flax.cursor package#
The Cursor API allows for mutability of pytrees. This API provides a more
ergonomic solution to making partial-updates of deeply nested immutable
data structures, compared to making many nested dataclasses.replace calls.
To illustrate, consider the example below:
from flax.cursor import cursor
import dataclasses
from typing import Any
@dataclasses.dataclass
class A:
x: Any
a = A(A(A(A(A(A(A(0)))))))
To replace the int 0 using dataclasses.replace, we would have to write many nested calls:
a2 = dataclasses.replace(
a,
x=dataclasses.replace(
a.x,
x=dataclasses.replace(
a.x.x,
x=dataclasses.replace(
a.x.x.x,
x=dataclasses.replace(
a.x.x.x.x,
x=dataclasses.replace(
a.x.x.x.x.x,
x=dataclasses.replace(a.x.x.x.x.x.x, x=1),
),
),
),
),
),
)
The equivalent can be achieved much more simply using the Cursor API:
a3 = cursor(a).x.x.x.x.x.x.x.set(1)
assert a2 == a3
The Cursor object keeps tracks of changes made to it and when .build is called,
generates a new object with the accumulated changes. Basic usage involves
wrapping the object in a Cursor, making changes to the Cursor object and
generating a new copy of the original object with the accumulated changes.
- flax.cursor.cursor(obj)[source]#
Wrap Cursor over obj and return it. Changes can then be applied to the Cursor object in the following ways:
single-line change via the
.setmethodmultiple changes, and then calling the
.buildmethodmultiple changes conditioned on the tree path and node value, via the
.apply_updatemethod
.setexample:from flax.cursor import cursor dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} modified_dict_obj = cursor(dict_obj)['b'][0].set(10) assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}
.buildexample:from flax.cursor import cursor dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} c = cursor(dict_obj) c['b'][0] = 10 c['a'] = (100, 200) modified_dict_obj = c.build() assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}
.apply_updateexample:from flax.cursor import cursor from flax.training import train_state import optax def update_fn(path, value): '''Replace params with empty dictionary.''' if 'params' in path: return {} return value state = train_state.TrainState.create( apply_fn=lambda x: x, params={'a': 1, 'b': 2}, tx=optax.adam(1e-3), ) c = cursor(state) state2 = c.apply_update(update_fn).build() assert state2.params == {} assert state.params == {'a': 1, 'b': 2} # make sure original params are unchanged
View the docstrings for each method to see more examples of their usage.
- Parameters
obj – the object you want to wrap the Cursor in
- Returns
A Cursor object wrapped around obj.
- class flax.cursor.Cursor(obj, parent_key)[source]#
- apply_update(update_fn)[source]#
Traverse the Cursor object and apply conditional changes recursively via an
update_fn. Theupdate_fnhas a function signature of(str, Any) -> Any:The input arguments are the current key path (in the form of a string delimited by ‘/’) and value at that current key path
The output is the new value (either modified by the
update_fnor same as the input value if the condition wasn’t fulfilled)
To generate a copy of the original object with the accumulated changes, call the
.buildmethod.NOTES:
If the
update_fnreturns a modified value, this function will not recurse any further down that branch to apply changes. For example, if we intend to replace an attribute that points to a dictionary with an int, we don’t need to look for further changes inside the dictionary, since the dictionary will be replaced anyways.The
isoperator is used to determine whether the return value is modified (by comparing it to the input value). Therefore if theupdate_fnmodifies a mutable container (e.g. lists, dicts, etc.) and returns the same container,.apply_updatewill treat the returned value as unmodified as it contains the sameid. To avoid this, return a copy of the modified value.The
.apply_updateWILL NOT apply theupdate_fnto the value at the top-most level of the pytree (i.e. the root node). Theupdate_fnwill be applied recursively, starting at the root node’s children.
Example:
import flax.linen as nn from flax.cursor import cursor import jax import jax.numpy as jnp class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(3)(x) x = nn.relu(x) x = nn.Dense(3)(x) x = nn.relu(x) x = nn.Dense(3)(x) x = nn.relu(x) return x params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] def update_fn(path, value): '''Multiply all dense kernel params by 2 and add 1. Subtract the Dense_1 bias param by 1.''' if 'kernel' in path: return value * 2 + 1 elif 'Dense_1' in path and 'bias' in path: return value - 1 return value c = cursor(params) new_params = c.apply_update(update_fn).build() for layer in ('Dense_0', 'Dense_1', 'Dense_2'): assert (new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all() if layer == 'Dense_1': assert (new_params[layer]['bias'] == jnp.array([-1, -1, -1])).all() else: assert (new_params[layer]['bias'] == params[layer]['bias']).all() assert jax.tree_util.tree_all( jax.tree_util.tree_map( lambda x, y: (x == y).all(), params, Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[ 'params' ], ) ) # make sure original params are unchanged
- Parameters
update_fn – the function that will conditionally apply changes to the Cursor object
- Returns
The current Cursor object with the updates applied by the
update_fn.
- build()[source]#
Create and return a copy of the original object with accumulated changes. This method is to be called after making changes to the Cursor object.
NOTE: The new object is built bottom-up, the changes will be first applied to the leaf nodes, and then its parent, all the way up to the root.
Example:
from flax.cursor import cursor from flax.training import train_state import optax dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} c = cursor(dict_obj) c['b'][0] = 10 c['a'] = (100, 200) modified_dict_obj = c.build() assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]} state = train_state.TrainState.create( apply_fn=lambda x: x, params=dict_obj, tx=optax.adam(1e-3), ) new_fn = lambda x: x + 1 c = cursor(state) c.params['b'][1] = 10 c.apply_fn = new_fn modified_state = c.build() assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]} assert modified_state.apply_fn == new_fn
- Returns
A copy of the original object with the accumulated changes.
- set(value)[source]#
Set a new value for an attribute, property, element or entry in the Cursor object and return a copy of the original object, containing the new set value.
Example:
from flax.cursor import cursor from flax.training import train_state import optax dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} modified_dict_obj = cursor(dict_obj)['b'][0].set(10) assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]} state = train_state.TrainState.create( apply_fn=lambda x: x, params=dict_obj, tx=optax.adam(1e-3), ) modified_state = cursor(state).params['b'][1].set(10) assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
- Parameters
value – the value used to set an attribute, property, element or entry in the Cursor object
- Returns
A copy of the original object with the new set value.