Source code for vformer.functional.merge

import torch
import torch.nn as nn

from ..utils import pair


[docs]class PatchMerging(nn.Module): """ Parameters ---------- input_resolution: int or tuple[int] Resolution of input features dim : int """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super(PatchMerging, self).__init__() self.input_resolution = pair(input_resolution) self.dim = dim self.reduction = nn.Linear(4 * self.dim, 2 * self.dim, bias=False) self.norm = norm_layer(4 * dim)
[docs] def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x