Window Attention

class vformer.attention.window.WindowAttention(dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_dropout=0.0, proj_dropout=0.0)[source]

Bases: Module

Implementation of Window Attention introduced in: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

Parameters
  • dim (int) – Number of input channels.

  • window_size (int or tuple[int]) – The height and width of the window.

  • num_heads (int) – Number of attention heads.

  • qkv_bias (bool) – If True, add a learnable bias to query, key, value, default is True

  • qk_scale (float, optional) – Override default qk scale of head_dim ** -0.5 if set

  • attn_dropout (float, optional) – Dropout rate, default is 0.0.

  • proj_dropout (float, optional) – Dropout rate, default is 0.0.

forward(x, mask=None)[source]
Parameters
  • x (torch.Tensor) – input Tensor

  • mask (torch.Tensor) – Attention mask used for shifted window attention, if None, window attention will be used, else attention mask will be taken into consideration. for better understanding you may refer this github issue.

Returns

Returns output tensor by applying Window-Attention or Shifted-Window-Attention on input tensor

Return type

torch.Tensor

training: bool