Window Attention

class vformer.attention.window.WindowAttention(dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_dropout=0.0, proj_dropout=0.0)[source]

Implementation of Window Attention introduced in: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

  • dim (int) – Number of input channels.

  • window_size (int or tuple[int]) – The height and width of the window.

  • num_heads (int) – Number of attention heads.

  • qkv_bias (bool) – If True, add a learnable bias to query, key, value, default is True

  • qk_scale (float, optional) – Override default qk scale of head_dim ** -0.5 if set

  • attn_dropout (float, optional) – Dropout rate, default is 0.0.

  • proj_dropout (float, optional) – Dropout rate, default is 0.0.

forward(x, mask=None)[source]
  • x (torch.Tensor) – input Tensor

  • mask (torch.Tensor) – Attention mask used for shifted window attention, if None, window attention will be used, else attention mask will be taken into consideration. for better understanding you may refer this github issue.


Returns output tensor by applying Window-Attention or Shifted-Window-Attention on input tensor

