Source code for vformer.attention.cross

import torch
import torch.nn as nn
from einops import rearrange, repeat

from ..utils import ATTENTION_REGISTRY


[docs]@ATTENTION_REGISTRY.register() class CrossAttentionWithClsToken(nn.Module): """ Cross-Attention with Cls Token introduced in Paper: CrossViT: `Cross-Attention Multi-Scale Vision Transformer for Image Classification <https://arxiv.org/abs/2103.14899>`_ In Cross-Attention, cls token from one branch and patch token from another branch are fused together. Parameters ----------- cls_dim: int Dimension of cls token embedding patch_dim: int Dimension of patch token embeddings cls token to be fused with num_heads: int Number of cross-attention heads head_dim: int Dimension of each head """ def __init__(self, cls_dim, patch_dim, num_heads=8, head_dim=64): super().__init__() inner_dim = num_heads * head_dim self.num_heads = num_heads self.scale = head_dim**-0.5 self.fl = ( nn.Linear(cls_dim, patch_dim) if not cls_dim == patch_dim else nn.Identity() ) self.gl = ( nn.Linear(patch_dim, cls_dim) if not cls_dim == patch_dim else nn.Identity() ) self.to_k = nn.Linear(patch_dim, inner_dim) self.to_v = nn.Linear(patch_dim, inner_dim) self.to_q = nn.Linear(patch_dim, inner_dim) self.cls_project = ( nn.Linear(inner_dim, patch_dim) if inner_dim != patch_dim else nn.Identity() ) self.attend = nn.Softmax(dim=-1)
[docs] def forward(self, cls, patches): """ Parameters ---------- x: torch.Tensor Input tensor cls: torch.Tensor CLS token from one branch patch: torch.Tensor patch tokens from another branch Returns ---------- torch.Tensor Returns output tensor by applying cross attention on input tensor """ h = self.num_heads cls = self.fl(cls) x = torch.cat([cls, patches], dim=-2) q = self.to_q(cls) k = self.to_k(x) v = self.to_v(x) k = rearrange(k, "b n (h d) -> b h n d", h=h) q = rearrange(q, "b n (h d) -> b h n d", h=h) v = rearrange(v, "b n (h d) -> b h n d", h=h) attention = torch.matmul(q, k.transpose(-1, -2)) * self.scale attention = self.attend(attention) attention_value = attention @ v attention_value = rearrange(attention_value, "b h n d -> b n (h d)") attention_value = self.cls_project(attention_value) ycls = cls + attention_value ycls = self.gl(ycls) return ycls
[docs]@ATTENTION_REGISTRY.register() class CrossAttention(nn.Module): """ This variant of Cross Attention is iteratively used in Perciever IO. In Cross-Attention, cls token from one branch and patch token from another branch are fused together. Parameters ---------- query_dim: int Dimension of query array context_dim: int Dimension of context array num_heads: int Number of cross-attention heads head_dim: int Dimension of each head """ def __init__(self, query_dim, context_dim, num_heads=8, head_dim=64): super().__init__() inner_dim = num_heads * head_dim self.num_heads = num_heads self.scale = head_dim**-0.5 self.to_q = nn.Linear(query_dim, inner_dim) self.to_k = nn.Linear(context_dim, inner_dim) self.to_v = nn.Linear(context_dim, inner_dim) self.to_out = ( nn.Linear(inner_dim, query_dim) if not inner_dim == query_dim else nn.Identity() ) self.attend = nn.Softmax(dim=-1)
[docs] def forward(self, x, context, mask=None): h = self.num_heads q = self.to_q(x) k = self.to_k(context) v = self.to_v(context) k = rearrange(k, "b n (h d) -> b h n d", h=h) q = rearrange(q, "b n (h d) -> b h n d", h=h) v = rearrange(v, "b n (h d) -> b h n d", h=h) attention = torch.matmul(q, k.transpose(-1, -2)) * self.scale if mask is not None: mask = rearrange(mask, "b ... -> b (...)") max_neg_value = -torch.finfo(attention.dtype).max mask = repeat(mask, "b j -> b h () j", h=h) attention.masked_fill_(~mask, max_neg_value) attention = self.attend(attention) attention_value = attention @ v attention_value = rearrange(attention_value, "b h n d -> b n (h d)") return self.to_out(attention_value)