Source code for vformer.models.dense.dpt

import types

import torch
import torch.nn as nn

from ...utils.dpt_utils import _resize_pos_embed, forward_flex
from ...utils.registry import MODEL_REGISTRY

activations = {}


def get_activation(name):
    def hook(model, input, output):
        activations[name] = output

    return hook


attention = {}


def get_attention(name):
    def hook(module, input, output):
        x = input[0]
        B, N, C = x.shape
        qkv = (
            module.to_qkv(x)
            .reshape(B, N, 3, module.num_heads, C // module.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = (
            qkv[0],
            qkv[1],
            qkv[2],
        )  # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * module.scale

        attn = attn.softmax(dim=-1)  # [:,:,1,1:]
        attention[name] = attn

    return hook


[docs]@MODEL_REGISTRY.register() class DPTDepth(nn.Module): """ Implementation of " Vision Transformers for Dense Prediction " https://arxiv.org/abs/2103.13413 Parameters ----------- backbone:str Name of ViT model to be used as backbone, must be one of {`vitb16`,`vitl16`,`vit_tiny`} in_channels: int Number of channels in input image, default is 3 img_size: tuple[int] Input image size, default is (384,384) readout:str Method to handle the `readout_token` or `cls_token` Must be one of {`add`, `ignore`,`project`}, default is `project` hooks: list[int] List representing index of encoder blocks on which hooks will be registered. These hooks extract features from different ViT blocks, eg attention, default is (2,5,8,11). channels_last: bool Alters the memory format of storing tensors, default is False, For more information visit, this `blogpost<https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html>` use_bn:bool If True, BatchNormalisation is used in `FeatureFusionBlock_custom`, default is False enable_attention_hooks:bool If True, `get_attention` hook is registered, default is false non_negative:bool If True, Relu operation will be applied in `DPTDepth.model.head` block, default is True invert:bool If True, forward pass output of `DPTDepth.model.head` will be transformed (inverted) according to `scale` and `shift` parameters, default is False scale:float Float value that will be multiplied with forward pass output from `DPTDepth.model.head`, default is 1.0 shift:float Float value that will be added with forward pass output from `DPTDepth.model.head` after scaling, default is 0.0 """ def __init__( self, backbone, in_channels=3, img_size=(384, 384), readout="project", hooks=(2, 5, 8, 11), channels_last=False, use_bn=False, enable_attention_hooks=False, non_negative=True, scale=1.0, shift=0.0, invert=False, ): super(DPTDepth, self).__init__() self.channels_last = channels_last self.use_bn = use_bn self.enable_attention_hooks = enable_attention_hooks self.non_negative = non_negative self.scale = scale self.shift = shift self.invert = invert start_index = 1 if backbone == "vitb16": scratch_in_features = ( 96, 192, 384, 768, ) self.model = MODEL_REGISTRY.get("VanillaViT")( img_size=img_size, patch_size=16, embedding_dim=768, head_dim=64, depth=12, num_heads=12, encoder_mlp_dim=768, n_classes=10, in_channels=in_channels, ) hooks = (2, 5, 8, 11) if hooks is None else hooks self.vit_features = 768 elif backbone == "vitl16": scratch_in_features = (256, 512, 1024, 1024) self.model = MODEL_REGISTRY.get("VanillaViT")( img_size=img_size, patch_size=16, embedding_dim=1024, head_dim=64, depth=24, num_heads=16, encoder_mlp_dim=1024, n_classes=10, in_channels=in_channels, ) hooks = (5, 11, 17, 23) if hooks is None else hooks self.vit_features = 1024 elif backbone == "vit_tiny": scratch_in_features = (48, 96, 144, 192) self.model = MODEL_REGISTRY.get("VanillaViT")( img_size=img_size, patch_size=16, embedding_dim=192, head_dim=64, depth=12, num_heads=3, encoder_mlp_dim=192, n_classes=3, # doenst matter because decoder part is not used in DPTs forward_vit in_channels=in_channels, ) hooks = (2, 5, 8, 11) if hooks is None else hooks self.vit_features = 192 else: raise NotImplementedError assert readout in ( "add", "ignore", "project", ), f"Not valid `readout` param, Must be one of ('add','ignore','project'), but got {readout}" # check if valid string features = scratch_in_features[0] self._register_hooks_and_add_postprocess( size=img_size, features=scratch_in_features, hooks=hooks, use_readout=readout, enable_attention_hooks=enable_attention_hooks, start_index=start_index, ) self._make_scratch( in_shape=scratch_in_features, out_shape=features, groups=1, expand=False, ) self._add_refinenet_to_scratch(features=features, use_bn=use_bn) self.model.head = nn.Sequential( nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), Interpolate(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), nn.Identity(), ) def _register_hooks_and_add_postprocess( self, size=(384, 384), features=(96, 192, 384, 768), hooks=(2, 5, 8, 11), use_readout="ignore", enable_attention_hooks=False, start_index=1, ): """ Registers forward hooks to the backbone and initializes activation-postprocessing-blocks (act_postprocess(int)) Parameters ----------- size: tuple[int] Input image size features:tuple[int] Number of features hooks:tuple[int] List containing index of encoder blocks to which forward hooks will be registered use_readout:str Appropriate readout operation,must be one of {`add`,`ignore`,`project`} enable_attention_hooks:bool If True, forward hooks will be registered to attention blocks. start_index:int Parameter that handles readout operation, default value is 1. """ for i in range(4): self.model.encoder.encoder[hooks[i]][0].fn.register_forward_hook( get_activation(str(i + 1)) ) self.activations = activations if enable_attention_hooks: for i in range(4): self.model.encoder.encoder[hooks[i]][0].fn.register_forward_hook( get_attention(f"attn_{str(i+1)}") ) self.attention = attention readout_oper = get_readout_oper( self.vit_features, features, use_readout, start_index ) # 32, 48, 136, 384 self.act_postprocess1 = nn.Sequential( readout_oper[0], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=self.vit_features, out_channels=features[0], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=features[0], out_channels=features[0], kernel_size=4, stride=4, padding=0, bias=True, dilation=1, groups=1, ), ) self.act_postprocess2 = nn.Sequential( readout_oper[1], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=self.vit_features, out_channels=features[1], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=features[1], out_channels=features[1], kernel_size=2, stride=2, padding=0, bias=True, dilation=1, groups=1, ), ) self.act_postprocess3 = nn.Sequential( readout_oper[2], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=self.vit_features, out_channels=features[2], kernel_size=1, stride=1, padding=0, ), ) self.act_postprocess4 = nn.Sequential( readout_oper[3], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=self.vit_features, out_channels=features[3], kernel_size=1, stride=1, padding=0, ), nn.Conv2d( in_channels=features[3], out_channels=features[3], kernel_size=3, stride=2, padding=1, ), ) self.model.start_index = start_index self.model.patch_size = [16, 16] self.model.forward_flex = types.MethodType(forward_flex, self.model) self.model._resize_pos_embed = types.MethodType(_resize_pos_embed, self.model) def _make_scratch(self, in_shape, out_shape, groups=1, expand=False): """ Makes a scratch module which is subclass of nn.Module Parameters ----------- in_shape: list[int] out_shape:int groups: int expand:bool """ self.scratch = nn.Module() for i in range(4): layer = nn.Conv2d( in_shape[i], out_shape * (2) ** (i) if expand else out_shape, kernel_size=3, stride=1, padding=1, bias=False, groups=groups, ) setattr(self.scratch, f"layer{i+1}_rn", layer) def _add_refinenet_to_scratch(self, features, use_bn): """ Parameters ----------- features: int Number of features use_bn: bool Whether to use batch normalisation """ for i in range(4): refinenet = FeatureFusionBlock_custom( features, nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True, ) setattr(self.scratch, f"refinenet{i+1}", refinenet)
[docs] def forward_vit(self, x): """ Performs forward pass on backbone ViT model and fetches output from different encoder blocks with the help of hooks Parameters ----------- x: torch.Tensor Input image tensor """ b, c, h, w = x.shape glob = forward_flex(self, x) layer_1 = self.activations["1"] layer_2 = self.activations["2"] layer_3 = self.activations["3"] layer_4 = self.activations["4"] layer_1 = self.act_postprocess1[0:2](layer_1) layer_2 = self.act_postprocess2[0:2](layer_2) layer_3 = self.act_postprocess3[0:2](layer_3) layer_4 = self.act_postprocess4[0:2](layer_4) unflatten = nn.Sequential( nn.Unflatten( 2, torch.Size( [ h // self.model.patch_size[1], w // self.model.patch_size[0], ] ), ) ) if layer_1.ndim == 3: layer_1 = unflatten(layer_1) if layer_2.ndim == 3: layer_2 = unflatten(layer_2) if layer_3.ndim == 3: layer_3 = unflatten(layer_3) if layer_4.ndim == 3: layer_4 = unflatten(layer_4) layer_1 = self.act_postprocess1[3 : len(self.act_postprocess1)](layer_1) layer_2 = self.act_postprocess2[3 : len(self.act_postprocess2)](layer_2) layer_3 = self.act_postprocess3[3 : len(self.act_postprocess3)](layer_3) layer_4 = self.act_postprocess4[3 : len(self.act_postprocess4)](layer_4) return layer_1, layer_2, layer_3, layer_4
[docs] def forward(self, x): """ Forward pass of DPTDepth Parameters ----------- x:torch.Tensor Input image tensor """ if self.channels_last: x.contiguous(memory_format=torch.channels_last) layer_1, layer_2, layer_3, layer_4 = self.forward_vit(x) layer_1 = self.scratch.layer1_rn(layer_1) layer_2 = self.scratch.layer2_rn(layer_2) layer_3 = self.scratch.layer3_rn(layer_3) layer_4 = self.scratch.layer4_rn(layer_4) path1 = self.scratch.refinenet4(layer_4) path1 = self.scratch.refinenet3(path1, layer_3) path1 = self.scratch.refinenet2(path1, layer_2) path1 = self.scratch.refinenet1(path1, layer_1) inv_depth = self.model.head(path1).squeeze(dim=1) if self.invert: depth = self.scale * inv_depth + self.shift depth[depth < 1e-8] = 1e-8 depth = 1.0 / depth return depth else: return inv_depth
[docs]class Slice(nn.Module): """Handles readout operation when `readout` parameter is `ignore`. Removes `cls_token` or `readout_token` by index slicing""" def __init__(self, start_index=1): super(Slice, self).__init__() self.start_index = start_index
[docs] def forward(self, x): return x[:, self.start_index :]
[docs]class AddReadout(nn.Module): """Handles readout operation when `readout` parameter is `add`. Removes `cls_token` or `readout_token` from tensor and adds it to the rest of tensor""" def __init__(self, start_index=1): super(AddReadout, self).__init__() self.start_index = start_index
[docs] def forward(self, x): readout = x[:, 0] return x[:, self.start_index :] + readout.unsqueeze(1)
[docs]class ProjectReadout(nn.Module): """Another class that handles readout operation. Used when `readout` parameter is `project`""" def __init__(self, in_features, start_index=1): super(ProjectReadout, self).__init__() self.start_index = start_index self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
[docs] def forward(self, x): readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) features = torch.cat((x[:, self.start_index :], readout), -1) return self.project(features)
[docs]class Interpolate(nn.Module): """Interpolation module Parameters ----------- scale_factor : float Scaling factor used in interpolation mode :str Interpolation mode align_corners: bool Whether to align corners in Interpolation operation """ def __init__(self, scale_factor, mode, align_corners=False): super(Interpolate, self).__init__() self.interp = nn.functional.interpolate self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners
[docs] def forward(self, x): """Forward pass""" x = self.interp( x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, ) return x
[docs]class Transpose(nn.Module): def __init__(self, dim0, dim1): super(Transpose, self).__init__() self.dim0 = dim0 self.dim1 = dim1
[docs] def forward(self, x): x = x.transpose(self.dim0, self.dim1) return x
[docs]class ResidualConvUnit_custom(nn.Module): """Residual convolution module Parameters ----------- features :int Number of features activation: nn.Module Activation module, default is nn.GELU bn: bool Whether to use batch normalisation """ def __init__(self, features, activation=nn.GELU, bn=True): super().__init__() self.bn = bn self.groups = 1 self.conv1 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=not self.bn, groups=self.groups, ) self.conv2 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=not self.bn, groups=self.groups, ) if self.bn == True: self.bn1 = nn.BatchNorm2d(features) self.bn2 = nn.BatchNorm2d(features) self.activation = activation self.skip_add = nn.quantized.FloatFunctional()
[docs] def forward(self, x): """forward pass""" out = self.activation(x) out = self.conv1(out) if self.bn == True: out = self.bn1(out) out = self.activation(out) out = self.conv2(out) if self.bn == True: out = self.bn2(out) return self.skip_add.add(out, x)
# return out + x
[docs]class FeatureFusionBlock_custom(nn.Module): """Feature fusion block.""" def __init__( self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, ): super(FeatureFusionBlock_custom, self).__init__() self.deconv = deconv self.align_corners = align_corners self.groups = 1 self.expand = expand out_features = features self.out_conv = nn.Conv2d( features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1, ) self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) self.skip_add = nn.quantized.FloatFunctional()
[docs] def forward(self, *xs): """Forward pass""" output = xs[0] if len(xs) == 2: res = self.resConfUnit1(xs[1]) output = self.skip_add.add(output, res) # output += res output = self.resConfUnit2(output) output = nn.functional.interpolate( output, scale_factor=2, mode="bilinear", align_corners=self.align_corners ) output = self.out_conv(output) return output
def get_readout_oper(vit_features, features, use_readout, start_index=1): if use_readout == "ignore": readout_oper = [Slice(start_index)] * len(features) elif use_readout == "add": readout_oper = [AddReadout(start_index)] * len(features) else: readout_oper = [ ProjectReadout(vit_features, start_index) for out_feat in features ] return readout_oper