vit_flax package

Submodules

vit_flax.layers module

class vit_flax.layers.FeedForward[source]

Bases: flax.nn.base.Module

Applies two linear transformations with Gelu to input

Parameters:
  • x – Input tensor.
  • latent_dim – FC latent dim.
apply(x, latent_dim)[source]

Applies two linear transformations with Gelu to input

Parameters:
  • x – Input tensor.
  • latent_dim – FC latent dim.
classmethod call(x, latent_dim)

Evaluate the module with the given parameters.

Parameters:
  • params – the parameters of the module. Typically, inital parameter values are constructed using Module.init or Module.init_by_shape.
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • latent_dim – FC latent dim.
Returns:

The output of the module’s apply function.

Apply docstring:

Applies two linear transformations with Gelu to input

classmethod create(x, latent_dim)

Create a module instance by evaluating the model.

DEPRECATION WARNING: create() is deprecated use init() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.

Use create_by_shape instead to initialize without doing computation. Initializer functions can depend both on the shape and the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • *args – arguments passed to the module’s apply function
  • name – name of this module
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • latent_dim – FC latent dim.
Returns:

A pair consisting of the model output and an instance of Model

Apply docstring:

Applies two linear transformations with Gelu to input

classmethod create_by_shape(input_specs, x, latent_dim)

Create a module instance using only shape and dtype information.

DEPRECATION WARNING: create_by_shape() is deprecated use init_by_shape() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.

This method will initialize the model without computation. Initializer functions can depend on the shape but not the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • input_specs – an iterable of (shape, dtype) pairs specifying the inputs
  • *args – other arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • latent_dim – FC latent dim.
Returns:

A pair consisting of the model output and an instance of Model

Apply docstring:

Applies two linear transformations with Gelu to input

classmethod init(x, latent_dim)

Initialize the module parameters.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • latent_dim – FC latent dim.
Returns:

A pair consisting of the model output and the initialized parameters

Apply docstring:

Applies two linear transformations with Gelu to input

classmethod init_by_shape(input_specs, x, latent_dim)

Initialize the module parameters.

This method will initialize the module parameters without computation. Initializer functions can depend on the shape but not the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • input_specs – an iterable of (shape, dtype) pairs specifying the inputs
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
Returns:

A pair consisting of the model output and the initialized parameters

Example

``` input_shape = (batch_size, image_size, image_size, 3) model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),

input_specs=[(input_shape, jnp.float32)])

```

Apply docstring:

Applies two linear transformations with Gelu to input

Parameters:
  • x – Input tensor.
  • latent_dim – FC latent dim.
classmethod partial(latent_dim)

Partially applies a module with the given arguments.

Unlike functools.partial this will return a subclass of Module.

Parameters:
  • name – the name used the module
  • **kwargs – the argument to be applied.
  • x – Input tensor.
  • latent_dim – FC latent dim.
Returns:

A subclass of Module which partially applies the given keyword arguments.

Apply docstring:

Applies two linear transformations with Gelu to input

class vit_flax.layers.PreNorm[source]

Bases: flax.nn.base.Module

Applies a function to normalized input

Parameters:
  • x – Input tensor.
  • norm – Normalization module
  • fn – Callable function that takes in tensor as input
Returns:

Output of function after normalizing input

apply(x, norm, fn)[source]

Applies a function to normalized input

Parameters:
  • x – Input tensor.
  • norm – Normalization module
  • fn – Callable function that takes in tensor as input
Returns:

Output of function after normalizing input

classmethod call(x, norm, fn)

Evaluate the module with the given parameters.

Parameters:
  • params – the parameters of the module. Typically, inital parameter values are constructed using Module.init or Module.init_by_shape.
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • norm – Normalization module
  • fn – Callable function that takes in tensor as input
Returns:

The output of the module’s apply function.

Apply docstring:

Applies a function to normalized input

Returns:

Output of function after normalizing input

classmethod create(x, norm, fn)

Create a module instance by evaluating the model.

DEPRECATION WARNING: create() is deprecated use init() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.

Use create_by_shape instead to initialize without doing computation. Initializer functions can depend both on the shape and the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • *args – arguments passed to the module’s apply function
  • name – name of this module
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • norm – Normalization module
  • fn – Callable function that takes in tensor as input
Returns:

A pair consisting of the model output and an instance of Model

Apply docstring:

Applies a function to normalized input

Returns:

Output of function after normalizing input

