Cross Attention
- class vformer.attention.cross.CrossAttention(query_dim, context_dim, num_heads=8, head_dim=64)[source]
Bases:
Module
This variant of Cross Attention is iteratively used in Perciever IO.
In Cross-Attention, cls token from one branch and patch token from another branch are fused together.
- Parameters
query_dim (int) – Dimension of query array
context_dim (int) – Dimension of context array
num_heads (int) – Number of cross-attention heads
head_dim (int) – Dimension of each head
- forward(x, context, mask=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class vformer.attention.cross.CrossAttentionWithClsToken(cls_dim, patch_dim, num_heads=8, head_dim=64)[source]
Bases:
Module
Cross-Attention with Cls Token introduced in Paper: CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification
In Cross-Attention, cls token from one branch and patch token from another branch are fused together.
- Parameters
cls_dim (int) – Dimension of cls token embedding
patch_dim (int) – Dimension of patch token embeddings cls token to be fused with
num_heads (int) – Number of cross-attention heads
head_dim (int) – Dimension of each head
- forward(cls, patches)[source]
- Parameters
x (torch.Tensor) – Input tensor
cls (torch.Tensor) – CLS token from one branch
patch (torch.Tensor) – patch tokens from another branch
- Returns
Returns output tensor by applying cross attention on input tensor
- Return type
torch.Tensor
- training: bool