from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from ..utils import ATTENTION_REGISTRY

[docs]@ATTENTION_REGISTRY.register() class ConvVTAttention(nn.Module): """ Attention with Convolutional Projection introduced in Paper: `Introducing Convolutions to Vision Transformers <>`_ Position-wise linear projection for Multi-Head Self-Attention (MHSA) replaced by Depth-wise separable convolutions Parameters ----------- dim_in: int Dimension of input tensor dim_out: int Dimension of output tensor num_heads: int Number of heads in attention img_size: int Size of image attn_dropout: float Probability of dropout in attention proj_dropout: float Probability of dropout in convolution projection method: str Method of projection, ``'dw_bn'`` for depth-wise convolution and batch norm, ``'avg'`` for average pooling. default is ``'dw_bn'`` kernel_size: int Size of kernel stride_kv: int Size of stride for key value stride_q: int Size of stride for query padding_kv: int Padding for key value padding_q: int Padding for query with_cls_token: bool Whether to include classification token, default is ```False```. """ def __init__( self, dim_in, dim_out, num_heads, img_size, attn_dropout=0.0, proj_dropout=0.0, method="dw_bn", kernel_size=3, stride_kv=1, stride_q=1, padding_kv=1, padding_q=1, with_cls_token=False, ): super().__init__() self.stride_kv = stride_kv self.stride_q = stride_q self.with_cls_token = with_cls_token self.dim = dim_out self.num_heads = num_heads self.scale = dim_out**-0.5 self.h, self.w = img_size, img_size self.conv_proj_q = self._build_projection( dim_in, kernel_size, padding_q, stride_q, method ) self.conv_proj_k = self._build_projection( dim_in, kernel_size, padding_kv, stride_kv, method ) self.conv_proj_v = self._build_projection( dim_in, kernel_size, padding_kv, stride_kv, method ) self.proj_q = nn.Linear(dim_in, dim_out, bias=False) self.proj_k = nn.Linear(dim_in, dim_out, bias=False) self.proj_v = nn.Linear(dim_in, dim_out, bias=False) self.attn_drop = nn.Dropout(attn_dropout) self.proj = nn.Linear(dim_out, dim_out) self.proj_drop = nn.Dropout(proj_dropout) def _build_projection(self, dim_in, kernel_size, padding, stride, method): if method == "dw_bn": proj = nn.Sequential( OrderedDict( [ ( "conv", nn.Conv2d( dim_in, dim_in, kernel_size=kernel_size, padding=padding, stride=stride, bias=False, groups=dim_in, ), ), ("bn", nn.BatchNorm2d(dim_in)), ] ) ) elif method == "avg": proj = nn.AvgPool2d( kernel_size=kernel_size, padding=padding, stride=stride, ceil_mode=True ) else: raise ValueError("Unknown method ({})".format(method)) return proj
[docs] def forward_conv(self, x): if self.with_cls_token: cls_token, x = torch.split(x, [1, self.h * self.w], 1) x = rearrange(x, "b (h w) c -> b c h w", h=self.h, w=self.w) q = self.conv_proj_q(x) q = rearrange(q, "b c h w -> b (h w) c") k = self.conv_proj_k(x) k = rearrange(k, "b c h w -> b (h w) c") v = self.conv_proj_v(x) v = rearrange(v, "b c h w -> b (h w) c") if self.with_cls_token: q =, q), dim=1) k =, k), dim=1) v =, v), dim=1) return q, k, v
[docs] def forward(self, x): """ Parameters ---------- x: torch.Tensor Input tensor Returns ---------- torch.Tensor Returns output tensor by applying self-attention on input tensor """ q, k, v = self.forward_conv(x) q = rearrange(self.proj_q(q), "b t (h d) -> b h t d", h=self.num_heads) k = rearrange(self.proj_k(k), "b t (h d) -> b h t d", h=self.num_heads) v = rearrange(self.proj_v(v), "b t (h d) -> b h t d", h=self.num_heads) attn_score = torch.einsum("bhlk,bhtk->bhlt", [q, k]) * self.scale attn = F.softmax(attn_score, dim=-1) attn = self.attn_drop(attn) x = torch.einsum("bhlt,bhtv->bhlv", [attn, v]) x = rearrange(x, "b h t d -> b t (h d)") x = self.proj(x) x = self.proj_drop(x) return x