flax.linen.cond#
- flax.linen.cond(pred, true_fun, false_fun, mdl, *operands, variables=True, rngs=True)[source]#
Lifted version of
jax.lax.cond.The returned values from
true_funandfalse_funmust have the same Pytree structure, shapes, and dtypes. The variables created or updated inside the branches must also have the same structure. Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch causes the parameter structure to be different.Example:
class CondExample(nn.Module): @nn.compact def __call__(self, x, pred): self.variable('state', 'true_count', lambda: 0) self.variable('state', 'false_count', lambda: 0) def true_fn(mdl, x): mdl.variable('state', 'true_count').value += 1 return nn.Dense(2, name='dense')(x) def false_fn(mdl, x): mdl.variable('state', 'false_count').value += 1 return -nn.Dense(2, name='dense')(x) return nn.cond(pred, true_fn, false_fn, self, x)
- Parameters
pred – determines if true_fun or false_fun is evaluated.
true_fun – The function evalauted when
predis True. The signature is (module, *operands) -> T.false_fun – The function evalauted when
predis False. The signature is (module, *operands) -> T.mdl – A Module target to pass.
*operands – The arguments passed to
true_funandfalse_funvariables – The variable collections passed to the conditional branches (default: all)
rngs – The PRNG sequences passed to the conditionals (default: all)
- Returns
The result of the evaluated branch (
true_funorfalse_fun).