classmethod create_by_shape(input_specs, x, norm, fn)

Create a module instance using only shape and dtype information.

DEPRECATION WARNING: create_by_shape() is deprecated use init_by_shape() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.

This method will initialize the model without computation. Initializer functions can depend on the shape but not the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • input_specs – an iterable of (shape, dtype) pairs specifying the inputs
  • *args – other arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • norm – Normalization module
  • fn – Callable function that takes in tensor as input
Returns:

A pair consisting of the model output and an instance of Model

Apply docstring:

Applies a function to normalized input

Returns:

Output of function after normalizing input

classmethod init(x, norm, fn)

Initialize the module parameters.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • norm – Normalization module
  • fn – Callable function that takes in tensor as input
Returns:

A pair consisting of the model output and the initialized parameters

Apply docstring:

Applies a function to normalized input

Returns:

Output of function after normalizing input

classmethod init_by_shape(input_specs, x, norm, fn)

Initialize the module parameters.

This method will initialize the module parameters without computation. Initializer functions can depend on the shape but not the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • input_specs – an iterable of (shape, dtype) pairs specifying the inputs
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
Returns:

A pair consisting of the model output and the initialized parameters

Example

``` input_shape = (batch_size, image_size, image_size, 3) model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),

input_specs=[(input_shape, jnp.float32)])

```

Apply docstring:

Applies a function to normalized input

Parameters:
  • x – Input tensor.
  • norm – Normalization module
  • fn – Callable function that takes in tensor as input
Returns:

Output of function after normalizing input

classmethod partial(norm, fn)

Partially applies a module with the given arguments.

Unlike functools.partial this will return a subclass of Module.

Parameters:
  • name – the name used the module
  • **kwargs – the argument to be applied.
  • x – Input tensor.
  • norm – Normalization module
  • fn – Callable function that takes in tensor as input
Returns:

A subclass of Module which partially applies the given keyword arguments.

Apply docstring:

Applies a function to normalized input

Returns:

Output of function after normalizing input

class vit_flax.layers.Residual[source]

Bases: flax.nn.base.Module

Applies residual(skip) connection to a residual_fn block

Parameters:
  • x – Input tensor.
  • residual_fn – Callable function that takes in tensor as input.
apply(x, residual_fn)[source]

Applies residual(skip) connection to a residual_fn block

Parameters:
  • x – Input tensor.
  • residual_fn – Callable function that takes in tensor as input.
classmethod call(x, residual_fn)

Evaluate the module with the given parameters.

Parameters:
  • params – the parameters of the module. Typically, inital parameter values are constructed using Module.init or Module.init_by_shape.
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • residual_fn – Callable function that takes in tensor as input.
Returns:

The output of the module’s apply function.

Apply docstring:

Applies residual(skip) connection to a residual_fn block

classmethod create(x, residual_fn)

Create a module instance by evaluating the model.

DEPRECATION WARNING: create() is deprecated use init() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.

Use create_by_shape instead to initialize without doing computation. Initializer functions can depend both on the shape and the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • *args – arguments passed to the module’s apply function
  • name – name of this module
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • residual_fn – Callable function that takes in tensor as input.
Returns:

A pair consisting of the model output and an instance of Model

Apply docstring:

Applies residual(skip) connection to a residual_fn block

classmethod create_by_shape(input_specs, x, residual_fn)

Create a module instance using only shape and dtype information.

DEPRECATION WARNING: create_by_shape() is deprecated use init_by_shape() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.

This method will initialize the model without computation. Initializer functions can depend on the shape but not the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • input_specs – an iterable of (shape, dtype) pairs specifying the inputs
  • *args – other arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • residual_fn – Callable function that takes in tensor as input.
Returns:

A pair consisting of the model output and an instance of Model

Apply docstring:

Applies residual(skip) connection to a residual_fn block

classmethod init(x, residual_fn)

Initialize the module parameters.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • residual_fn – Callable function that takes in tensor as input.
Returns:

A pair consisting of the model output and the initialized parameters

Apply docstring:

Applies residual(skip) connection to a residual_fn block

classmethod init_by_shape(input_specs, x, residual_fn)

Initialize the module parameters.

This method will initialize the module parameters without computation. Initializer functions can depend on the shape but not the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • input_specs – an iterable of (shape, dtype) pairs specifying the inputs
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
Returns:

A pair consisting of the model output and the initialized parameters

Example

