ConvVT

class vformer.models.classification.convvt.ConvVT(img_size=224, patch_size=[7, 3, 3], patch_stride=[4, 2, 2], patch_padding=[2, 1, 1], embedding_dim=[64, 192, 384], num_heads=[1, 3, 6], depth=[1, 2, 10], mlp_ratio=[4.0, 4.0, 4.0], p_dropout=[0, 0, 0], attn_dropout=[0, 0, 0], drop_path_rate=[0, 0, 0.1], kernel_size=[3, 3, 3], padding_q=[1, 1, 1], padding_kv=[1, 1, 1], stride_kv=[2, 2, 2], stride_q=[1, 1, 1], in_channels=3, num_stages=3, n_classes=1000)[source]

Implementation of CvT: Introducing Convolutions to Vision Transformers: https://arxiv.org/pdf/2103.15808.pdf

img_size: int

Size of the image, default is 224

in_channels:int

Number of input channels in image, default is 3

num_stages: int

Number of stages in encoder block, default is 3

n_classes: int

Number of classes for classification, default is 1000

  • The following are all in list of int/float with length num_stages

patch_size: list[int]

Size of patch, default is [7, 3, 3]

patch_stride: list[int]

Stride of patch, default is [4, 2, 2]

patch_padding: list[int]

Padding for patch, default is [2, 1, 1]

embedding_dim: list[int]

Embedding dimensions, default is [64, 192, 384]

depth: list[int]

Number of CVT Attention blocks in each stage, default is [1, 2, 10]

num_heads: list[int]

Number of heads in attention, default is [1, 3, 6]

mlp_ratio: list[float]

Feature dimension expansion ratio in MLP, default is [4.0, 4.0, 4.0]

p_dropout: list[float]

Probability of dropout in MLP, default is [0, 0, 0]

attn_dropout: list[float]

Probability of dropout in attention, default is [0, 0, 0]

drop_path_rate: list[float]

Probability for droppath, default is [0, 0, 0.1]

kernel_size: list[int]

Size of kernel, default is [3, 3, 3]

padding_q: list[int]

Size of padding in q, default is [1, 1, 1]

padding_kv: list[int]

Size of padding in kv, default is [2, 2, 2]

stride_kv: list[int]

Stride in kv, default is [2, 2, 2]

stride_q: list[int]

Stride in q, default is [1, 1, 1]

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.