flax.linen.apply#
- flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[source]#
Creates an apply function to call
fnwith a bound module.Unlike
Module.applythis function returns a new function with the signature(variables, *args, rngs=None, **kwargs) -> Twhere T is the return type offn. Ifmutableis notFalsethe return type is a tuple where the second item is aFrozenDictwith the mutated variables.The apply function that is returned can be directly composed with JAX transformations like
jax.jit:def f(foo, x): z = foo.encode(x) y = foo.decode(z) # ... return y foo = Foo() f_jitted = jax.jit(nn.apply(f, foo)) f_jitted(variables, x)
- Parameters
fn – The function that should be applied. The first argument passed will be an module instance of the
modulewith variables and RNGs bound to it.module – The
Modulethat will be used to bind variables and RNGs to. TheModulepassed as the first argument tofnwill be a clone of module.mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool: all/no collections are mutable.str: The name of a single mutable collection.list: A list of names of mutable collections.capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.
- Returns
The apply function wrapping
fn.