Source code for vformer.models.classification.vivit

import torch
import torch.nn as nn
from einops import rearrange, repeat

from ...common.base_model import BaseClassificationModel
from ...decoder.mlp import MLPDecoder
from ...encoder.embedding import LinearVideoEmbedding, PosEmbedding, TubeletEmbedding
from ...encoder.vanilla import VanillaEncoder
from ...encoder.vivit import ViViTEncoder
from ...utils.registry import MODEL_REGISTRY
from ...utils.utils import pair


[docs]@MODEL_REGISTRY.register() class ViViTModel2(BaseClassificationModel): """ Model 2 implementation of: `ViViT: A Video Vision Transformer <https://arxiv.org/abs/2103.15691>`_ Parameters ----------- img_size:int Size of single frame/ image in video in_channels:int Number of channels patch_size: int Patch size embedding_dim: int Embedding dimension of a patch num_frames:int Number of seconds in each Video depth:int Number of encoder layers num_heads:int Number of attention heads head_dim:int Dimension of head n_classes:int Number of classes mlp_dim: int Dimension of hidden layer pool: str Pooling operation,must be one of {"cls","mean"},default is "cls" p_dropout:float Dropout probability attn_dropout:float Dropout probability drop_path_rate:float Stochastic drop path rate """ def __init__( self, img_size, in_channels, patch_size, embedding_dim, num_frames, depth, num_heads, head_dim, n_classes, mlp_dim=None, pool="cls", p_dropout=0.0, attn_dropout=0.0, drop_path_rate=0.02, ): super(ViViTModel2, self).__init__( img_size=img_size, in_channels=in_channels, patch_size=patch_size, pool=pool, ) patch_dim = in_channels * patch_size**2 self.patch_embedding = LinearVideoEmbedding( embedding_dim=embedding_dim, patch_height=patch_size, patch_width=patch_size, patch_dim=patch_dim, ) self.pos_embedding = PosEmbedding( shape=[num_frames, self.num_patches + 1], dim=embedding_dim, drop=p_dropout ) self.space_token = nn.Parameter( torch.randn(1, 1, embedding_dim) ) # this is similar to using cls token in vanilla vision transformer self.spatial_transformer = VanillaEncoder( embedding_dim=embedding_dim, depth=depth, num_heads=num_heads, head_dim=head_dim, mlp_dim=mlp_dim, p_dropout=p_dropout, attn_dropout=attn_dropout, drop_path_rate=drop_path_rate, ) self.time_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) self.temporal_transformer = VanillaEncoder( embedding_dim=embedding_dim, depth=depth, num_heads=num_heads, head_dim=head_dim, mlp_dim=mlp_dim, p_dropout=p_dropout, attn_dropout=attn_dropout, drop_path_rate=drop_path_rate, ) self.decoder = MLPDecoder( config=[ embedding_dim, ], n_classes=n_classes, )
[docs] def forward(self, x): x = self.patch_embedding(x) ( b, t, n, d, ) = x.shape # shape of x will be number of videos,time,num_frames,embedding dim cls_space_tokens = repeat(self.space_token, "() n d -> b t n d", b=b, t=t) x = nn.Parameter(torch.cat((cls_space_tokens, x), dim=2)) x = self.pos_embedding(x) x = rearrange(x, "b t n d -> (b t) n d") x = self.spatial_transformer(x) x = rearrange(x[:, 0], "(b t) ... -> b t ...", b=b) cls_temporal_tokens = repeat(self.time_token, "() n d -> b n d", b=b) x = torch.cat((cls_temporal_tokens, x), dim=1) x = self.temporal_transformer(x) x = x.mean(dim=1) if self.pool == "mean" else x[:, 0] x = self.decoder(x) return x
# model 3
[docs]@MODEL_REGISTRY.register() class ViViTModel3(BaseClassificationModel): """ Model 3 Implementation from : `ViViT: A Video Vision Transformer <https://arxiv.org/abs/2103.15691>`_ Parameters ---------- img_size:int or tuple[int] size of a frame patch_t:int Temporal length of single tube/patch in tubelet embedding patch_h:int Height of single tube/patch in tubelet embedding patch_w:int Width of single tube/patch in tubelet embedding in_channels: int Number of input channels, default is 3 n_classes:int Number of classes num_frames :int Number of seconds in each Video embedding_dim:int Embedding dimension of a patch depth:int Number of Encoder layers num_heads: int Number of attention heads head_dim:int Dimension of attention head p_dropout:float Dropout rate/probability, default is 0.0 mlp_dim: int Hidden dimension, optional """ def __init__( self, img_size, patch_t, patch_h, patch_w, in_channels, n_classes, num_frames, embedding_dim, depth, num_heads, head_dim, p_dropout, mlp_dim=None, ): super(ViViTModel3, self).__init__( in_channels=in_channels, patch_size=(patch_h, patch_w), pool="mean", img_size=img_size, ) h, w = pair(img_size) self.tubelet_embedding = TubeletEmbedding( embedding_dim=embedding_dim, tubelet_t=patch_t, tubelet_h=patch_h, tubelet_w=patch_w, in_channels=in_channels, ) self.pos_embbedding = PosEmbedding( shape=[num_frames // patch_t, (h * w) // (patch_w * patch_h)], dim=embedding_dim, ) self.encoder = ViViTEncoder( dim=embedding_dim, num_heads=num_heads, head_dim=head_dim, p_dropout=p_dropout, depth=depth, hidden_dim=mlp_dim, ) self.decoder = MLPDecoder( config=[ embedding_dim, ], n_classes=n_classes, )
[docs] def forward(self, x): x = self.tubelet_embedding(x) x = self.pos_embbedding(x) x = self.encoder(x) x = x.mean(dim=1) x = self.decoder(x) return x