vit_flax package¶
Submodules¶
vit_flax.layers module¶
-
class
vit_flax.layers.FeedForward[source]¶ Bases:
flax.nn.base.ModuleApplies two linear transformations with Gelu to input
Parameters: - x – Input tensor.
- latent_dim – FC latent dim.
-
apply(x, latent_dim)[source]¶ Applies two linear transformations with Gelu to input
Parameters: - x – Input tensor.
- latent_dim – FC latent dim.
-
classmethod
call(x, latent_dim)¶ Evaluate the module with the given parameters.
Parameters: - params – the parameters of the module. Typically, inital parameter values are constructed using Module.init or Module.init_by_shape.
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- latent_dim – FC latent dim.
Returns: The output of the module’s apply function.
Apply docstring:
Applies two linear transformations with Gelu to input
-
classmethod
create(x, latent_dim)¶ Create a module instance by evaluating the model.
DEPRECATION WARNING: create() is deprecated use init() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.
Use create_by_shape instead to initialize without doing computation. Initializer functions can depend both on the shape and the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- *args – arguments passed to the module’s apply function
- name – name of this module
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- latent_dim – FC latent dim.
Returns: A pair consisting of the model output and an instance of Model
Apply docstring:
Applies two linear transformations with Gelu to input
-
classmethod
create_by_shape(input_specs, x, latent_dim)¶ Create a module instance using only shape and dtype information.
DEPRECATION WARNING: create_by_shape() is deprecated use init_by_shape() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.
This method will initialize the model without computation. Initializer functions can depend on the shape but not the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- input_specs – an iterable of (shape, dtype) pairs specifying the inputs
- *args – other arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- latent_dim – FC latent dim.
Returns: A pair consisting of the model output and an instance of Model
Apply docstring:
Applies two linear transformations with Gelu to input
-
classmethod
init(x, latent_dim)¶ Initialize the module parameters.
Parameters: - _rng – the random number generator used to initialize parameters.
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- latent_dim – FC latent dim.
Returns: A pair consisting of the model output and the initialized parameters
Apply docstring:
Applies two linear transformations with Gelu to input
-
classmethod
init_by_shape(input_specs, x, latent_dim)¶ Initialize the module parameters.
This method will initialize the module parameters without computation. Initializer functions can depend on the shape but not the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- input_specs – an iterable of (shape, dtype) pairs specifying the inputs
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
Returns: A pair consisting of the model output and the initialized parameters
Example
``` input_shape = (batch_size, image_size, image_size, 3) model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),
input_specs=[(input_shape, jnp.float32)])Apply docstring:
Applies two linear transformations with Gelu to input
Parameters: - x – Input tensor.
- latent_dim – FC latent dim.
-
classmethod
partial(latent_dim)¶ Partially applies a module with the given arguments.
Unlike functools.partial this will return a subclass of Module.
Parameters: - name – the name used the module
- **kwargs – the argument to be applied.
- x – Input tensor.
- latent_dim – FC latent dim.
Returns: A subclass of Module which partially applies the given keyword arguments.
Apply docstring:
Applies two linear transformations with Gelu to input
-
class
vit_flax.layers.PreNorm[source]¶ Bases:
flax.nn.base.ModuleApplies a function to normalized input
Parameters: - x – Input tensor.
- norm – Normalization module
- fn – Callable function that takes in tensor as input
Returns: Output of function after normalizing input
-
apply(x, norm, fn)[source]¶ Applies a function to normalized input
Parameters: - x – Input tensor.
- norm – Normalization module
- fn – Callable function that takes in tensor as input
Returns: Output of function after normalizing input
-
classmethod
call(x, norm, fn)¶ Evaluate the module with the given parameters.
Parameters: - params – the parameters of the module. Typically, inital parameter values are constructed using Module.init or Module.init_by_shape.
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- norm – Normalization module
- fn – Callable function that takes in tensor as input
Returns: The output of the module’s apply function.
Apply docstring:
Applies a function to normalized input
Returns: Output of function after normalizing input
-
classmethod
create(x, norm, fn)¶ Create a module instance by evaluating the model.
DEPRECATION WARNING: create() is deprecated use init() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.
Use create_by_shape instead to initialize without doing computation. Initializer functions can depend both on the shape and the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- *args – arguments passed to the module’s apply function
- name – name of this module
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- norm – Normalization module
- fn – Callable function that takes in tensor as input
Returns: A pair consisting of the model output and an instance of Model
Apply docstring:
Applies a function to normalized input
Returns: Output of function after normalizing input
-
classmethod
create_by_shape(input_specs, x, norm, fn)¶ Create a module instance using only shape and dtype information.
DEPRECATION WARNING: create_by_shape() is deprecated use init_by_shape() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.
This method will initialize the model without computation. Initializer functions can depend on the shape but not the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- input_specs – an iterable of (shape, dtype) pairs specifying the inputs
- *args – other arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- norm – Normalization module
- fn – Callable function that takes in tensor as input
Returns: A pair consisting of the model output and an instance of Model
Apply docstring:
Applies a function to normalized input
Returns: Output of function after normalizing input
-
classmethod
init(x, norm, fn)¶ Initialize the module parameters.
Parameters: - _rng – the random number generator used to initialize parameters.
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- norm – Normalization module
- fn – Callable function that takes in tensor as input
Returns: A pair consisting of the model output and the initialized parameters
Apply docstring:
Applies a function to normalized input
Returns: Output of function after normalizing input
-
classmethod
init_by_shape(input_specs, x, norm, fn)¶ Initialize the module parameters.
This method will initialize the module parameters without computation. Initializer functions can depend on the shape but not the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- input_specs – an iterable of (shape, dtype) pairs specifying the inputs
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
Returns: A pair consisting of the model output and the initialized parameters
Example
``` input_shape = (batch_size, image_size, image_size, 3) model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),
input_specs=[(input_shape, jnp.float32)])Apply docstring:
Applies a function to normalized input
Parameters: - x – Input tensor.
- norm – Normalization module
- fn – Callable function that takes in tensor as input
Returns: Output of function after normalizing input
-
classmethod
partial(norm, fn)¶ Partially applies a module with the given arguments.
Unlike functools.partial this will return a subclass of Module.
Parameters: - name – the name used the module
- **kwargs – the argument to be applied.
- x – Input tensor.
- norm – Normalization module
- fn – Callable function that takes in tensor as input
Returns: A subclass of Module which partially applies the given keyword arguments.
Apply docstring:
Applies a function to normalized input
Returns: Output of function after normalizing input
-
class
vit_flax.layers.Residual[source]¶ Bases:
flax.nn.base.ModuleApplies residual(skip) connection to a residual_fn block
Parameters: - x – Input tensor.
- residual_fn – Callable function that takes in tensor as input.
-
apply(x, residual_fn)[source]¶ Applies residual(skip) connection to a residual_fn block
Parameters: - x – Input tensor.
- residual_fn – Callable function that takes in tensor as input.
-
classmethod
call(x, residual_fn)¶ Evaluate the module with the given parameters.
Parameters: - params – the parameters of the module. Typically, inital parameter values are constructed using Module.init or Module.init_by_shape.
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- residual_fn – Callable function that takes in tensor as input.
Returns: The output of the module’s apply function.
Apply docstring:
Applies residual(skip) connection to a residual_fn block
-
classmethod
create(x, residual_fn)¶ Create a module instance by evaluating the model.
DEPRECATION WARNING: create() is deprecated use init() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.
Use create_by_shape instead to initialize without doing computation. Initializer functions can depend both on the shape and the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- *args – arguments passed to the module’s apply function
- name – name of this module
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- residual_fn – Callable function that takes in tensor as input.
Returns: A pair consisting of the model output and an instance of Model
Apply docstring:
Applies residual(skip) connection to a residual_fn block
-
classmethod
create_by_shape(input_specs, x, residual_fn)¶ Create a module instance using only shape and dtype information.
DEPRECATION WARNING: create_by_shape() is deprecated use init_by_shape() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.
This method will initialize the model without computation. Initializer functions can depend on the shape but not the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- input_specs – an iterable of (shape, dtype) pairs specifying the inputs
- *args – other arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- residual_fn – Callable function that takes in tensor as input.
Returns: A pair consisting of the model output and an instance of Model
Apply docstring:
Applies residual(skip) connection to a residual_fn block
-
classmethod
init(x, residual_fn)¶ Initialize the module parameters.
Parameters: - _rng – the random number generator used to initialize parameters.
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- residual_fn – Callable function that takes in tensor as input.
Returns: A pair consisting of the model output and the initialized parameters
Apply docstring:
Applies residual(skip) connection to a residual_fn block
-
classmethod
init_by_shape(input_specs, x, residual_fn)¶ Initialize the module parameters.
This method will initialize the module parameters without computation. Initializer functions can depend on the shape but not the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- input_specs – an iterable of (shape, dtype) pairs specifying the inputs
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
Returns: A pair consisting of the model output and the initialized parameters
Example
``` input_shape = (batch_size, image_size, image_size, 3) model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),
input_specs=[(input_shape, jnp.float32)])Apply docstring:
Applies residual(skip) connection to a residual_fn block
Parameters: - x – Input tensor.
- residual_fn – Callable function that takes in tensor as input.
-
classmethod
partial(residual_fn)¶ Partially applies a module with the given arguments.
Unlike functools.partial this will return a subclass of Module.
Parameters: - name – the name used the module
- **kwargs – the argument to be applied.
- x – Input tensor.
- residual_fn – Callable function that takes in tensor as input.
Returns: A subclass of Module which partially applies the given keyword arguments.
Apply docstring:
Applies residual(skip) connection to a residual_fn block
-
class
vit_flax.layers.Transformer[source]¶ Bases:
flax.nn.base.ModuleApplies a residual normalized attention(Transformer) to input
Parameters: - x – Input tensor.
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- feed_forward_dim – FC dimension
Returns: Transformer output embedding
-
apply(x, depth, num_heads, feed_forward_dim_1)[source]¶ Applies a residual normalized attention(Transformer) to input
Parameters: - x – Input tensor.
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- feed_forward_dim – FC dimension
Returns: Transformer output embedding
-
classmethod
call(x, depth, num_heads, feed_forward_dim_1)¶ Evaluate the module with the given parameters.
Parameters: - params – the parameters of the module. Typically, inital parameter values are constructed using Module.init or Module.init_by_shape.
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- feed_forward_dim – FC dimension
Returns: The output of the module’s apply function.
Apply docstring:
Applies a residual normalized attention(Transformer) to input
Returns: Transformer output embedding
-
classmethod
create(x, depth, num_heads, feed_forward_dim_1)¶ Create a module instance by evaluating the model.
DEPRECATION WARNING: create() is deprecated use init() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.
Use create_by_shape instead to initialize without doing computation. Initializer functions can depend both on the shape and the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- *args – arguments passed to the module’s apply function
- name – name of this module
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- feed_forward_dim – FC dimension
Returns: A pair consisting of the model output and an instance of Model
Apply docstring:
Applies a residual normalized attention(Transformer) to input
Returns: Transformer output embedding
-
classmethod
create_by_shape(input_specs, x, depth, num_heads, feed_forward_dim_1)¶ Create a module instance using only shape and dtype information.
DEPRECATION WARNING: create_by_shape() is deprecated use init_by_shape() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.
This method will initialize the model without computation. Initializer functions can depend on the shape but not the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- input_specs – an iterable of (shape, dtype) pairs specifying the inputs
- *args – other arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- feed_forward_dim – FC dimension
Returns: A pair consisting of the model output and an instance of Model
Apply docstring:
Applies a residual normalized attention(Transformer) to input
Returns: Transformer output embedding
-
classmethod
init(x, depth, num_heads, feed_forward_dim_1)¶ Initialize the module parameters.
Parameters: - _rng – the random number generator used to initialize parameters.
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor.
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- feed_forward_dim – FC dimension
Returns: A pair consisting of the model output and the initialized parameters
Apply docstring:
Applies a residual normalized attention(Transformer) to input
Returns: Transformer output embedding
-
classmethod
init_by_shape(input_specs, x, depth, num_heads, feed_forward_dim_1)¶ Initialize the module parameters.
This method will initialize the module parameters without computation. Initializer functions can depend on the shape but not the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- input_specs – an iterable of (shape, dtype) pairs specifying the inputs
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
Returns: A pair consisting of the model output and the initialized parameters
Example
``` input_shape = (batch_size, image_size, image_size, 3) model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),
input_specs=[(input_shape, jnp.float32)])Apply docstring:
Applies a residual normalized attention(Transformer) to input
Parameters: - x – Input tensor.
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- feed_forward_dim – FC dimension
Returns: Transformer output embedding
-
classmethod
partial(depth, num_heads, feed_forward_dim_1)¶ Partially applies a module with the given arguments.
Unlike functools.partial this will return a subclass of Module.
Parameters: - name – the name used the module
- **kwargs – the argument to be applied.
- x – Input tensor.
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- feed_forward_dim – FC dimension
Returns: A subclass of Module which partially applies the given keyword arguments.
Apply docstring:
Applies a residual normalized attention(Transformer) to input
Returns: Transformer output embedding
vit_flax.vit module¶
-
class
vit_flax.vit.ViT[source]¶ Bases:
flax.nn.base.ModuleApplies the Vision transformer to input tensor.
Parameters: - x – Input tensor image
- patch_size – Patch dimension from image
- dim – Latent dim
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
- img_size – Dimension of input image
- num_classes – Number of classification classes
- initializer – Flax initializer
:return:Classification output
-
apply(x, patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)[source]¶ Applies the Vision transformer to input tensor.
Parameters: - x – Input tensor image
- patch_size – Patch dimension from image
- dim – Latent dim
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
- img_size – Dimension of input image
- num_classes – Number of classification classes
- initializer – Flax initializer
:return:Classification output
-
classmethod
call(x, patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)¶ Evaluate the module with the given parameters.
Parameters: - params – the parameters of the module. Typically, inital parameter values are constructed using Module.init or Module.init_by_shape.
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor image
- patch_size – Patch dimension from image
- dim – Latent dim
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
- img_size – Dimension of input image
- num_classes – Number of classification classes
- initializer – Flax initializer
Returns: The output of the module’s apply function.
Apply docstring:
Applies the Vision transformer to input tensor.
:return:Classification output
-
classmethod
create(x, patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)¶ Create a module instance by evaluating the model.
DEPRECATION WARNING: create() is deprecated use init() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.
Use create_by_shape instead to initialize without doing computation. Initializer functions can depend both on the shape and the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- *args – arguments passed to the module’s apply function
- name – name of this module
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor image
- patch_size – Patch dimension from image
- dim – Latent dim
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
- img_size – Dimension of input image
- num_classes – Number of classification classes
- initializer – Flax initializer
Returns: A pair consisting of the model output and an instance of Model
Apply docstring:
Applies the Vision transformer to input tensor.
:return:Classification output
-
classmethod
create_by_shape(input_specs, x, patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)¶ Create a module instance using only shape and dtype information.
DEPRECATION WARNING: create_by_shape() is deprecated use init_by_shape() to initialize parameters and then explicitly create a nn.Model given the module and initialized parameters.
This method will initialize the model without computation. Initializer functions can depend on the shape but not the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- input_specs – an iterable of (shape, dtype) pairs specifying the inputs
- *args – other arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor image
- patch_size – Patch dimension from image
- dim – Latent dim
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
- img_size – Dimension of input image
- num_classes – Number of classification classes
- initializer – Flax initializer
Returns: A pair consisting of the model output and an instance of Model
Apply docstring:
Applies the Vision transformer to input tensor.
:return:Classification output
-
classmethod
init(x, patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)¶ Initialize the module parameters.
Parameters: - _rng – the random number generator used to initialize parameters.
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
- x – Input tensor image
- patch_size – Patch dimension from image
- dim – Latent dim
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
- img_size – Dimension of input image
- num_classes – Number of classification classes
- initializer – Flax initializer
Returns: A pair consisting of the model output and the initialized parameters
Apply docstring:
Applies the Vision transformer to input tensor.
:return:Classification output
-
classmethod
init_by_shape(input_specs, x, patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)¶ Initialize the module parameters.
This method will initialize the module parameters without computation. Initializer functions can depend on the shape but not the value of inputs.
Parameters: - _rng – the random number generator used to initialize parameters.
- input_specs – an iterable of (shape, dtype) pairs specifying the inputs
- *args – arguments passed to the module’s apply function
- name – name of this module.
- **kwargs – keyword arguments passed to the module’s apply function
Returns: A pair consisting of the model output and the initialized parameters
Example
``` input_shape = (batch_size, image_size, image_size, 3) model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),
input_specs=[(input_shape, jnp.float32)])Apply docstring:
Applies the Vision transformer to input tensor.
Parameters: - x – Input tensor image
- patch_size – Patch dimension from image
- dim – Latent dim
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
- img_size – Dimension of input image
- num_classes – Number of classification classes
- initializer – Flax initializer
:return:Classification output
-
classmethod
partial(patch_size, dim, depth, num_heads, dense_dims, img_size, num_classes, initializer=<function normal.<locals>.init>)¶ Partially applies a module with the given arguments.
Unlike functools.partial this will return a subclass of Module.
Parameters: - name – the name used the module
- **kwargs – the argument to be applied.
- x – Input tensor image
- patch_size – Patch dimension from image
- dim – Latent dim
- depth – Number of layers of Residual-normalized attention layers.
- num_heads – Number of attention heads
- dense_dims – Tuple(int, int) - (Transformer FC dim, Classifier FC dim)
- img_size – Dimension of input image
- num_classes – Number of classification classes
- initializer – Flax initializer
Returns: A subclass of Module which partially applies the given keyword arguments.
Apply docstring:
Applies the Vision transformer to input tensor.
:return:Classification output