Source code for vformer.models.classification.pyramid

import numpy as np
import torch
import torch.nn as nn

from ...decoder import MLPDecoder
from ...encoder import OverlapPatchEmbed, PVTEncoder, PVTPosEmbedding
from ...utils import MODEL_REGISTRY


[docs]@MODEL_REGISTRY.register() class PVTClassification(nn.Module): """ Implementation of `Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolution <https://arxiv.org/abs/2102.12122>`_ Parameters ----------- img_size: int Image size patch_size: list(int) List of patch size in_channels: int Input channels in image, default=3 n_classes: int Number of classes for classification embed_dims: int Patch Embedding dimension num_heads:tuple[int] Number of heads in each transformer layer depths: tuple[int] Depth in each Transformer layer mlp_ratio: float Ratio of mlp heads to embedding dimension qkv_bias: bool, default= True Adds bias to the qkv if true qk_scale: float, optional Override default qk scale of head_dim ** -0.5 Spatial Attention if set p_dropout: float, Dropout rate,default is 0.0 attn_dropout: float, Attention dropout rate, default is 0.0 drop_path_rate: float Stochastic depth rate, default is 0.1 norm_layer: Normalization layer, default is nn.LayerNorm sr_ratios: float Spatial reduction ratio decoder_config:int or tuple[int], optional Configuration of the decoder. If None, the default configuration is used. linear: bool Whether to use linear Spatial attention, default is False use_dwconv: bool Whether to use Depth-wise convolutions, default is False ape: bool Whether to use absolute position embedding, default is True """ def __init__( self, img_size=224, patch_size=[7, 3, 3, 3], in_channels=3, n_classes=1000, embed_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratio=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, p_dropout=0.0, attn_dropout=0.0, drop_path_rate=0.0, norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], decoder_config=None, linear=False, use_dwconv=False, ape=True, ): super(PVTClassification, self).__init__() self.ape = ape self.depths = depths assert ( len(depths) == len(num_heads) == len(embed_dims) ), "Configurations do not match" dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] self.patch_embeds = nn.ModuleList([]) self.blocks = nn.ModuleList([]) self.norms = nn.ModuleList() self.pos_embeds = nn.ModuleList() for i in range(len(depths)): self.patch_embeds.append( nn.ModuleList( [ OverlapPatchEmbed( img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), patch_size=patch_size[i], stride=4 if i == 0 else 2, in_channels=in_channels if i == 0 else embed_dims[i - 1], embedding_dim=embed_dims[i], ) ] ) ) if ape: if i != len(depths) - 1: self.pos_embeds.append( nn.ModuleList( [ PVTPosEmbedding( pos_shape=img_size // np.prod(patch_size[: i + 1]), pos_dim=embed_dims[i], ) ] ) ) else: self.last_pos = nn.Parameter( torch.randn( 1, (img_size // np.prod(patch_size[: i + 1])) ** 2, embed_dims[-1], ) ) self.blocks.append( nn.ModuleList( [ PVTEncoder( dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias, qk_scale=qk_scale, p_dropout=p_dropout, depth=depths[i], attn_dropout=attn_dropout, drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], sr_ratio=sr_ratios[i], linear=linear, activation=nn.GELU, use_dwconv=use_dwconv, ) ] ) ) self.norms.append(norm_layer(embed_dims[i])) # cls_token self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1])) nn.init.trunc_normal_(self.cls_token, std=0.02) if decoder_config is not None: if not isinstance(decoder_config, list) and not isinstance( decoder_config, tuple ): decoder_config = [decoder_config] assert ( decoder_config[0] == embed_dims[-1] ), f"Configurations do not match for MLPDecoder, First element of `decoder_config` expected to be {embed_dims[-1]}, got {decoder_config[0]} " self.decoder = MLPDecoder(config=decoder_config, n_classes=n_classes) else: self.decoder = MLPDecoder(config=embed_dims[-1], n_classes=n_classes)
[docs] def forward(self, x): """ Parameters ---------- x: torch.Tensor Input tensor Returns ---------- torch.Tensor Returns tensor of size `n_classes` """ B = x.shape[0] for i in range(len(self.depths)): patch_embed = self.patch_embeds[i] block = self.blocks[i] norm = self.norms[i] x, H, W = patch_embed[0](x) N = x.shape[1] if self.ape: if i == len(self.depths) - 1: x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) x += self.last_pos[:, : (N + 1)] else: pos_embed = self.pos_embeds[i] x = pos_embed[0](x, H=H, W=W) for blk in block: x = blk(x, H=H, W=W) x = norm(x) if i == len(self.depths) - 1: x = x.mean(dim=1) else: x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() x = self.decoder(x) return x
[docs]@MODEL_REGISTRY.register() class PVTClassificationV2(PVTClassification): """ Implementation of `PVT v2: Improved Baselines with Pyramid Vision Transformer <https://arxiv.org/abs/2106.13797>`_ Parameters ---------- img_size: int Image size patch_size: list(int) List of patch size in_channels: int Input channels in image, default is 3 n_classes: int Number of classes for classification embedding_dims: int Patch Embedding dimension num_heads:tuple[int] Number of heads in each transformer layer depths: tuple[int] Depth in each Transformer layer mlp_ratio: float Ratio of mlp heads to embedding dimension qkv_bias: bool, default= True Adds bias to the qkv if true qk_scale: float, optional Override default qk scale of head_dim ** -0.5 in Spatial Attention if set p_dropout: float, Dropout rate,default is 0.0 attn_dropout: float, Attention dropout rate, default is 0.0 drop_path_rate: float Stochastic depth rate, default is 0.1 norm_layer:nn.Module Normalization layer, default is nn.LayerNorm sr_ratios: float Spatial reduction ratio decoder_config:int or tuple[int], optional Configuration of the decoder. If None, the default configuration is used. linear: bool Whether to use linear Spatial attention, default is False use_dwconv: bool Whether to use Depth-wise convolutions, default is True ape: bool Whether to use absolute position embedding, default is false """ def __init__( self, img_size=224, patch_size=[7, 3, 3, 3], in_channels=3, n_classes=1000, embedding_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratio=[4, 4, 4, 4], qkv_bias=False, qk_scale=0.0, p_dropout=0.0, attn_dropout=0.0, drop_path_rate=0.0, norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], decoder_config=None, use_dwconv=True, linear=False, ape=False, ): super(PVTClassificationV2, self).__init__( img_size=img_size, patch_size=patch_size, in_channels=in_channels, n_classes=n_classes, embed_dims=embedding_dims, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, p_dropout=p_dropout, attn_dropout=attn_dropout, drop_path_rate=drop_path_rate, norm_layer=norm_layer, depths=depths, sr_ratios=sr_ratios, decoder_config=decoder_config, ape=ape, use_dwconv=use_dwconv, linear=linear, )