import torch
import torch.nn as nn
from ..utils import pair
[docs]class PatchMerging(nn.Module):
"""
Parameters
----------
input_resolution: int or tuple[int]
Resolution of input features
dim : int
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super(PatchMerging, self).__init__()
self.input_resolution = pair(input_resolution)
self.dim = dim
self.reduction = nn.Linear(4 * self.dim, 2 * self.dim, bias=False)
self.norm = norm_layer(4 * dim)
[docs] def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x