Cross Encoder

class vformer.encoder.cross.CrossEncoder(embedding_dim_s=1024, embedding_dim_l=1024, attn_heads_s=16, attn_heads_l=16, cross_head_s=8, cross_head_l=8, head_dim_s=64, head_dim_l=64, cross_dim_head_s=64, cross_dim_head_l=64, depth_s=6, depth_l=6, mlp_dim_s=2048, mlp_dim_l=2048, p_dropout_s=0.0, p_dropout_l=0.0)[source]

Encoder block used in Cross-VIT .

Parameters
  • embedding_dim_s (int) – Dimension of the embedding of smaller patches, default is 1024

  • embedding_dim_l (int) – Dimension of the embedding of larger patches, default is 1024

  • attn_heads_s (int) – Number of self-attention heads for the smaller patches, default is 16

  • attn_heads_l (int) – Number of self-attention heads for the larger patches, default is 16

  • cross_head_s (int) – Number of cross-attention heads for the smaller patches, default is 8

  • cross_head_l (int) – Number of cross-attention heads for the larger patches, default is 8

  • head_dim_s (int) – Dimension of the head of the attention for the smaller patches, default is 64

  • head_dim_l (int) – Dimension of the head of the attention for the larger patches, default is 64

  • cross_dim_head_s (int) – Dimension of the head of the cross-attention for the smaller patches, default is 64

  • cross_dim_head_l (int) – Dimension of the head of the cross-attention for the larger patches, default is 64

  • depth_s (int) – Number of self-attention layers in encoder for the smaller patches, default is 6

  • depth_l (int) – Number of self-attention layers in encoder for the larger patches, default is 6

  • mlp_dim_s (int) – Dimension of the hidden layer in the feed-forward layer for the smaller patches, default is 2048

  • mlp_dim_l (int) – Dimension of the hidden layer in the feed-forward layer for the larger patches, default is 2048

  • p_dropout_s (float) – Dropout probability for the smaller patches, default is 0.0

  • p_dropout_l (float) – Dropout probability for the larger patches, default is 0.0

forward(emb_s, emb_l)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.