import torch
import torch.nn as nn
from einops import rearrange
from torchvision.ops import StochasticDepth
from ..attention.convvt import ConvVTAttention
from ..encoder.nn import FeedForward
from ..utils import ENCODER_REGISTRY
from .embedding.convvt import ConvEmbedding
[docs]@ENCODER_REGISTRY.register()
class ConvVTStage(nn.Module):
"""
Implementation of a Stage in CVT
Parameters
----------
patch_size: int
Size of patch, default is 16
patch_stride: int
Stride of patch, default is 4
patch_padding: int
Padding for patch, default is 0
in_channels:int
Number of input channels in image, default is 3
img_size: int
Size of the image, default is 224
embedding_dim: int
Embedding dimensions, default is 64
depth: int
Number of CVT Attention blocks in each stage, default is 1
num_heads: int
Number of heads in attention, default is 6
mlp_ratio: float
Feature dimension expansion ratio in MLP, default is 4.0
p_dropout: float
Probability of dropout in MLP, default is 0.0
attn_dropout: float
Probability of dropout in attention, default is 0.0
drop_path_rate: float
Probability for droppath, default is 0.0
with_cls_token: bool
Whether to include classification token, default is False
kernel_size: int
Size of kernel, default is 3
padding_q: int
Size of padding in q, default is 1
padding_kv: int
Size of padding in kv, default is 2
stride_kv: int
Stride in kv, default is 2
stride_q: int
Stride in q, default is 1
init: str ('trunc_norm' or 'xavier')
Initialization method, default is 'trunc_norm'
"""
def __init__(
self,
patch_size=7,
patch_stride=4,
patch_padding=0,
in_channels=3,
embedding_dim=64,
depth=1,
p_dropout=0.0,
drop_path_rate=0.0,
with_cls_token=False,
init="trunc_norm",
**kwargs
):
super().__init__()
self.patch_embed = ConvEmbedding(
patch_size=patch_size,
in_channels=in_channels,
embedding_dim=embedding_dim,
stride=patch_stride,
padding=patch_padding,
)
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embedding_dim))
else:
self.cls_token = None
self.pos_drop = nn.Dropout(p=p_dropout)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
blocks = []
for j in range(depth):
blocks.append(
ConvVTBlock(
dim_in=embedding_dim,
dim_out=embedding_dim,
p_dropout=p_dropout,
with_cls_token=with_cls_token,
drop_path=dpr[j],
**kwargs
)
)
self.blocks = nn.ModuleList(blocks)
if self.cls_token is not None:
nn.init.trunc_normal_(self.cls_token, std=0.02)
if init == "xavier":
self.apply(self._init_weights_xavier)
elif init == "trunc_norm":
self.apply(self._init_weights_trunc_normal)
else:
raise ValueError("Init method {} not found".format(init))
def _init_weights_trunc_normal(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _init_weights_xavier(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
[docs] def forward(self, x):
x = self.patch_embed(x)
B, C, H, W = x.size()
x = rearrange(x, "b c h w -> b (h w) c")
cls_tokens = None
if self.cls_token is not None:
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
if self.cls_token is not None:
cls_tokens, x = torch.split(x, [1, H * W], 1)
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
return x, cls_tokens
[docs]class ConvVTBlock(nn.Module):
"""
Implementation of a Attention MLP block in CVT
Parameters:
------------
dim_in: int
Input dimensions
dim_out: int
Output dimensions
num_heads: int
Number of heads in attention
img_size: int
Size of image
mlp_ratio: float
Feature dimension expansion ratio in MLP, default is 4.
p_dropout: float
Probability of dropout in MLP, default is 0.0
attn_dropout: float
Probability of dropout in attention, default is 0.0
drop_path: float
Probability of droppath, default is 0.0
with_cls_token: bool
Whether to include classification token, default is False
"""
def __init__(
self,
dim_in,
dim_out,
mlp_ratio=4.0,
p_dropout=0.0,
drop_path=0.0,
drop_path_mode="batch",
**kwargs
):
super().__init__()
self.norm1 = nn.LayerNorm(dim_in)
self.attn = ConvVTAttention(dim_in, dim_out, **kwargs)
self.drop_path = (
StochasticDepth(p=drop_path, mode=drop_path_mode)
if drop_path > 0.0
else nn.Identity()
)
self.norm2 = nn.LayerNorm(dim_out)
dim_mlp_hidden = int(dim_out * mlp_ratio)
self.mlp = FeedForward(
dim=dim_out, hidden_dim=dim_mlp_hidden, p_dropout=p_dropout
)
[docs] def forward(self, x):
res = x
x = self.norm1(x)
attn = self.attn(x)
x = res + self.drop_path(attn)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x