Index _ | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | Q | R | S | T | U | V | W | X | Z _ __call__() (flax.linen.activation.PReLU method) (flax.linen.BatchNorm method), [1] (flax.linen.Bidirectional method), [1] (flax.linen.Conv method), [1] (flax.linen.ConvLocal method), [1] (flax.linen.ConvTranspose method), [1] (flax.linen.Dense method), [1] (flax.linen.DenseGeneral method), [1] (flax.linen.Dropout method), [1] (flax.linen.Embed method), [1] (flax.linen.GroupNorm method), [1] (flax.linen.GRUCell method), [1] (flax.linen.LayerNorm method), [1] (flax.linen.LSTMCell method), [1] (flax.linen.MultiHeadDotProductAttention method), [1] (flax.linen.OptimizedLSTMCell method), [1] (flax.linen.RMSNorm method) (flax.linen.RNN method), [1] (flax.linen.RNNCellBase method), [1] (flax.linen.SelfAttention method), [1] (flax.linen.Sequential method), [1] __init__() (flax.linen.activation.PReLU method) (flax.linen.LogicallyPartitioned method) (flax.linen.Partitioned method) (flax.linen.Variable method) (flax.traverse_util.ModelParamTraversal method) __setattr__() (flax.linen.Module method) A activation_fn (flax.linen.GRUCell attribute), [1] (flax.linen.LSTMCell attribute), [1] (flax.linen.OptimizedLSTMCell attribute), [1] AlreadyExistsError apply() (flax.linen.Module method) (in module flax.linen), [1] apply_gradients() (flax.training.train_state.TrainState method) apply_update() (flax.cursor.Cursor method) ApplyModuleInvalidMethodError ApplyScopeInvalidVariablesStructureError ApplyScopeInvalidVariablesTypeError AssignSubModuleError attention_fn (flax.linen.MultiHeadDotProductAttention attribute), [1] avg_pool() (in module flax.linen), [1] axis (flax.linen.BatchNorm attribute), [1] (flax.linen.DenseGeneral attribute), [1] axis_index_groups (flax.linen.BatchNorm attribute), [1] (flax.linen.GroupNorm attribute), [1] (flax.linen.LayerNorm attribute), [1] (flax.linen.RMSNorm attribute) axis_name (flax.linen.BatchNorm attribute), [1] (flax.linen.GroupNorm attribute), [1] (flax.linen.LayerNorm attribute), [1] (flax.linen.RMSNorm attribute) B batch_dims (flax.linen.DenseGeneral attribute), [1] BatchNorm (class in flax.linen), [1] best_metric (flax.training.early_stopping.EarlyStopping attribute) bias_init (flax.linen.BatchNorm attribute), [1] (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] (flax.linen.ConvTranspose attribute), [1] (flax.linen.Dense attribute), [1] (flax.linen.DenseGeneral attribute), [1] (flax.linen.GroupNorm attribute), [1] (flax.linen.GRUCell attribute), [1] (flax.linen.LayerNorm attribute), [1] (flax.linen.LSTMCell attribute), [1] (flax.linen.MultiHeadDotProductAttention attribute), [1] (flax.linen.OptimizedLSTMCell attribute), [1] Bidirectional (class in flax.linen), [1] bind() (flax.linen.Module method) Bound Module broadcast_dims (flax.linen.Dropout attribute), [1] broadcast_dropout (flax.linen.MultiHeadDotProductAttention attribute), [1] build() (flax.cursor.Cursor method) C CallCompactUnboundModuleError CallSetupUnboundModuleError CallUnbindOnUnboundModuleError cell (flax.linen.RNN attribute), [1] celu() (in module flax.linen.activation), [1] Compact / Non-compact Module compact() (in module flax.linen), [1] compose() (flax.traverse_util.Traversal method) cond() (in module flax.linen), [1] constant() (in module flax.linen.initializers), [1] Conv (class in flax.linen), [1] convert_pre_linen() (in module flax.training.checkpoints) ConvLocal (class in flax.linen), [1] ConvTranspose (class in flax.linen), [1] copy() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) create() (flax.training.train_state.TrainState class method) create_constant_learning_rate_schedule() (in module flax.training.lr_schedule) create_cosine_learning_rate_schedule() (in module flax.training.lr_schedule) create_stepped_learning_rate_schedule() (in module flax.training.lr_schedule) Cursor (class in flax.cursor) cursor() (in module flax.cursor) custom_vjp() (in module flax.linen), [1] D dataclass() (in module flax.struct) decode (flax.linen.MultiHeadDotProductAttention attribute), [1] define_bool_state() (in module flax.configurations) delta_orthogonal() (in module flax.linen.initializers), [1] Dense (class in flax.linen), [1] DenseGeneral (class in flax.linen), [1] DescriptorAttributeError deterministic (flax.linen.Dropout attribute), [1] (flax.linen.MultiHeadDotProductAttention attribute), [1] disable_named_call() (in module flax.linen), [1] dot_product_attention() (in module flax.linen), [1] dot_product_attention_weights() (in module flax.linen), [1] Dropout (class in flax.linen), [1] dropout_rate (flax.linen.MultiHeadDotProductAttention attribute), [1] dtype (flax.linen.BatchNorm attribute), [1] (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] (flax.linen.ConvTranspose attribute), [1] (flax.linen.Dense attribute), [1] (flax.linen.DenseGeneral attribute), [1] (flax.linen.Embed attribute), [1] (flax.linen.GroupNorm attribute), [1] (flax.linen.GRUCell attribute), [1] (flax.linen.LayerNorm attribute), [1] (flax.linen.LSTMCell attribute), [1] (flax.linen.MultiHeadDotProductAttention attribute), [1] (flax.linen.OptimizedLSTMCell attribute), [1] (flax.linen.RMSNorm attribute) E each() (flax.traverse_util.Traversal method) EarlyStopping (class in flax.training.early_stopping) elu() (in module flax.linen.activation), [1] Embed (class in flax.linen), [1] embedding_init (flax.linen.Embed attribute), [1] enable_named_call() (in module flax.linen), [1] epsilon (flax.linen.BatchNorm attribute), [1] (flax.linen.GroupNorm attribute), [1] (flax.linen.LayerNorm attribute), [1] (flax.linen.RMSNorm attribute) F feature_axes (flax.linen.LayerNorm attribute), [1] (flax.linen.RMSNorm attribute) feature_group_count (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] features (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] (flax.linen.ConvTranspose attribute), [1] (flax.linen.Dense attribute), [1] (flax.linen.DenseGeneral attribute), [1] (flax.linen.Embed attribute), [1] (flax.linen.LSTMCell attribute), [1] filter() (flax.traverse_util.Traversal method) flatten_dict() (in module flax.traverse_util) flax.configurations module flax.core.variables module flax.errors module flax.jax_utils module flax.linen module flax.linen.activation module flax.linen.initializers module flax.linen.spmd module flax.linen.transforms module flax.serialization module flax.struct module flax.traceback_util module flax.training.checkpoints module flax.training.lr_schedule module flax.traverse_util module Folding in freeze() (in module flax.core.frozen_dict) from_bytes() (in module flax.serialization) from_state_dict() (in module flax.serialization) FrozenDict (class in flax.core.frozen_dict) Functional core G gate_fn (flax.linen.GRUCell attribute), [1] (flax.linen.LSTMCell attribute), [1] (flax.linen.OptimizedLSTMCell attribute), [1] gelu() (in module flax.linen.activation), [1] get_logical_axis_rules() (in module flax.linen), [1] get_metrics() (in module flax.training.common_utils) get_partition_spec() (in module flax.linen), [1] get_sharding() (in module flax.linen), [1] glorot_normal() (in module flax.linen.initializers), [1] glorot_uniform() (in module flax.linen.initializers), [1] glu() (in module flax.linen.activation), [1] group_size (flax.linen.GroupNorm attribute), [1] GroupNorm (class in flax.linen), [1] GRUCell (class in flax.linen), [1] H hard_sigmoid() (in module flax.linen.activation), [1] hard_silu() (in module flax.linen.activation), [1] hard_swish() (in module flax.linen.activation), [1] hard_tanh() (in module flax.linen.activation), [1] he_normal() (in module flax.linen.initializers), [1] he_uniform() (in module flax.linen.initializers), [1] hide_flax_in_tracebacks() (in module flax.traceback_util) I IncorrectPostInitOverrideError init() (flax.linen.Module method) (in module flax.linen), [1] init_with_output() (flax.linen.Module method) (in module flax.linen), [1] input_dilation (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] InvalidCheckpointError InvalidFilterError InvalidInstanceModuleError InvalidRngError InvalidScopeError is_initializing() (flax.linen.Module method) iterate() (flax.traverse_util.Traversal method) (flax.traverse_util.TraverseAttr method) (flax.traverse_util.TraverseCompose method) (flax.traverse_util.TraverseEach method) (flax.traverse_util.TraverseFilter method) (flax.traverse_util.TraverseId method) (flax.traverse_util.TraverseItem method) (flax.traverse_util.TraverseMerge method) (flax.traverse_util.TraverseTree method) J JaxTransformError jit() (in module flax.linen), [1] jvp() (in module flax.linen), [1] K kaiming_normal() (in module flax.linen.initializers), [1] kaiming_uniform() (in module flax.linen.initializers), [1] keep_order (flax.linen.RNN attribute), [1] kernel_dilation (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] (flax.linen.ConvTranspose attribute), [1] kernel_init (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] (flax.linen.ConvTranspose attribute), [1] (flax.linen.Dense attribute), [1] (flax.linen.DenseGeneral attribute), [1] (flax.linen.GRUCell attribute), [1] (flax.linen.LSTMCell attribute), [1] (flax.linen.MultiHeadDotProductAttention attribute), [1] (flax.linen.OptimizedLSTMCell attribute), [1] kernel_size (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] (flax.linen.ConvTranspose attribute), [1] L latest_checkpoint() (in module flax.training.checkpoints) LayerNorm (class in flax.linen), [1] Lazy initialization LazyInitError leaky_relu() (in module flax.linen.activation), [1] lecun_normal() (in module flax.linen.initializers), [1] lecun_uniform() (in module flax.linen.initializers), [1] Lifted transformation log_sigmoid() (in module flax.linen.activation), [1] log_softmax() (in module flax.linen.activation), [1] logical_axis_rules() (in module flax.linen), [1] logical_to_mesh() (in module flax.linen), [1] logical_to_mesh_axes() (in module flax.linen), [1] logical_to_mesh_sharding() (in module flax.linen), [1] LogicallyPartitioned (class in flax.linen) LogicallyPartitioned() (in module flax.linen) logsumexp() (in module flax.linen.activation), [1] LSTMCell (class in flax.linen), [1] M make_attention_mask() (in module flax.linen), [1] make_causal_mask() (in module flax.linen), [1] make_rng() (flax.linen.Module method) map_variables() (in module flax.linen), [1] mask (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] (flax.linen.ConvTranspose attribute), [1] max_pool() (in module flax.linen), [1] merge() (flax.traverse_util.Traversal method) min_delta (flax.training.early_stopping.EarlyStopping attribute) ModelParamTraversal (class in flax.traverse_util) ModifyScopeVariableError Module module (class in flax.linen) flax.configurations flax.core.variables flax.errors flax.jax_utils flax.linen flax.linen.activation flax.linen.initializers flax.linen.spmd flax.linen.transforms flax.serialization flax.struct flax.traceback_util flax.training.checkpoints flax.training.lr_schedule flax.traverse_util momentum (flax.linen.BatchNorm attribute), [1] MPACheckpointingRequiredError MPARestoreDataCorruptedError MPARestoreTargetRequiredError msgpack_restore() (in module flax.serialization) msgpack_serialize() (in module flax.serialization) MultiHeadDotProductAttention (class in flax.linen), [1] MultipleMethodsCompactError N NameInUseError negative_slope_init (flax.linen.activation.PReLU attribute), [1] normal() (in module flax.linen.initializers), [1] normalize_qk (flax.linen.MultiHeadDotProductAttention attribute), [1] nowrap() (in module flax.linen), [1] num_embeddings (flax.linen.Embed attribute), [1] num_groups (flax.linen.GroupNorm attribute), [1] num_heads (flax.linen.MultiHeadDotProductAttention attribute), [1] O one_hot() (in module flax.linen.activation), [1] onehot() (in module flax.training.common_utils) ones() (in module flax.linen.initializers), [1] ones_init() (in module flax.linen.initializers), [1] OptimizedLSTMCell (class in flax.linen), [1] orthogonal() (in module flax.linen.initializers), [1] out_features (flax.linen.MultiHeadDotProductAttention attribute), [1] override_named_call() (in module flax.linen), [1] P pad_shard_unpad() (in module flax.jax_utils) padding (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] (flax.linen.ConvTranspose attribute), [1] param() (flax.linen.Module method) param_dtype (flax.linen.activation.PReLU attribute), [1] (flax.linen.BatchNorm attribute), [1] (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] (flax.linen.ConvTranspose attribute), [1] (flax.linen.Dense attribute), [1] (flax.linen.DenseGeneral attribute), [1] (flax.linen.Embed attribute), [1] (flax.linen.GroupNorm attribute), [1] (flax.linen.GRUCell attribute), [1] (flax.linen.LayerNorm attribute), [1] (flax.linen.LSTMCell attribute), [1] (flax.linen.MultiHeadDotProductAttention attribute), [1] (flax.linen.OptimizedLSTMCell attribute), [1] (flax.linen.RMSNorm attribute) Params / parameters partial_eval_by_shape() (in module flax.jax_utils) Partitioned (class in flax.linen) Partitioned() (in module flax.linen) PartitioningUnspecifiedError path_aware_map() (in module flax.traverse_util) patience (flax.training.early_stopping.EarlyStopping attribute) patience_count (flax.training.early_stopping.EarlyStopping attribute) perturb() (flax.linen.Module method) pmean() (in module flax.jax_utils) pool() (in module flax.linen), [1] pop() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) precision (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] (flax.linen.ConvTranspose attribute), [1] (flax.linen.Dense attribute), [1] (flax.linen.DenseGeneral attribute), [1] (flax.linen.MultiHeadDotProductAttention attribute), [1] prefetch_to_device() (in module flax.jax_utils) PReLU (class in flax.linen.activation), [1] pretty_repr() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) PyTreeNode (class in flax.struct) Q qkv_features (flax.linen.MultiHeadDotProductAttention attribute), [1] R rate (flax.linen.Dropout attribute), [1] recurrent_kernel_init (flax.linen.GRUCell attribute), [1] (flax.linen.LSTMCell attribute), [1] (flax.linen.OptimizedLSTMCell attribute), [1] reduction_axes (flax.linen.LayerNorm attribute), [1] (flax.linen.RMSNorm attribute) register_serialization_state() (in module flax.serialization) relu (in module flax.linen.activation) relu() (in module flax.linen.activation) relu6 (in module flax.linen.activation) remat() (in module flax.linen), [1] remat_scan() (in module flax.linen), [1] replicate() (in module flax.jax_utils) ReservedModuleAttributeError restore_checkpoint() (in module flax.training.checkpoints) return_carry (flax.linen.RNN attribute), [1] reverse (flax.linen.RNN attribute), [1] RMSNorm (class in flax.linen) RNG sequences rng_collection (flax.linen.Dropout attribute), [1] RNN (class in flax.linen), [1] RNNCellBase (class in flax.linen), [1] S save_checkpoint() (in module flax.training.checkpoints) save_checkpoint_multiprocess() (in module flax.training.checkpoints) scale_init (flax.linen.BatchNorm attribute), [1] (flax.linen.GroupNorm attribute), [1] (flax.linen.LayerNorm attribute), [1] (flax.linen.RMSNorm attribute) scan() (in module flax.linen), [1] Scope ScopeCollectionNotFound ScopeParamNotFoundError ScopeParamShapeError ScopeVariableNotFoundError SelfAttention (class in flax.linen), [1] selu() (in module flax.linen.activation), [1] Sequential (class in flax.linen), [1] set() (flax.cursor.Cursor method) (flax.traverse_util.Traversal method) set_logical_axis_rules() (in module flax.linen), [1] SetAttributeFrozenModuleError SetAttributeInModuleSetupError setup() (flax.linen.Module method) Shape inference shard() (in module flax.training.common_utils) shard_prng_key() (in module flax.training.common_utils) should_stop (flax.training.early_stopping.EarlyStopping attribute) show_flax_in_tracebacks() (in module flax.traceback_util) sigmoid() (in module flax.linen.activation), [1] silu() (in module flax.linen.activation), [1] soft_sign() (in module flax.linen.activation), [1] softmax() (in module flax.linen.activation), [1] softplus() (in module flax.linen.activation), [1] sow() (flax.linen.Module method) split_rngs (flax.linen.RNN attribute), [1] stack_forest() (in module flax.training.common_utils) standardize() (in module flax.linen.activation), [1] static_bool_env() (in module flax.configurations) strides (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] (flax.linen.ConvTranspose attribute), [1] swish() (in module flax.linen.activation), [1] switch() (in module flax.linen), [1] T tabulate() (flax.linen.Module method) (in module flax.linen), [1] tanh() (in module flax.linen.activation), [1] time_major (flax.linen.RNN attribute), [1] to_bytes() (in module flax.serialization) to_state_dict() (in module flax.serialization) TrainState (class in flax.training.train_state) TransformedMethodReturnValueError TransformTargetError transpose_kernel (flax.linen.ConvTranspose attribute), [1] Traversal (class in flax.traverse_util) TraverseAttr (class in flax.traverse_util) TraverseCompose (class in flax.traverse_util) TraverseEach (class in flax.traverse_util) TraverseFilter (class in flax.traverse_util) TraverseId (class in flax.traverse_util) TraverseItem (class in flax.traverse_util) TraverseMerge (class in flax.traverse_util) TraverseTree (class in flax.traverse_util) tree() (flax.traverse_util.Traversal method) U unbind() (flax.linen.Module method) unflatten_dict() (in module flax.traverse_util) unfreeze() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) uniform() (in module flax.linen.initializers), [1] unreplicate() (in module flax.jax_utils) unroll (flax.linen.RNN attribute), [1] update() (flax.training.early_stopping.EarlyStopping method) (flax.traverse_util.Traversal method) (flax.traverse_util.TraverseAttr method) (flax.traverse_util.TraverseCompose method) (flax.traverse_util.TraverseEach method) (flax.traverse_util.TraverseFilter method) (flax.traverse_util.TraverseId method) (flax.traverse_util.TraverseItem method) (flax.traverse_util.TraverseMerge method) (flax.traverse_util.TraverseTree method) use_bias (flax.linen.BatchNorm attribute), [1] (flax.linen.Conv attribute), [1] (flax.linen.ConvLocal attribute), [1] (flax.linen.ConvTranspose attribute), [1] (flax.linen.Dense attribute), [1] (flax.linen.DenseGeneral attribute), [1] (flax.linen.GroupNorm attribute), [1] (flax.linen.LayerNorm attribute), [1] (flax.linen.MultiHeadDotProductAttention attribute), [1] use_fast_variance (flax.linen.BatchNorm attribute), [1] (flax.linen.GroupNorm attribute), [1] (flax.linen.LayerNorm attribute), [1] use_running_average (flax.linen.BatchNorm attribute), [1] use_scale (flax.linen.BatchNorm attribute), [1] (flax.linen.GroupNorm attribute), [1] (flax.linen.LayerNorm attribute), [1] (flax.linen.RMSNorm attribute) V Variable (class in flax.linen), [1] Variable collections Variable dictionary variable() (flax.linen.Module method) variable_axes (flax.linen.RNN attribute), [1] variable_broadcast (flax.linen.RNN attribute), [1] variable_carry (flax.linen.RNN attribute), [1] variables (flax.linen.Module property) variance_scaling() (in module flax.linen.initializers), [1] vjp() (in module flax.linen), [1] vmap() (in module flax.linen), [1] W while_loop() (in module flax.linen), [1] with_logical_constraint() (in module flax.linen), [1] with_logical_partitioning() (in module flax.linen), [1] with_partitioning() (in module flax.linen), [1] X xavier_normal() (in module flax.linen.initializers), [1] xavier_uniform() (in module flax.linen.initializers), [1] Z zeros() (in module flax.linen.initializers), [1] zeros_init() (in module flax.linen.initializers), [1]