Cross
- class vformer.attention.cross.CrossAttention(query_dim, context_dim, num_heads=8, head_dim=64)[source]
Bases:
Module
Cross-Attention
- Parameters
query_dim (int) – Dimension of query array
context_dim (int) – Dimension of context array
num_heads (int) – Number of cross-attention heads
head_dim (int) – Dimension of each head
- forward(x, context, mask=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class vformer.attention.cross.CrossAttentionWithClsToken(cls_dim, patch_dim, num_heads=8, head_dim=64)[source]
Bases:
Module
Cross-Attention with Cls Token
- Parameters
cls_dim (int) – Dimension of cls token embedding
patch_dim (int) – Dimension of patch token embeddings cls token to be fused with
num_heads (int) – Number of cross-attention heads
head_dim (int) – Dimension of each head
- forward(cls, patches)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool