flax.linen.vjp#
- flax.linen.vjp(fn, mdl, *primals, has_aux=False, reduce_axes=(), vjp_variables='params', variables=True, rngs=True)[source]#
A lifted version of
jax.vjp.See
jax.vjpfor the unlifted vector-Jacobiam product (backward gradient).Note that a gradient is returned for all variables in the collections specified by vjp_variables. However, the backward funtion only expects a cotangent for the return value of fn. If variables require a co-tangent as well they can be returned from fn using Module.variables.
Example:
class LearnScale(nn.Module): @nn.compact def __call__(self, x, y): p = self.param('scale', nn.initializers.zeros_init(), ()) return p * x * y class Foo(nn.Module): @nn.compact def __call__(self, x, y): z, bwd = nn.vjp(lambda mdl, x, y: mdl(x, y), LearnScale(), x, y) params_grad, x_grad, y_grad = bwd(jnp.ones(z.shape)) return z, params_grad, x_grad, y_grad
- Parameters
fn – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments.
mdl – The module of which the variables will be differentiated.
*primals – A sequence of primal values at which the Jacobian of
fnshould be evaluated. The length ofprimalsshould be equal to the number of positional parameters tofn. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof.has_aux – Optional, bool. Indicates whether
fnreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
fnimplicitly broadcasts a value over that axis, the backward pass will perform apsumof the corresponding gradient. Otherwise, the VJP will be per-example over named axes. For example, if'batch'is a named batch axis,vjp(f, *args, reduce_axes=('batch',))will create a VJP function that sums over the batch whilevjp(f, *args)will create a per-example VJP.vjp_variables – The vjpfun will return a cotangent vector for all variable collections specified by this filter.
variables – other variables collections that are available inside fn but do not receive a cotangent.
rngs – the prngs that are available inside fn.
- Returns
If
has_auxisFalse, returns a(primals_out, vjpfun)pair, whereprimals_outisfn(*primals).vjpfunis a function from a cotangent vector with the same shape asprimals_outto a tuple of cotangent vectors with the same shape asprimals, representing the vector-Jacobian product offnevaluated atprimals. Ifhas_auxisTrue, returns a(primals_out, vjpfun, aux)tuple whereauxis the auxiliary data returned byfn.