Source code for vformer.functional.norm

import torch.nn as nn


[docs]class PreNorm(nn.Module): """ Parameters ---------- dim: int Dimension of the embedding fn:nn.Module Attention class context_dim: int Dimension of the context array used in cross attention """ def __init__(self, dim, fn, context_dim=None): super().__init__() self.norm = nn.LayerNorm(dim) self.context_norm = ( nn.LayerNorm(context_dim) if context_dim is not None else None ) self.fn = fn
[docs] def forward(self, x, **kwargs): if "context" in kwargs.keys() and kwargs["context"] is not None: normed_context = self.context_norm(kwargs["context"]) kwargs.update(context=normed_context) return self.fn(self.norm(x), **kwargs)