Source code for vformer.encoder.cross

import torch
import torch.nn as nn

from ..attention import CrossAttentionWithClsToken
from ..utils import ENCODER_REGISTRY
from .vanilla import VanillaEncoder


[docs]@ENCODER_REGISTRY.register() class CrossEncoder(nn.Module): """ Encoder block used in Cross-VIT . Parameters ---------- embedding_dim_s : int Dimension of the embedding of smaller patches, default is 1024 embedding_dim_l : int Dimension of the embedding of larger patches, default is 1024 attn_heads_s : int Number of self-attention heads for the smaller patches, default is 16 attn_heads_l : int Number of self-attention heads for the larger patches, default is 16 cross_head_s : int Number of cross-attention heads for the smaller patches, default is 8 cross_head_l : int Number of cross-attention heads for the larger patches, default is 8 head_dim_s : int Dimension of the head of the attention for the smaller patches, default is 64 head_dim_l : int Dimension of the head of the attention for the larger patches, default is 64 cross_dim_head_s : int Dimension of the head of the cross-attention for the smaller patches, default is 64 cross_dim_head_l : int Dimension of the head of the cross-attention for the larger patches, default is 64 depth_s : int Number of self-attention layers in encoder for the smaller patches, default is 6 depth_l : int Number of self-attention layers in encoder for the larger patches, default is 6 mlp_dim_s : int Dimension of the hidden layer in the feed-forward layer for the smaller patches, default is 2048 mlp_dim_l : int Dimension of the hidden layer in the feed-forward layer for the larger patches, default is 2048 p_dropout_s : float Dropout probability for the smaller patches, default is 0.0 p_dropout_l : float Dropout probability for the larger patches, default is 0.0 """ def __init__( self, embedding_dim_s=1024, embedding_dim_l=1024, attn_heads_s=16, attn_heads_l=16, cross_head_s=8, cross_head_l=8, head_dim_s=64, head_dim_l=64, cross_dim_head_s=64, cross_dim_head_l=64, depth_s=6, depth_l=6, mlp_dim_s=2048, mlp_dim_l=2048, p_dropout_s=0.0, p_dropout_l=0.0, ): super().__init__() self.s = VanillaEncoder( embedding_dim_s, depth_s, attn_heads_s, head_dim_s, mlp_dim_s, p_dropout_s, ) self.l = VanillaEncoder( embedding_dim_l, depth_l, attn_heads_l, head_dim_l, mlp_dim_l, p_dropout_l, ) self.attend_s = CrossAttentionWithClsToken( embedding_dim_s, embedding_dim_l, cross_head_s, cross_dim_head_s ) self.attend_l = CrossAttentionWithClsToken( embedding_dim_l, embedding_dim_s, cross_head_l, cross_dim_head_l )
[docs] def forward(self, emb_s, emb_l): emb_s = self.s(emb_s) emb_l = self.l(emb_l) s_cls, s_patches = (lambda t: (t[:, 0:1, :], t[:, 1:, :]))(emb_s) l_cls, l_patches = (lambda t: (t[:, 0:1, :], t[:, 1:, :]))(emb_l) s_cls = self.attend_s(s_cls, l_patches) l_cls = self.attend_l(l_cls, s_patches) emb_l = torch.cat([l_cls, l_patches], dim=1) emb_s = torch.cat([s_cls, s_patches], dim=1) return emb_s, emb_l