import torch
from einops import repeat
from ...encoder import ConViTEncoder
from ...utils import MODEL_REGISTRY
from .vanilla import VanillaViT
[docs]@MODEL_REGISTRY.register()
class ConViT(VanillaViT):
"""
Implementation of `ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases <https://arxiv.org/abs/2103.10697>`_
Parameters
-----------
img_size: int
Size of the image
patch_size: int
Size of a patch
n_classes: int
Number of classes for classification
embedding_dim: int
Dimension of hidden layer
head_dim: int
Dimension of the attention head
depth_sa: int
Number of attention layers in the encoder for self attention layers
depth_gpsa: int
Number of attention layers in the encoder for global positional self attention layers
attn_heads_sa:int
Number of the attention heads for self attention layers
attn_heads_gpsa:int
Number of the attention heads for global positional self attention layers
encoder_mlp_dim: int
Dimension of hidden layer in the encoder
in_channels: int
Number of input channels
decoder_config: int or tuple or list, optional
Configuration of the decoder. If None, the default configuration is used.
pool: str
Feature pooling type, one of {``cls``,``mean``}
p_dropout_encoder: float
Dropout probability in the encoder
p_dropout_embedding: float
Dropout probability in the embedding layer
"""
def __init__(
self,
img_size,
patch_size,
n_classes,
embedding_dim=1024,
head_dim=64,
depth_sa=6,
depth_gpsa=6,
attn_heads_sa=16,
attn_heads_gpsa=16,
encoder_mlp_dim=2048,
in_channels=3,
decoder_config=None,
pool="cls",
p_dropout_encoder=0,
p_dropout_embedding=0,
):
super().__init__(
img_size,
patch_size,
n_classes,
embedding_dim,
head_dim,
depth_sa,
attn_heads_sa,
encoder_mlp_dim,
in_channels,
decoder_config,
pool,
p_dropout_encoder,
p_dropout_embedding,
)
self.encoder_gpsa = ConViTEncoder(
embedding_dim,
depth_gpsa,
attn_heads_gpsa,
head_dim,
encoder_mlp_dim,
p_dropout_encoder,
)
[docs] def forward(self, x):
"""
Parameters
----------
x: torch.Tensor
Input tensor
Returns
----------
torch.Tensor
Returns tensor of size `n_classes`
"""
x = self.patch_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_embedding(x)
x_cls = x[:, 0:1, :]
x = x[:, 1:, :]
x = self.encoder_gpsa(x)
x = torch.cat((x_cls, x), dim=1)
x = self.encoder(x)
x = self.pool(x)
x = self.decoder(x)
return x