# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
from typing import Any, Callable, Dict, Generator, Generic, Mapping, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable
from flax.core import FrozenDict
import dataclasses
A = TypeVar('A')
Key = Any
@runtime_checkable
class Indexable(Protocol):
def __getitem__(self, key) -> Any:
...
@dataclasses.dataclass
class ParentKey(Generic[A]):
parent: 'Cursor[A]'
key: Any
class AccessType(enum.Enum):
GETITEM = enum.auto()
GETATTR = enum.auto()
def is_named_tuple(obj):
return (
isinstance(obj, tuple)
and hasattr(obj, '_fields')
and hasattr(obj, '_asdict')
and hasattr(obj, '_replace')
)
def _get_changes(path, obj, update_fn):
"""Helper function for ``Cursor.apply_update``. Returns a generator of
Tuple[Tuple[Union[str, int], Any], ...], where the first element is a
tuple key path where the change was applied from the ``update_fn``, and
the second element is the newly modified value. If the generator is
non-empty, then the tuple key path will always be non-empty as well."""
if path:
str_path = '/'.join(str(key) for key, _ in path)
new_obj = update_fn(str_path, obj)
if new_obj is not obj:
yield path, new_obj
return
if isinstance(obj, (FrozenDict, dict)):
items = obj.items()
access_type = AccessType.GETITEM
elif is_named_tuple(obj):
items = ((name, getattr(obj, name)) for name in obj._fields) # type: ignore
access_type = AccessType.GETATTR
elif isinstance(obj, (list, tuple)):
items = enumerate(obj)
access_type = AccessType.GETITEM
elif dataclasses.is_dataclass(obj):
items = (
(f.name, getattr(obj, f.name))
for f in dataclasses.fields(obj)
if f.init
)
access_type = AccessType.GETATTR
else:
yield from () # empty generator
return
for key, value in items:
yield from _get_changes(path + ((key, access_type),), value, update_fn)
[docs]class Cursor(Generic[A]):
obj: A
parent_key: Optional[ParentKey[A]]
changes: Dict[Any, Union[Any, 'Cursor[A]']]
def __init__(self, obj: A, parent_key: Optional[ParentKey[A]]):
# NOTE: we use `vars` here to avoid calling `__setattr__`
# vars(self) = self.__dict__
vars(self)['obj'] = obj
vars(self)['parent_key'] = parent_key
vars(self)['changes'] = {}
@property
def root(self) -> 'Cursor[A]':
if self.parent_key is None:
return self
else:
return self.parent_key.parent.root # type: ignore
def __getitem__(self, key) -> 'Cursor[A]':
if key in self.changes:
return self.changes[key]
if not isinstance(self.obj, Indexable):
raise TypeError(f'Cannot index into {self.obj}')
if isinstance(self.obj, Mapping) and key not in self.obj:
raise KeyError(f'Key {key} not found in {self.obj}')
if is_named_tuple(self.obj):
return getattr(self, self.obj._fields[key]) # type: ignore
child = Cursor(self.obj[key], ParentKey(self, key))
self.changes[key] = child
return child
def __getattr__(self, name) -> 'Cursor[A]':
if name in self.changes:
return self.changes[name]
if not hasattr(self.obj, name):
raise AttributeError(f'Attribute {name} not found in {self.obj}')
child = Cursor(getattr(self.obj, name), ParentKey(self, name))
self.changes[name] = child
return child
def __setitem__(self, key, value):
if is_named_tuple(self.obj):
return setattr(self, self.obj._fields[key], value) # type: ignore
self.changes[key] = Cursor(value, ParentKey(self, key))
def __setattr__(self, name, value):
self.changes[name] = Cursor(value, ParentKey(self, name))
[docs] def set(self, value) -> A:
"""Set a new value for an attribute, property, element or entry
in the Cursor object and return a copy of the original object,
containing the new set value.
Example::
from flax.cursor import cursor
from flax.training import train_state
import optax
dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
modified_dict_obj = cursor(dict_obj)['b'][0].set(10)
assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}
state = train_state.TrainState.create(
apply_fn=lambda x: x,
params=dict_obj,
tx=optax.adam(1e-3),
)
modified_state = cursor(state).params['b'][1].set(10)
assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
Args:
value: the value used to set an attribute, property, element or entry in the Cursor object
Returns:
A copy of the original object with the new set value.
"""
if self.parent_key is None:
return value
parent, key = self.parent_key.parent, self.parent_key.key # type: ignore
parent.changes[key] = value
return parent.root.build()
[docs] def build(self) -> A:
"""Create and return a copy of the original object with accumulated changes.
This method is to be called after making changes to the Cursor object.
NOTE: The new object is built bottom-up, the changes will be first applied
to the leaf nodes, and then its parent, all the way up to the root.
Example::
from flax.cursor import cursor
from flax.training import train_state
import optax
dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
c = cursor(dict_obj)
c['b'][0] = 10
c['a'] = (100, 200)
modified_dict_obj = c.build()
assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}
state = train_state.TrainState.create(
apply_fn=lambda x: x,
params=dict_obj,
tx=optax.adam(1e-3),
)
new_fn = lambda x: x + 1
c = cursor(state)
c.params['b'][1] = 10
c.apply_fn = new_fn
modified_state = c.build()
assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
assert modified_state.apply_fn == new_fn
Returns:
A copy of the original object with the accumulated changes.
"""
changes = {
key: child.build() if isinstance(child, Cursor) else child
for key, child in self.changes.items()
}
if isinstance(self.obj, FrozenDict):
obj = self.obj.copy(changes) # type: ignore
elif isinstance(self.obj, (dict, list)):
obj = self.obj.copy() # type: ignore
for key, value in changes.items():
obj[key] = value
elif is_named_tuple(self.obj):
obj = self.obj._replace(**changes) # type: ignore
elif isinstance(self.obj, tuple):
obj = list(self.obj) # type: ignore
for key, value in changes.items():
obj[key] = value
obj = tuple(obj) # type: ignore
elif dataclasses.is_dataclass(self.obj):
obj = dataclasses.replace(self.obj, **changes) # type: ignore
else:
obj = self.obj # type: ignore
# NOTE: There is a way to try to do a general replace for pytrees, but it requires
# the key of `changes` to store the type of access (getattr, getitem, etc.)
# in order to access those value from the original object and try to replace them
# with the new value. For simplicity, this is not implemented for now.
# ----------------------
# changed_values = tuple(changes.values())
# result = flatten_until_found(self.obj, changed_values)
# if result is None:
# raise ValueError('Cannot find object in parent')
# leaves, treedef = result
# leaves = [leaf if leaf is not self.obj else value for leaf in leaves]
# obj = jax.tree_util.tree_unflatten(treedef, leaves)
return obj # type: ignore
[docs] def apply_update(
self,
update_fn: Callable[[str, Any], Any],
) -> 'Cursor[A]':
"""Traverse the Cursor object and apply conditional changes recursively via an ``update_fn``.
The ``update_fn`` has a function signature of ``(str, Any) -> Any``:
- The input arguments are the current key path (in the form of a string delimited
by '/') and value at that current key path
- The output is the new value (either modified by the ``update_fn`` or same as the
input value if the condition wasn't fulfilled)
To generate a copy of the original object with the accumulated changes, call the ``.build`` method.
NOTES:
- If the ``update_fn`` returns a modified value, this function will not recurse any further
down that branch to apply changes. For example, if we intend to replace an attribute that points
to a dictionary with an int, we don't need to look for further changes inside the dictionary,
since the dictionary will be replaced anyways.
- The ``is`` operator is used to determine whether the return value is modified (by comparing it
to the input value). Therefore if the ``update_fn`` modifies a mutable container (e.g. lists,
dicts, etc.) and returns the same container, ``.apply_update`` will treat the returned value as
unmodified as it contains the same ``id``. To avoid this, return a copy of the modified value.
- The ``.apply_update`` WILL NOT apply the ``update_fn`` to the value at the top-most level of
the pytree (i.e. the root node). The ``update_fn`` will be applied recursively, starting at the
root node's children.
Example::
import flax.linen as nn
from flax.cursor import cursor
import jax
import jax.numpy as jnp
class Model(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(3)(x)
x = nn.relu(x)
x = nn.Dense(3)(x)
x = nn.relu(x)
x = nn.Dense(3)(x)
x = nn.relu(x)
return x
params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params']
def update_fn(path, value):
'''Multiply all dense kernel params by 2 and add 1.
Subtract the Dense_1 bias param by 1.'''
if 'kernel' in path:
return value * 2 + 1
elif 'Dense_1' in path and 'bias' in path:
return value - 1
return value
c = cursor(params)
new_params = c.apply_update(update_fn).build()
for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
assert (new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all()
if layer == 'Dense_1':
assert (new_params[layer]['bias'] == jnp.array([-1, -1, -1])).all()
else:
assert (new_params[layer]['bias'] == params[layer]['bias']).all()
assert jax.tree_util.tree_all(
jax.tree_util.tree_map(
lambda x, y: (x == y).all(),
params,
Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[
'params'
],
)
) # make sure original params are unchanged
Args:
update_fn: the function that will conditionally apply changes to the Cursor object
Returns:
The current Cursor object with the updates applied by the ``update_fn``.
"""
for path, value in _get_changes((), self.obj, update_fn):
child = self
for key, access_type in path[:-1]:
if access_type is AccessType.GETITEM:
child = child[key]
else: # access_type is AccessType.GETATTR
child = getattr(child, key)
key, access_type = path[-1]
if access_type is AccessType.GETITEM:
child[key] = value
else: # access_type is AccessType.GETATTR
setattr(child, key, value)
return self
[docs]def cursor(obj: A) -> Cursor[A]:
"""Wrap Cursor over obj and return it.
Changes can then be applied to the Cursor object in the following ways:
- single-line change via the ``.set`` method
- multiple changes, and then calling the ``.build`` method
- multiple changes conditioned on the tree path and node value, via the ``.apply_update`` method
``.set`` example::
from flax.cursor import cursor
dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
modified_dict_obj = cursor(dict_obj)['b'][0].set(10)
assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}
``.build`` example::
from flax.cursor import cursor
dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
c = cursor(dict_obj)
c['b'][0] = 10
c['a'] = (100, 200)
modified_dict_obj = c.build()
assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}
``.apply_update`` example::
from flax.cursor import cursor
from flax.training import train_state
import optax
def update_fn(path, value):
'''Replace params with empty dictionary.'''
if 'params' in path:
return {}
return value
state = train_state.TrainState.create(
apply_fn=lambda x: x,
params={'a': 1, 'b': 2},
tx=optax.adam(1e-3),
)
c = cursor(state)
state2 = c.apply_update(update_fn).build()
assert state2.params == {}
assert state.params == {'a': 1, 'b': 2} # make sure original params are unchanged
View the docstrings for each method to see more examples of their usage.
Args:
obj: the object you want to wrap the Cursor in
Returns:
A Cursor object wrapped around obj.
"""
return Cursor(obj, None)