flax.linen.activation.logsumexp#
- flax.linen.activation.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False)[source]#
Log-sum-exp reduction.
Computes
\[\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})\]where the \(j\) indices range over one or more dimensions to be reduced.
- Parameters
a – the input array
axis – the axis or axes over which to reduce. May be either
None, an int, or a tuple of ints.b – scaling factors for \(\mathrm{exp}(a)\). Must be broadcastable to the shape of a.
keepdims – If
True, the axes that are reduced are left in the output as dimensions of size 1.return_sign – If
True, the output will be a(result, sign)pair, wheresignis the sign of the sums andresultcontains the logarithms of their absolute values. IfFalseonlyresultis returned and it will contain NaN values if the sums are negative.
- Returns
Either an array
resultor a pair of arrays(result, sign), depending on the value of thereturn_signargument.