import torch.nn as nn
from ..functional import PreNorm
from ..utils import ATTENTION_REGISTRY
[docs]@ATTENTION_REGISTRY.register()
class SpatialAttention(nn.Module):
"""
Spatial Reduction Attention introduced in : `Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions <https://arxiv.org/abs/2102.12122>`_
This class also supports the linear complexity spatial attention in the improved `paper <https://arxiv.org/abs/2106.13797>`_
Parameters
-----------
dim: int
Dimension of the input tensor
num_heads: int
Number of attention heads
sr_ratio :int
Spatial Reduction ratio
qkv_bias : bool
If True, add a learnable bias to query, key, value, default is ``True``
qk_scale : float, optional
Override default qk scale of head_dim ** -0.5 if set
attn_drop : float, optional
Dropout rate
proj_drop :float, optional
Dropout rate
linear : bool
Whether to use linear Spatial attention,default is ``False``.
activation : nn.Module
Activation function, default is ``nn.GELU``.
"""
def __init__(
self,
dim,
num_heads,
sr_ratio=1,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
linear=False,
activation=nn.GELU,
):
super(SpatialAttention, self).__init__()
assert (
dim % num_heads == 0
), f"dim {dim} should be divided by num_heads {num_heads}."
self.num_heads = num_heads
self.sr_ratio = sr_ratio
head_dim = dim // num_heads
self.scale = qk_scale or (head_dim) ** (0.5)
inner_dim = head_dim * num_heads
self.q = nn.Linear(dim, inner_dim, bias=qkv_bias)
self.kv = nn.Linear(dim, inner_dim * 2, bias=qkv_bias)
self.attn = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(p=attn_drop))
self.to_out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(p=proj_drop))
self.linear = linear
self.sr_ratio = sr_ratio
self.norm = PreNorm(dim=dim, fn=activation() if linear else nn.Identity())
if not linear:
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
else:
self.pool = nn.AdaptiveAvgPool2d(7)
self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
[docs] def forward(self, x, H, W):
"""
Parameters
----------
x: torch.Tensor
Input tensor
H: int
Height of image patches
W: int
Width of image patches
Returns
----------
torch.Tensor
Returns output tensor by applying spatial attention on input tensor
"""
B, N, C = x.shape
q = (
self.q(x)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
if not self.linear:
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.norm(self.sr(x_).reshape(B, C, -1).permute(0, 2, 1))
kv = (
self.kv(x_)
.reshape(B, -1, 2, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
else:
kv = (
self.kv(x)
.reshape(B, -1, 2, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
else:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.norm(self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1))
kv = (
self.kv(x_)
.reshape(B, -1, 2, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = self.attn(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
return self.to_out(x)