import torch
import torch.nn as nn
import torch.nn.functional as F
from ...common import BaseClassificationModel
from ...decoder import MLPDecoder
from ...encoder import CVTEmbedding, PosEmbedding, VanillaEncoder
from ...utils import MODEL_REGISTRY, pair
[docs]@MODEL_REGISTRY.register()
class CCT(BaseClassificationModel):
"""
Implementation of `Escaping the Big Data Paradigm with Compact Transformers <https://arxiv.org/abs/2104.05704>`_
Parameters
-----------
img_size: int
Size of the image
patch_size: int
Size of the single patch in the image
in_channels: int
Number of input channels in image
seq_pool:bool
Whether to use sequence pooling or not
embedding_dim: int
Patch embedding dimension
num_layers: int
Number of Encoders in encoder block
num_heads: int
Number of heads in each transformer layer
mlp_ratio:float
Ratio of mlp heads to embedding dimension
n_classes: int
Number of classes for classification
p_dropout: float
Dropout probability
attn_dropout: float
Dropout probability
drop_path: float
Stochastic depth rate, default is 0.1
positional_embedding: str
One of the string values {``'learnable'``, ``'sine'`` , ``None``}, default is ``'learnable'``.
decoder_config: tuple(int) or int
Configuration of the decoder. If None, the default configuration is used.
pooling_kernel_size: int or tuple(int)
Size of the kernel in MaxPooling operation
pooling_stride: int or tuple(int)
Stride of MaxPooling operation
pooling_padding: int
Padding in MaxPooling operation
"""
def __init__(
self,
img_size=224,
patch_size=4,
in_channels=3,
seq_pool=True,
embedding_dim=768,
num_layers=1,
head_dim=96,
num_heads=1,
mlp_ratio=4.0,
n_classes=1000,
p_dropout=0.1,
attn_dropout=0.1,
drop_path=0.1,
positional_embedding="learnable",
decoder_config=(
768,
1024,
),
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
):
super().__init__(
img_size=img_size,
patch_size=patch_size,
)
assert (
img_size % patch_size == 0
), f"Image size ({img_size}) has to be divisible by patch size ({patch_size})"
img_size = pair(img_size)
self.in_channels = in_channels
self.embedding = CVTEmbedding(
in_channels=in_channels,
out_channels=embedding_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0,
max_pool=True,
pooling_kernel_size=pooling_kernel_size,
pooling_stride=pooling_stride,
pooling_padding=pooling_padding,
activation=nn.ReLU,
num_conv_layers=1,
conv_bias=True,
)
positional_embedding = (
positional_embedding
if positional_embedding in ["sine", "learnable", "none"]
else "sine"
)
hidden_dim = int(embedding_dim * mlp_ratio)
self.embedding_dim = embedding_dim
self.sequence_length = self.embedding.sequence_length(
n_channels=in_channels, height=img_size[0], width=img_size[1]
)
self.seq_pool = seq_pool
assert (
self.sequence_length is not None or positional_embedding == "none"
), f"Positional embedding is set to {positional_embedding} and the sequence length was not specified."
if not seq_pool:
self.sequence_length += 1
self.class_emb = nn.Parameter(
torch.zeros(1, 1, self.embedding_dim), requires_grad=True
)
else:
self.attention_pool = nn.Linear(self.embedding_dim, 1)
if positional_embedding != "none":
self.positional_emb = PosEmbedding(
self.sequence_length,
dim=embedding_dim,
drop=p_dropout,
sinusoidal=True if positional_embedding == "sine" else False,
)
else:
self.positional_emb = None
dpr = [x.item() for x in torch.linspace(0, drop_path, num_layers)]
self.encoder_blocks = nn.ModuleList(
[
VanillaEncoder(
embedding_dim=embedding_dim,
num_heads=num_heads,
depth=1,
head_dim=head_dim,
mlp_dim=hidden_dim,
p_dropout=p_dropout,
attn_dropout=attn_dropout,
drop_path_rate=dpr[i],
)
for i in range(num_layers)
]
)
if decoder_config is not None:
if not isinstance(decoder_config, list) and not isinstance(
decoder_config, tuple
):
decoder_config = [decoder_config]
assert (
decoder_config[0] == embedding_dim
), f"Configurations do not match for MLPDecoder, First element of `decoder_config` expected to be {embedding_dim}, got {decoder_config[0]} "
self.decoder = MLPDecoder(config=decoder_config, n_classes=n_classes)
else:
self.decoder = MLPDecoder(config=embedding_dim, n_classes=n_classes)
[docs] def forward(self, x):
"""
Parameters
----------
x: torch.Tensor
Input tensor
Returns
----------
torch.Tensor
Returns tensor of size `n_classes`
"""
x = self.embedding(x)
if self.positional_emb is None and x.size(1) < self.sequence_length:
x = F.pad(
x, (0, 0, 0, self.in_channels - x.size(1)), mode="constant", value=0
)
if not self.seq_pool:
cls_token = self.class_emb.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
if self.positional_emb is not None:
x = self.positional_emb(x)
for blk in self.encoder_blocks:
x = blk(x)
if self.seq_pool:
x = torch.matmul(
F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x
).squeeze(-2)
else:
x = x[:, 0]
x = self.decoder(x)
return x