Window

class vformer.attention.window.WindowAttention(dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_dropout=0.0, proj_dropout=0.0)[source]

Bases: Module

Parameters
  • dim (int) – Number of input channels.

  • window_size (int or tuple[int]) – The height and width of the window.

  • num_heads (int) – Number of attention heads.

  • qkv_bias (bool, default is True) – If True, add a learnable bias to query, key, value.

  • qk_scale (float, optional) – Override default qk scale of head_dim ** -0.5 if set

  • attn_dropout (float, optional) – Dropout rate

  • proj_dropout (float, optional) – Dropout rate

forward(x, mask=None)[source]
Parameters
  • x (torch.Tensor) – input Tensor

  • mask (torch.Tensor) – Attention mask used for shifted window attention, if None, window attention will be used, else attention mask will be taken into consideration. for better understanding you may refer this <https://github.com/microsoft/Swin-Transformer/issues/38>

Returns

Returns output tensor by applying Window-Attention or Shifted-Window-Attention on input tensor

Return type

torch.Tensor

training: bool