Convolutional Vision Transformer

class vformer.models.classification.convvt.ConvVT(img_size=224, patch_size=[7, 3, 3], patch_stride=[4, 2, 2], patch_padding=[2, 1, 1], embedding_dim=[64, 192, 384], num_heads=[1, 3, 6], depth=[1, 2, 10], mlp_ratio=[4.0, 4.0, 4.0], p_dropout=[0, 0, 0], attn_dropout=[0, 0, 0], drop_path_rate=[0, 0, 0.1], kernel_size=[3, 3, 3], padding_q=[1, 1, 1], padding_kv=[1, 1, 1], stride_kv=[2, 2, 2], stride_q=[1, 1, 1], in_channels=3, num_stages=3, n_classes=1000)[source]

Implementation of CvT: Introducing Convolutions to Vision Transformers

Parameters
  • img_size (int) – Size of the image, default is 224

  • in_channels (int) – Number of input channels in image, default is 3

  • num_stages (* The following are all in list of int/float with length) – Number of stages in encoder block, default is 3

  • n_classes (int) – Number of classes for classification, default is 1000

  • num_stages

  • patch_size (list[int]) – Size of patch, default is [7, 3, 3]

  • patch_stride (list[int]) – Stride of patch, default is [4, 2, 2]

  • patch_padding (list[int]) – Padding for patch, default is [2, 1, 1]

  • embedding_dim (list[int]) – Embedding dimensions, default is [64, 192, 384]

  • depth (list[int]) – Number of CVT Attention blocks in each stage, default is [1, 2, 10]

  • num_heads (list[int]) – Number of heads in attention, default is [1, 3, 6]

  • mlp_ratio (list[float]) – Feature dimension expansion ratio in MLP, default is [4.0, 4.0, 4.0]

  • p_dropout (list[float]) – Probability of dropout in MLP, default is [0, 0, 0]

  • attn_dropout (list[float]) – Probability of dropout in attention, default is [0, 0, 0]

  • drop_path_rate (list[float]) – Probability for droppath, default is [0, 0, 0.1]

  • kernel_size (list[int]) – Size of kernel, default is [3, 3, 3]

  • padding_q (list[int]) – Size of padding in q, default is [1, 1, 1]

  • padding_kv (list[int]) – Size of padding in kv, default is [2, 2, 2]

  • stride_kv (list[int]) – Stride in kv, default is [2, 2, 2]

  • stride_q (list[int]) – Stride in q, default is [1, 1, 1]

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.