from einops.layers.torch import Rearrange
from torch import nn
[docs]class LinearEmbedding(nn.Module):
"""
Projects image patches into embedding space using Linear layer.
Parameters
-----------
embedding_dim: int
Dimension of the resultant embedding
patch_height: int
Height of the patch
patch_width: int
Width of the patch
patch_dim: int
Dimension of the patch
"""
def __init__(
self,
embedding_dim,
patch_height,
patch_width,
patch_dim,
):
super().__init__()
self.patch_embedding = nn.Sequential(
Rearrange(
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=patch_height,
p2=patch_width,
),
nn.Linear(patch_dim, embedding_dim),
)
[docs] def forward(self, x):
"""
Parameters
-----------
x: torch.Tensor
Input tensor
Returns
----------
torch.Tensor
Returns patch embeddings of size `embedding_dim`
"""
return self.patch_embedding(x)