Source code for vit_flax.vit

from vit_flax.layers import Transformer

from flax import nn
from jax import lax


[docs]class ViT(nn.Module): """Vision transformer"""
[docs] def apply(self, x, patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=nn.initializers.normal(stddev=1.0)): """Applies the Vision transformer to input tensor. :param x: Input tensor image :param patch_size: Patch dimension from image :param dim: Latent dim :param depth: Number of layers of Residual-normalized attention layers. :param num_heads: Number of attention heads :param dense_dims: Tuple(int, int) - (Transformer FC dim, Classifier FC dim) :param img_size: Dimension of input image :param num_classes: Number of classification classes :param initializer: Flax initializer :return:Classification output """ b, h, w, c = x.shape patch = x.reshape(b, c, h, w) num_patches = (img_size // patch_size) ** 2 patch = patch.reshape(b, (h*w)//(patch_size*patch_size), c*patch_size**2) fc_embedding = nn.Dense(patch, dim) class_tokens = self.param('class_tokens', (b, 1, dim), initializer) pos_embedding = self.param( 'pos_embedding', (1, num_patches+1, dim), initializer ) x = lax.concatenate([class_tokens, fc_embedding], dimension=1) x += pos_embedding x = Transformer(x, depth, num_heads, dense_dims[0]) x = x[:, 0] x = nn.Dense(x, dense_dims[1]) x = nn.gelu(x) x = nn.Dense(x, num_classes) return x