Memory Efficient Attention

class vformer.attention.memory_efficient.MemoryEfficientAttention(dim, num_heads=8, head_dim=64, p_dropout=0.0, query_chunk_size=1024, key_chunk_size=4096)[source]

Bases: Module

Implementation of Memory-Efficient O(1) Attention: https://arxiv.org/abs/2112.05682

Implementation based on https://github.com/AminRezaei0x443/memory-efficient-attention

Parameters
  • dim (int) – Dimension of the embedding

  • num_heads (int) – Number of the attention heads

  • head_dim (int) – Dimension of each head

  • p_dropout (float) – Dropout Probability

static dynamic_slice(x, starts, sizes)[source]
forward(x)[source]
Parameters

x (torch.Tensor) – Input tensor

Returns

Returns output tensor by applying self-attention on input tensor

Return type

torch.Tensor

static map_pt(f, xs)[source]
query_chunk_attention(query, key, value)[source]
static scan(f, init, xs, length=None)[source]
static summarize_chunk(query, key, value)[source]
training: bool