Source code for vformer.common.base_model

import torch.nn as nn

from ..utils import pair


[docs]class BaseClassificationModel(nn.Module): """ img_size: int Size of the image patch_size: int or tuple(int) Size of the patch in_channels: int Number of channels in input image pool: {"mean","cls"} Feature pooling type """ def __init__(self, img_size, patch_size, in_channels=3, pool="cls"): super(BaseClassificationModel, self).__init__() img_height, img_width = pair(img_size) patch_height, patch_width = pair(patch_size) assert ( img_height % patch_height == 0 and img_width % patch_width == 0 ), "Image dimensions must be divisible by the patch size." num_patches = (img_height // patch_height) * (img_width // patch_width) patch_dim = in_channels * patch_height * patch_width self.patch_height = patch_height self.patch_width = patch_width self.num_patches = num_patches self.patch_dim = patch_dim assert pool in { "cls", "mean", }, "Feature pooling type must be either cls (cls token) or mean (mean pooling)" self.pool = pool