flax.linen.LogicallyPartitioned#
- class flax.linen.LogicallyPartitioned(value: Any, names: Tuple[Optional[str], ...], mesh: Optional[jax._src.mesh.Mesh] = None, rules: Optional[Sequence[Tuple[str, Union[str, Tuple[str], NoneType]]]] = None)[source]#
- __init__(value, names, mesh=None, rules=None)#
Methods
__init__(value, names[, mesh, rules])add_axis(index, params)Adds a new axis to the axis metadata.
get_partition_spec()Returns the
Partitionspecfor this partitioned value.get_sharding(mesh)Returns the
NamedShardingfor this partitioned value.remove_axis(index, params)Removes an axis from the axis metadata.
replace(**updates)"Returns a new object replacing the specified fields with new values.
replace_boxed(val)Replaces the boxed value with the provided value.
unbox([apply_constraint])Returns the wrapped value with the partitioning constraint applied.
Attributes
meshrules