flax.struct package#
Utilities for defining custom classes that can be used with jax transformations.
- flax.struct.dataclass(clz)[source]#
Create a class which can be passed to functional transformations.
NOTE: Inherit from
PyTreeNodeinstead to avoid type checking issues when using PyType.Jax transformations such as jax.jit and jax.grad require objects that are immutable and can be mapped over using the jax.tree_util methods. The dataclass decorator makes it easy to define custom classes that can be passed safely to Jax. For example:
from flax import struct @struct.dataclass class Model: params: Any # use pytree_node=False to indicate an attribute should not be touched # by Jax transformations. apply_fn: FunctionType = struct.field(pytree_node=False) def __apply__(self, *args): return self.apply_fn(*args) model = Model(params, apply_fn) model.params = params_b # Model is immutable. This will raise an error. model_b = model.replace(params=params_b) # Use the replace method instead. # This class can now be used safely in Jax to compute gradients w.r.t. the # parameters. model = Model(params, apply_fn) model_grad = jax.grad(some_loss_fn)(model)
Note that dataclasses have an auto-generated
__init__where the arguments of the constructor and the attributes of the created instance match 1:1. This correspondence is what makes these objects valid containers that work with JAX transformations and more generally the jax.tree_util library.Sometimes a “smart constructor” is desired, for example because some of the attributes can be (optionally) derived from others. The way to do this with Flax dataclasses is to make a static or class method that provides the smart constructor. This way the simple constructor used by jax.tree_util is preserved. Consider the following example:
@struct.dataclass class DirectionAndScaleKernel: direction: Array scale: Array @classmethod def create(cls, kernel): scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True) direction = direction / scale return cls(direction, scale)
- Parameters
clz – the class that will be transformed by the decorator.
- Returns
The new class.
- class flax.struct.PyTreeNode(*args, **kwargs)[source]#
Base class for dataclasses that should act like a JAX pytree node.
See
flax.struct.dataclassfor thejax.tree_utilbehavior. This base class additionally avoids type checking errors when using PyType.Example:
from flax import struct class Model(struct.PyTreeNode): params: Any # use pytree_node=False to indicate an attribute should not be touched # by Jax transformations. apply_fn: FunctionType = struct.field(pytree_node=False) def __apply__(self, *args): return self.apply_fn(*args) model = Model(params, apply_fn) model.params = params_b # Model is immutable. This will raise an error. model_b = model.replace(params=params_b) # Use the replace method instead. # This class can now be used safely in Jax to compute gradients w.r.t. the # parameters. model = Model(params, apply_fn) model_grad = jax.grad(some_loss_fn)(model)