Source code for vformer.encoder.swin

import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torchvision.ops import StochasticDepth

from ..attention.window import WindowAttention
from ..utils import (
from .nn import FeedForward

[docs]@ENCODER_REGISTRY.register() class SwinEncoderBlock(nn.Module): """ Parameters ----------- dim: int Number of the input channels input_resolution: int or tuple[int] Input resolution of patches num_heads: int Number of attention heads window_size: int Window size shift_size: int Shift size for Shifted Window Masked Self Attention (SW_MSA) mlp_ratio: float Ratio of MLP hidden dimension to embedding dimension qkv_bias: bool, default= True Whether to add a bias vector to the q,k, and v matrices qk_scale: float, Optional p_dropout: float Dropout rate attn_dropout: float Dropout rate drop_path_rate: float Stochastic depth rate norm_layer:nn.Module Normalization layer, default is `nn.LayerNorm` """ def __init__( self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, p_dropout=0.0, attn_dropout=0.0, drop_path_rate=0.0, norm_layer=nn.LayerNorm, drop_path_mode="batch", ): super(SwinEncoderBlock, self).__init__() self.dim = dim self.input_resolution = pair(input_resolution) self.num_heads = num_heads self.window_size = window_size self.mlp_ratio = mlp_ratio self.shift_size = shift_size hidden_dim = int(dim * mlp_ratio) if min(self.input_resolution) <= self.window_size: self.shift_size = 0 self.window_size = min(self.input_resolution) assert ( 0 <= self.shift_size < window_size ), "shift size must range from 0 to window size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim=dim, window_size=pair(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_dropout=attn_dropout, proj_dropout=p_dropout, ) self.drop_path = ( StochasticDepth(p=drop_path_rate, mode=drop_path_mode) if drop_path_rate > 0.0 else nn.Identity() ) self.norm2 = norm_layer(dim) self.mlp = FeedForward(dim=dim, hidden_dim=hidden_dim, p_dropout=p_dropout) if self.shift_size > 0: attn_mask = create_mask( self.window_size, self.shift_size, H=self.input_resolution[0], W=self.input_resolution[1], ) else: attn_mask = None self.register_buffer("attn_mask", attn_mask)
[docs] def forward(self, x): """ Parameters ---------- x: torch.Tensor Returns ---------- torch.Tensor Returns output tensor """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "Input tensor shape not compatible" skip_connection = x x = self.norm1(x) x = x.view(B, H, W, C) if self.shift_size > 0: shifted_x = cyclicshift(x, shift_size=-self.shift_size) else: shifted_x = x x_windows = window_partition(shifted_x, self.window_size).view( -1, self.window_size * self.window_size, C ) attn_windows = self.attn(x_windows, mask=self.attn_mask).view( -1, self.window_size, self.window_size, C ) shifted_x = window_reverse(attn_windows, self.window_size, H, W) if self.shift_size > 0: x = cyclicshift(shifted_x, shift_size=self.shift_size).view(B, H * W, C) else: x = shifted_x.view(B, H * W, C) x = skip_connection + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x
[docs]@ENCODER_REGISTRY.register() class SwinEncoder(nn.Module): """ Parameters ----------- dim: int Number of input channels. input_resolution: tuple[int] Input resolution. depth: int Number of blocks. num_heads: int Number of attention heads. window_size: int Local window size. mlp_ratio: float Ratio of MLP hidden dim to embedding dim. qkv_bias: bool, default is True Whether to add a bias vector to the q,k, and v matrices qk_scale: float, optional Override default qk scale of head_dim ** -0.5 in Window Attention if set p_dropout: float, Dropout rate. attn_dropout: float, optional Attention dropout rate drop_path_rate: float or tuple[float] Stochastic depth rate. norm_layer: nn.Module Normalization layer. default is nn.LayerNorm downsample: nn.Module, optional Downsample layer(like PatchMerging) at the end of the layer, default is None """ def __init__( self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4.0, qkv_bias=True, qkv_scale=None, p_dropout=0.0, attn_dropout=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, ): super(SwinEncoder, self).__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint self.blocks = nn.ModuleList( [ SwinEncoderBlock( dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qkv_scale, p_dropout=p_dropout, attn_dropout=attn_dropout, drop_path_rate=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, ) for i in range(depth) ] ) if downsample is not None: self.downsample = downsample( input_resolution, dim=dim, norm_layer=norm_layer ) else: self.downsample = None
[docs] def forward(self, x): """ Parameters ---------- x: torch.Tensor Returns ---------- torch.Tensor Returns output tensor """ for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x