Cross-Attention Transformer

class vformer.models.classification.cross.CrossViT(img_size, patch_size_s, patch_size_l, n_classes, cross_dim_head_s=64, cross_dim_head_l=64, latent_dim_s=1024, latent_dim_l=1024, head_dim_s=64, head_dim_l=64, depth_s=6, depth_l=6, attn_heads_s=16, attn_heads_l=16, cross_head_s=8, cross_head_l=8, encoder_mlp_dim_s=2048, encoder_mlp_dim_l=2048, in_channels=3, decoder_config_s=None, decoder_config_l=None, pool_s='cls', pool_l='cls', p_dropout_encoder_s=0.0, p_dropout_encoder_l=0.0, p_dropout_embedding_s=0.0, p_dropout_embedding_l=0.0)[source]

Implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

Parameters
  • img_size (int) – Size of the image

  • patch_size_s (int) – Size of the smaller patches

  • patch_size_l (int) – Size of the larger patches

  • n_classes (int) – Number of classes for classification

  • cross_dim_head_s (int) – Dimension of the head of the cross-attention for the smaller patches

  • cross_dim_head_l (int) – Dimension of the head of the cross-attention for the larger patches

  • latent_dim_s (int) – Dimension of the hidden layer for the smaller patches

  • latent_dim_l (int) – Dimension of the hidden layer for the larger patches

  • head_dim_s (int) – Dimension of the head of the attention for the smaller patches

  • head_dim_l (int) – Dimension of the head of the attention for the larger patches

  • depth_s (int) – Number of attention layers in encoder for the smaller patches

  • depth_l (int) – Number of attention layers in encoder for the larger patches

  • attn_heads_s (int) – Number of attention heads for the smaller patches

  • attn_heads_l (int) – Number of attention heads for the larger patches

  • cross_head_s (int) – Number of CrossAttention heads for the smaller patches

  • cross_head_l (int) – Number of CrossAttention heads for the larger patches

  • encoder_mlp_dim_s (int) – Dimension of hidden layer in the encoder for the smaller patches

  • encoder_mlp_dim_l (int) – Dimension of hidden layer in the encoder for the larger patches

  • in_channels (int) – Number of input channels

  • decoder_config_s (int or tuple or list, optional) – Configuration of the decoder for the smaller patches

  • decoder_config_l (int or tuple or list, optional) – Configuration of the decoder for the larger patches

  • pool_s (str) – Feature pooling type for the smaller patches, one of {cls,``mean``}

  • pool_l (str) – Feature pooling type for the larger patches, one of {cls,``mean``}

  • p_dropout_encoder_s (float) – Dropout probability in the encoder for the smaller patches

  • p_dropout_encoder_l (float) – Dropout probability in the encoder for the larger patches

  • p_dropout_embedding_s (float) – Dropout probability in the embedding layer for the smaller patches

  • p_dropout_embedding_l (float) – Dropout probability in the embedding layer for the larger patches

forward(img)[source]
Parameters

img (torch.Tensor) – Input tensor

Returns

Returns tensor of size n_classes

Return type

torch.Tensor