API Reference#
- flax.config package
- flax.core.frozen_dict package
- flax.cursor package
- flax.error package
AlreadyExistsErrorApplyModuleInvalidMethodErrorApplyScopeInvalidVariablesStructureErrorApplyScopeInvalidVariablesTypeErrorAssignSubModuleErrorCallCompactUnboundModuleErrorCallSetupUnboundModuleErrorCallUnbindOnUnboundModuleErrorDescriptorAttributeErrorIncorrectPostInitOverrideErrorInvalidCheckpointErrorInvalidFilterErrorInvalidInstanceModuleErrorInvalidRngErrorInvalidScopeErrorJaxTransformErrorLazyInitErrorMPACheckpointingRequiredErrorMPARestoreDataCorruptedErrorMPARestoreTargetRequiredErrorModifyScopeVariableErrorMultipleMethodsCompactErrorNameInUseErrorPartitioningUnspecifiedErrorReservedModuleAttributeErrorScopeCollectionNotFoundScopeParamNotFoundErrorScopeParamShapeErrorScopeVariableNotFoundErrorSetAttributeFrozenModuleErrorSetAttributeInModuleSetupErrorTransformTargetErrorTransformedMethodReturnValueError
- flax.jax_utils package
- flax.linen
- Module
- Init/Apply
- Layers
- Linear Modules
- Pooling
- Normalization
- Combinators
- Stochastic
- Attention
- Recurrent
RNNCellBaseLSTMCellOptimizedLSTMCellGRUCellRNNBidirectional- flax.linen.Dense
- flax.linen.DenseGeneral
- flax.linen.Conv
- flax.linen.ConvTranspose
- flax.linen.ConvLocal
- flax.linen.Embed
- flax.linen.BatchNorm
- flax.linen.LayerNorm
- flax.linen.GroupNorm
- flax.linen.RMSNorm
- flax.linen.Sequential
- flax.linen.Dropout
- flax.linen.SelfAttention
- flax.linen.MultiHeadDotProductAttention
- flax.linen.RNNCellBase
- flax.linen.LSTMCell
- flax.linen.OptimizedLSTMCell
- flax.linen.GRUCell
- flax.linen.RNN
- flax.linen.Bidirectional
- flax.linen.max_pool
- flax.linen.avg_pool
- flax.linen.pool
- flax.linen.dot_product_attention_weights
- flax.linen.dot_product_attention
- flax.linen.make_attention_mask
- flax.linen.make_causal_mask
- Activation functions
PReLUcelu()elu()gelu()glu()hard_sigmoid()hard_silu()hard_swish()hard_tanh()leaky_relu()log_sigmoid()log_softmax()logsumexp()one_hot()relu()selu()sigmoid()silu()soft_sign()softmax()softplus()standardize()swish()tanh()- flax.linen.activation.PReLU
- flax.linen.activation.celu
- flax.linen.activation.elu
- flax.linen.activation.gelu
- flax.linen.activation.glu
- flax.linen.activation.hard_sigmoid
- flax.linen.activation.hard_silu
- flax.linen.activation.hard_swish
- flax.linen.activation.hard_tanh
- flax.linen.activation.leaky_relu
- flax.linen.activation.log_sigmoid
- flax.linen.activation.log_softmax
- flax.linen.activation.logsumexp
- flax.linen.activation.one_hot
- flax.linen.activation.relu
- flax.linen.activation.relu6
- flax.linen.activation.selu
- flax.linen.activation.sigmoid
- flax.linen.activation.silu
- flax.linen.activation.soft_sign
- flax.linen.activation.softmax
- flax.linen.activation.softplus
- flax.linen.activation.standardize
- flax.linen.activation.swish
- flax.linen.activation.tanh
- Initializers
constant()delta_orthogonal()glorot_normal()glorot_uniform()he_normal()he_uniform()kaiming_normal()kaiming_uniform()lecun_normal()lecun_uniform()normal()ones()ones_init()orthogonal()uniform()variance_scaling()xavier_normal()xavier_uniform()zeros()zeros_init()- flax.linen.initializers.constant
- flax.linen.initializers.delta_orthogonal
- flax.linen.initializers.glorot_normal
- flax.linen.initializers.glorot_uniform
- flax.linen.initializers.he_normal
- flax.linen.initializers.he_uniform
- flax.linen.initializers.kaiming_normal
- flax.linen.initializers.kaiming_uniform
- flax.linen.initializers.lecun_normal
- flax.linen.initializers.lecun_uniform
- flax.linen.initializers.normal
- flax.linen.initializers.ones
- flax.linen.initializers.ones_init
- flax.linen.initializers.orthogonal
- flax.linen.initializers.uniform
- flax.linen.initializers.variance_scaling
- flax.linen.initializers.xavier_normal
- flax.linen.initializers.xavier_uniform
- flax.linen.initializers.zeros
- flax.linen.initializers.zeros_init
- Transformations
vmap()scan()jit()remat()remat_scan()map_variables()jvp()vjp()custom_vjp()while_loop()cond()switch()- flax.linen.vmap
- flax.linen.scan
- flax.linen.jit
- flax.linen.remat
- flax.linen.remat_scan
- flax.linen.map_variables
- flax.linen.jvp
- flax.linen.vjp
- flax.linen.custom_vjp
- flax.linen.while_loop
- flax.linen.cond
- flax.linen.switch
- Inspection
- Variable dictionary
- SPMD
Partitioned()with_partitioning()get_partition_spec()get_sharding()LogicallyPartitioned()logical_axis_rules()set_logical_axis_rules()get_logical_axis_rules()logical_to_mesh_axes()logical_to_mesh()logical_to_mesh_sharding()with_logical_constraint()with_logical_partitioning()- flax.linen.Partitioned
- flax.linen.with_partitioning
- flax.linen.get_partition_spec
- flax.linen.get_sharding
- flax.linen.LogicallyPartitioned
- flax.linen.logical_axis_rules
- flax.linen.set_logical_axis_rules
- flax.linen.get_logical_axis_rules
- flax.linen.logical_to_mesh_axes
- flax.linen.logical_to_mesh
- flax.linen.logical_to_mesh_sharding
- flax.linen.with_logical_constraint
- flax.linen.with_logical_partitioning
- Decorators
- Profiling
- flax.serialization package
- flax.struct package
- flax.traceback_util package
- flax.training package
- flax.traverse_util package