Memory Efficient self Attention O(n)

class vformer.attention.memory_efficient.MemoryEfficientAttention(dim, num_heads=8, head_dim=64, p_dropout=0.0, query_chunk_size=1024, key_chunk_size=4096)[source]

Bases: Module

Memory Effecient attention introduced in paper Self-attention Does Not Need O(n2) Memory

Implementation based on this repository

Parameters
  • dim (int) – Dimension of the embedding

  • num_heads (int) – Number of the attention heads

  • head_dim (int) – Dimension of each head

  • p_dropout (float) – Dropout Probability

static dynamic_slice(x, starts, sizes)[source]
forward(x)[source]
Parameters

x (torch.Tensor) – Input tensor

Returns

Returns output tensor by applying self-attention on input tensor

Return type

torch.Tensor

static map_pt(f, xs)[source]
query_chunk_attention(query, key, value)[source]
static scan(f, init, xs, length=None)[source]
static summarize_chunk(query, key, value)[source]
training: bool