``` input_shape = (batch_size, image_size, image_size, 3) model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),

input_specs=[(input_shape, jnp.float32)])

```

Apply docstring:

Applies residual(skip) connection to a residual_fn block

Parameters:
  • x – Input tensor.
  • residual_fn – Callable function that takes in tensor as input.
classmethod partial(residual_fn)

Partially applies a module with the given arguments.

Unlike functools.partial this will return a subclass of Module.

Parameters:
  • name – the name used the module
  • **kwargs – the argument to be applied.
  • x – Input tensor.
  • residual_fn – Callable function that takes in tensor as input.
Returns:

A subclass of Module which partially applies the given keyword arguments.

Apply docstring:

Applies residual(skip) connection to a residual_fn block

class vit_flax.layers.Transformer[source]

Bases: flax.nn.base.Module

Applies a residual normalized attention(Transformer) to input

Parameters:
  • x – Input tensor.
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • feed_forward_dim – FC dimension
Returns:

Transformer output embedding

apply(x, depth, num_heads, feed_forward_dim_1)[source]

Applies a residual normalized attention(Transformer) to input

Parameters:
  • x – Input tensor.
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • feed_forward_dim – FC dimension
Returns:

Transformer output embedding

classmethod call(x, depth, num_heads, feed_forward_dim_1)

Evaluate the module with the given parameters.

Parameters:
  • params – the parameters of the module. Typically, inital parameter values are constructed using Module.init or Module.init_by_shape.
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • feed_forward_dim – FC dimension
Returns:

The output of the module’s apply function.

Apply docstring:

Applies a residual normalized attention(Transformer) to input

Returns:

Transformer output embedding

classmethod create(x, depth, num_heads, feed_forward_dim_1)

Create a module instance by evaluating the model.

DEPRECATION WARNING: create() is deprecated use init() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.

Use create_by_shape instead to initialize without doing computation. Initializer functions can depend both on the shape and the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • *args – arguments passed to the module’s apply function
  • name – name of this module
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • feed_forward_dim – FC dimension
Returns:

A pair consisting of the model output and an instance of Model

Apply docstring:

Applies a residual normalized attention(Transformer) to input

Returns:

Transformer output embedding

classmethod create_by_shape(input_specs, x, depth, num_heads, feed_forward_dim_1)

Create a module instance using only shape and dtype information.

DEPRECATION WARNING: create_by_shape() is deprecated use init_by_shape() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.

This method will initialize the model without computation. Initializer functions can depend on the shape but not the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • input_specs – an iterable of (shape, dtype) pairs specifying the inputs
  • *args – other arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • feed_forward_dim – FC dimension
Returns:

A pair consisting of the model output and an instance of Model

Apply docstring:

Applies a residual normalized attention(Transformer) to input

Returns:

Transformer output embedding

classmethod init(x, depth, num_heads, feed_forward_dim_1)

Initialize the module parameters.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor.
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • feed_forward_dim – FC dimension
Returns:

A pair consisting of the model output and the initialized parameters

Apply docstring:

Applies a residual normalized attention(Transformer) to input

Returns:

Transformer output embedding

classmethod init_by_shape(input_specs, x, depth, num_heads, feed_forward_dim_1)

Initialize the module parameters.

This method will initialize the module parameters without computation. Initializer functions can depend on the shape but not the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • input_specs – an iterable of (shape, dtype) pairs specifying the inputs
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
Returns:

A pair consisting of the model output and the initialized parameters

Example

``` input_shape = (batch_size, image_size, image_size, 3) model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),

input_specs=[(input_shape, jnp.float32)])

```

Apply docstring:

Applies a residual normalized attention(Transformer) to input

Parameters:
  • x – Input tensor.
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • feed_forward_dim – FC dimension
Returns:

Transformer output embedding

classmethod partial(depth, num_heads, feed_forward_dim_1)

Partially applies a module with the given arguments.

Unlike functools.partial this will return a subclass of Module.

Parameters:
  • name – the name used the module
  • **kwargs – the argument to be applied.
  • x – Input tensor.
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • feed_forward_dim – FC dimension
Returns:

A subclass of Module which partially applies the given keyword arguments.

Apply docstring:

Applies a residual normalized attention(Transformer) to input

Returns:

Transformer output embedding

vit_flax.vit module

class vit_flax.vit.ViT[source]

Bases: flax.nn.base.Module

Applies the Vision transformer to input tensor.

