import torch
import torch.nn as nn
from einops import repeat
from ..attention.cross import CrossAttention
from ..attention.vanilla import VanillaSelfAttention
from ..encoder.nn import FeedForward
from ..functional import PreNorm
from ..utils import ENCODER_REGISTRY
[docs]@ENCODER_REGISTRY.register()
class PerceiverIOEncoder(nn.Module):
"""
Implementation of the Perceiver IO Encoder containing Iterative Cross Attention and Processor
Parameters
----------
dim: int
Size of sequence to be encoded
depth: int
Depth of latent attention blocks
latent_dim: int
Dimension of latent array
num_latents: int
Number of latent arrays
num_cross_heads: int
Number of heads for cross attention
num_latent_heads: int
Number of heads for latent attention
cross_head_dim: int
Dimension of cross attention head
latent_head_dim: int
Dimension of latent attention head
"""
def __init__(
self,
dim=32,
depth=6,
latent_dim=512,
num_latents=512,
num_cross_heads=1,
num_latent_heads=8,
cross_head_dim=64,
latent_head_dim=64,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
self.cross_attn = PreNorm(
latent_dim,
CrossAttention(
latent_dim, dim, num_heads=num_cross_heads, head_dim=cross_head_dim
),
context_dim=dim,
)
self.cross_ff = PreNorm(latent_dim, FeedForward(latent_dim))
get_latent_attn = VanillaSelfAttention(
latent_dim, num_heads=num_latent_heads, head_dim=latent_head_dim
)
get_latent_ff = PreNorm(latent_dim, FeedForward(latent_dim))
self.layers = nn.ModuleList([])
for i in range(depth):
self.layers.append(nn.ModuleList([get_latent_attn, get_latent_ff]))
[docs] def forward(self, x, mask=None):
b, *_, device = *x.shape, x.device
inner_x = repeat(self.latents, "n d -> b n d", b=b)
inner_x = self.cross_attn(inner_x, context=x, mask=mask) + inner_x
inner_x = self.cross_ff(inner_x) + inner_x
for self_attn, self_ff in self.layers:
inner_x = self_attn(inner_x) + inner_x
inner_x = self_ff(inner_x) + inner_x
return inner_x