Source code for vformer.models.classification.swin
import torch
import torch.nn as nn
from ...common import BaseClassificationModel
from ...decoder import MLPDecoder
from ...encoder import PatchEmbedding, PosEmbedding, SwinEncoder
from ...functional import PatchMerging
from ...utils import MODEL_REGISTRY
[docs]@MODEL_REGISTRY.register()
class SwinTransformer(BaseClassificationModel):
"""
Implementation of `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` <https://arxiv.org/abs/2103.14030v1>`_
Parameters
-----------
img_size: int
Size of an Image
patch_size: int
Patch Size
in_channels: int
Input channels in image, default=3
n_classes: int
Number of classes for classification
embedding_dim: int
Patch Embedding dimension
depths: tuple[int]
Depth in each Transformer layer
num_heads: tuple[int]
Number of heads in each transformer layer
window_size: int
Window Size
mlp_ratio : float
Ratio of mlp heads to embedding dimension
qkv_bias: bool, default= True
Adds bias to the qkv if true
qk_scale: float, optional
Override default qk scale of head_dim ** -0.5 in Window Attention if set
p_dropout: float
Dropout rate, default is 0.0
attn_dropout: float
Attention dropout rate,default is 0.0
drop_path_rate: float
Stochastic depth rate, default is 0.1
norm_layer: nn.Module
Normalization layer,default is nn.LayerNorm
ape: bool, optional
Whether to add relative/absolute position embedding to patch embedding, default is True
decoder_config: int or tuple[int], optional
Configuration of the decoder. If None, the default configuration is used.
patch_norm: bool, optional
Whether to add Normalization layer in PatchEmbedding, default is True
"""
def __init__(
self,
img_size,
patch_size,
in_channels,
n_classes,
embedding_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=8,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
p_dropout=0.0,
attn_dropout=0.0,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=True,
decoder_config=None,
patch_norm=True,
):
super(SwinTransformer, self).__init__(
img_size, patch_size, in_channels, pool="mean"
)
self.patch_embed = PatchEmbedding(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embedding_dim=embedding_dim,
norm_layer=norm_layer if patch_norm else nn.Identity,
)
self.patch_resolution = self.patch_embed.patch_resolution
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
self.ape = ape
num_features = int(embedding_dim * 2 ** (len(depths) - 1))
self.absolute_pos_embed = (
PosEmbedding(shape=num_patches, dim=embedding_dim, drop=p_dropout, std=0.02)
if ape
else nn.Identity()
)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.encoder = nn.ModuleList()
for i_layer in range(len(depths)):
layer = SwinEncoder(
dim=int(embedding_dim * (2**i_layer)),
input_resolution=(
(self.patch_resolution[0] // (2**i_layer)),
self.patch_resolution[1] // (2**i_layer),
),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qkv_scale=qk_scale,
p_dropout=p_dropout,
attn_dropout=attn_dropout,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if i_layer < len(depths) - 1 else None,
)
self.encoder.append(layer)
if decoder_config is not None:
if not isinstance(decoder_config, list):
decoder_config = list(decoder_config)
assert (
decoder_config[0] == num_features
), f"first item of `decoder_config` should be equal to the `num_features`; num_features=embed_dim * 2** (len(depths)-1) which is = {num_features} "
self.decoder = MLPDecoder(decoder_config, n_classes)
else:
self.decoder = MLPDecoder(num_features, n_classes)
self.pool = nn.AdaptiveAvgPool1d(1)
self.norm = norm_layer(num_features) if norm_layer is not None else nn.Identity
self.pos_drop = nn.Dropout(p=p_dropout)
[docs] def forward(self, x):
"""
Parameters
----------
x: torch.Tensor
Input tensor
Returns
----------
torch.Tensor
Returns tensor of size `n_classes`
"""
x = self.patch_embed(x)
x = self.absolute_pos_embed(x)
for layer in self.encoder:
x = layer(x)
x = self.norm(x)
x = self.pool(x.transpose(1, 2)).flatten(1)
x = self.decoder(x)
return x