Window Attention
- class vformer.attention.window.WindowAttention(dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_dropout=0.0, proj_dropout=0.0)[source]
Bases:
Module
Implementation of Window Attention introduced in: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
- Parameters
dim (int) – Number of input channels.
window_size (int or tuple[int]) – The height and width of the window.
num_heads (int) – Number of attention heads.
qkv_bias (bool) – If True, add a learnable bias to query, key, value, default is
True
qk_scale (float, optional) – Override default qk scale of head_dim ** -0.5 if set
attn_dropout (float, optional) – Dropout rate, default is 0.0.
proj_dropout (float, optional) – Dropout rate, default is 0.0.
- forward(x, mask=None)[source]
- Parameters
x (torch.Tensor) – input Tensor
mask (torch.Tensor) – Attention mask used for shifted window attention, if None, window attention will be used, else attention mask will be taken into consideration. for better understanding you may refer this github issue.
- Returns
Returns output tensor by applying Window-Attention or Shifted-Window-Attention on input tensor
- Return type
torch.Tensor
- training: bool