Window
- class vformer.attention.window.WindowAttention(dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_dropout=0.0, proj_dropout=0.0)[source]
Bases:
Module
- Parameters
dim (int) – Number of input channels.
window_size (int or tuple[int]) – The height and width of the window.
num_heads (int) – Number of attention heads.
qkv_bias (bool, default is True) – If True, add a learnable bias to query, key, value.
qk_scale (float, optional) – Override default qk scale of head_dim ** -0.5 if set
attn_dropout (float, optional) – Dropout rate
proj_dropout (float, optional) – Dropout rate
- forward(x, mask=None)[source]
- Parameters
x (torch.Tensor) – input Tensor
mask (torch.Tensor) – Attention mask used for shifted window attention, if None, window attention will be used, else attention mask will be taken into consideration. for better understanding you may refer this <https://github.com/microsoft/Swin-Transformer/issues/38>
- Returns
Returns output tensor by applying Window-Attention or Shifted-Window-Attention on input tensor
- Return type
torch.Tensor
- training: bool