import torch.nn as nn
from einops import rearrange, repeat
from ..attention.cross import CrossAttention
from ..encoder.nn import FeedForward
from ..functional import PreNorm
from ..utils import DECODER_REGISTRY
[docs]@DECODER_REGISTRY.register()
class PerceiverIODecoder(nn.Module):
"""
Implementation of the Perceiver IO Decoder
Parameters
----------
dim: int
Size of sequence to be encoded
latent_dim: int
Dimension of latent array
queries_dim: int
Dimension of queries array
num_latents: int
Number of latent arrays
num_cross_heads: int
Number of heads for cross attention
cross_head_dim: int
Dimension of cross attention head
logits_dim: int, optional
Dimension of output logits
decoder_ff: bool
Whether to include a feed forward layer for the decoder attention block
"""
def __init__(
self,
dim=32,
latent_dim=512,
queries_dim=32,
num_cross_heads=1,
cross_head_dim=64,
logits_dim=None,
decoder_ff=False,
):
super().__init__()
self.decoder_cross_attn = PreNorm(
queries_dim,
CrossAttention(
queries_dim,
latent_dim,
num_heads=num_cross_heads,
head_dim=cross_head_dim,
),
context_dim=latent_dim,
)
self.decoder_ff = (
PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
)
self.to_logits = (
nn.Linear(queries_dim, logits_dim)
if logits_dim is not None
else nn.Identity()
)
[docs] def forward(self, x, mask=None, queries=None):
b, *_, device = *x.shape, x.device
if queries is None:
return x
if queries.ndim == 2:
queries = repeat(queries, "n d -> b n d", b=b)
latents = self.decoder_cross_attn(queries, context=x)
if self.decoder_ff is not None:
latents = latents + self.decoder_ff(latents)
return self.to_logits(latents)