Source code for vit_flax.layers
from flax import nn
[docs]class FeedForward(nn.Module):
"""Simple FeedForward module: x -> Dense(x, latent_dim) -> GeLU -> Dense(x, original_dim)"""
[docs] def apply(self, x, latent_dim):
"""Applies two linear transformations with Gelu to input
:param x: Input tensor.
:param latent_dim: FC latent dim.
"""
dim = x.shape[-1]
x = nn.Dense(x, latent_dim)
x = nn.gelu(x)
x = nn.Dense(x, dim)
return x
[docs]class Residual(nn.Module):
"""Simple residual function: x -> fn(x) + x"""
[docs] def apply(self, x, residual_fn):
"""Applies residual(skip) connection to a residual_fn block
:param x: Input tensor.
:param residual_fn: Callable function that takes in tensor as input.
"""
return residual_fn(x) + x
[docs]class PreNorm(nn.Module):
"""A function applied to normalized input"""
[docs] def apply(self, x, norm, fn):
"""Applies a function to normalized input
:param x: Input tensor.
:param norm: Normalization module
:param fn: Callable function that takes in tensor as input
:return: Output of function after normalizing input
"""
return fn(norm(x))
[docs]class Transformer(nn.Module):
"""A simple implementation of a Transformer"""
[docs] def apply(self, x, depth, num_heads, feed_forward_dim_1):
"""Applies a residual normalized attention(Transformer) to input
:param x: Input tensor.
:param depth: Number of layers of Residual-normalized attention layers.
:param num_heads: Number of attention heads
:param feed_forward_dim: FC dimension
:return: Transformer output embedding
"""
attention = nn.SelfAttention.partial(num_heads=num_heads)
norm = nn.LayerNorm
norm_attention = PreNorm.partial(norm=norm, fn=attention)
residual_attention = Residual.partial(residual_fn=norm_attention)
forward = FeedForward.partial(latent_dim=feed_forward_dim_1)
norm_forward = PreNorm.partial(norm=norm, fn=forward)
residual_forward = Residual.partial(residual_fn=norm_forward)
for _ in range(depth):
x = residual_forward(residual_attention(x))
return x