Parameters:
  • x – Input tensor image
  • patch_size – Patch dimension from image
  • dim – Latent dim
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
  • img_size – Dimension of input image
  • num_classes – Number of classification classes
  • initializer – Flax initializer

:return:Classification output

apply(x, patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)[source]

Applies the Vision transformer to input tensor.

Parameters:
  • x – Input tensor image
  • patch_size – Patch dimension from image
  • dim – Latent dim
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
  • img_size – Dimension of input image
  • num_classes – Number of classification classes
  • initializer – Flax initializer

:return:Classification output

classmethod call(x, patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)

Evaluate the module with the given parameters.

Parameters:
  • params – the parameters of the module. Typically, inital parameter values are constructed using Module.init or Module.init_by_shape.
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor image
  • patch_size – Patch dimension from image
  • dim – Latent dim
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
  • img_size – Dimension of input image
  • num_classes – Number of classification classes
  • initializer – Flax initializer
Returns:

The output of the module’s apply function.

Apply docstring:

Applies the Vision transformer to input tensor.

:return:Classification output

classmethod create(x, patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)

Create a module instance by evaluating the model.

DEPRECATION WARNING: create() is deprecated use init() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.

Use create_by_shape instead to initialize without doing computation. Initializer functions can depend both on the shape and the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • *args – arguments passed to the module’s apply function
  • name – name of this module
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor image
  • patch_size – Patch dimension from image
  • dim – Latent dim
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
  • img_size – Dimension of input image
  • num_classes – Number of classification classes
  • initializer – Flax initializer
Returns:

A pair consisting of the model output and an instance of Model

Apply docstring:

Applies the Vision transformer to input tensor.

:return:Classification output

classmethod create_by_shape(input_specs, x, patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)

Create a module instance using only shape and dtype information.

DEPRECATION WARNING: create_by_shape() is deprecated use init_by_shape() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.

This method will initialize the model without computation. Initializer functions can depend on the shape but not the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • input_specs – an iterable of (shape, dtype) pairs specifying the inputs
  • *args – other arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor image
  • patch_size – Patch dimension from image
  • dim – Latent dim
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
  • img_size – Dimension of input image
  • num_classes – Number of classification classes
  • initializer – Flax initializer
Returns:

A pair consisting of the model output and an instance of Model

Apply docstring:

Applies the Vision transformer to input tensor.

:return:Classification output

classmethod init(x, patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)

Initialize the module parameters.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
  • x – Input tensor image
  • patch_size – Patch dimension from image
  • dim – Latent dim
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
  • img_size – Dimension of input image
  • num_classes – Number of classification classes
  • initializer – Flax initializer
Returns:

A pair consisting of the model output and the initialized parameters

Apply docstring:

Applies the Vision transformer to input tensor.

:return:Classification output

classmethod init_by_shape(input_specs, x, patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)

Initialize the module parameters.

This method will initialize the module parameters without computation. Initializer functions can depend on the shape but not the value of inputs.

Parameters:
  • _rng – the random number generator used to initialize parameters.
  • input_specs – an iterable of (shape, dtype) pairs specifying the inputs
  • *args – arguments passed to the module’s apply function
  • name – name of this module.
  • **kwargs – keyword arguments passed to the module’s apply function
Returns:

A pair consisting of the model output and the initialized parameters

Example

``` input_shape = (batch_size, image_size, image_size, 3) model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),

input_specs=[(input_shape, jnp.float32)])

```

Apply docstring:

Applies the Vision transformer to input tensor.

Parameters:
  • x – Input tensor image
  • patch_size – Patch dimension from image
  • dim – Latent dim
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
  • img_size – Dimension of input image
  • num_classes – Number of classification classes
  • initializer – Flax initializer

:return:Classification output

classmethod partial(patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)

Partially applies a module with the given arguments.

Unlike functools.partial this will return a subclass of Module.

Parameters:
  • name – the name used the module
  • **kwargs – the argument to be applied.
  • x – Input tensor image
  • patch_size – Patch dimension from image
  • dim – Latent dim
  • depth – Number of layers of Residual-normalized attention layers.
  • num_heads – Number of attention heads
  • dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
  • img_size – Dimension of input image
  • num_classes – Number of classification classes
  • initializer – Flax initializer
Returns:

A subclass of Module which partially applies the given keyword arguments.

Apply docstring:

Applies the Vision transformer to input tensor.

:return:Classification output

Module contents