Activation functions#
Activation functions.
- class flax.linen.activation.PReLU(param_dtype=<class 'jax.numpy.float32'>, negative_slope_init=0.01, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Parametric Rectified Linear Unit (PReLU) activation function.
Note that PReLU is a Flax layer and not a simple activation function, so it needs to be initialized before being called.
- Example usage::
x = nn.PReLU()(x)
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Any
- negative_slope_init#
the value to initialize the negative slope (default 0.01).
- Type
float
- flax.linen.activation.celu(x, alpha=1.0)[source]#
Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]For more information, see Continuously Differentiable Exponential Linear Units.
- Parameters
x – input array
alpha – array or scalar (default: 1.0)
- flax.linen.activation.elu(x, alpha=1.0)[source]#
Exponential linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]- Parameters
x – input array
alpha – scalar or array of alpha values (default: 1.0)
- flax.linen.activation.gelu(x, approximate=True)[source]#
Gaussian error linear unit activation function.
If
approximate=False, computes the element-wise function:\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right)\]If
approximate=True, uses the approximate formulation of GELU:\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]For more information, see Gaussian Error Linear Units (GELUs), section 2.
- Parameters
x – input array
approximate – whether to use the approximate or exact formulation.
- flax.linen.activation.glu(x, axis=-1)[source]#
Gated linear unit activation function.
- Parameters
x – input array
axis – the axis along which the split should be computed (default: -1)
- flax.linen.activation.hard_sigmoid(x)[source]#
Hard Sigmoid activation function.
Computes the element-wise function
\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]- Parameters
x – input array
- flax.linen.activation.hard_silu(x)[source]#
Hard SiLU activation function
Computes the element-wise function
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]- Parameters
x – input array
- flax.linen.activation.hard_swish(x)#
Hard SiLU activation function
Computes the element-wise function
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]- Parameters
x – input array
- flax.linen.activation.hard_tanh(x)[source]#
Hard \(\mathrm{tanh}\) activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & -1 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]- Parameters
x – input array
- flax.linen.activation.leaky_relu(x, negative_slope=0.01)[source]#
Leaky rectified linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]where \(\alpha\) =
negative_slope.- Parameters
x – input array
negative_slope – array or scalar specifying the negative slope (default: 0.01)
- flax.linen.activation.log_sigmoid(x)[source]#
Log-sigmoid activation function.
Computes the element-wise function:
\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]- Parameters
x – input array
- flax.linen.activation.log_softmax(x, axis=-1, where=None, initial=None)[source]#
Log-Softmax function.
Computes the logarithm of the
softmaxfunction, which rescales elements to the range \([-\infty, 0)\).\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]- Parameters
x – input array
axis – the axis or axes along which the
log_softmaxshould be computed. Either an integer or a tuple of integers.where – Elements to include in the
log_softmax.initial – The minimum value used to shift the input array. Must be present when
whereis not None.
- flax.linen.activation.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False)[source]#
Log-sum-exp reduction.
Computes
\[\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})\]where the \(j\) indices range over one or more dimensions to be reduced.
- Parameters
a – the input array
axis – the axis or axes over which to reduce. May be either
None, an int, or a tuple of ints.b – scaling factors for \(\mathrm{exp}(a)\). Must be broadcastable to the shape of a.
keepdims – If
True, the axes that are reduced are left in the output as dimensions of size 1.return_sign – If
True, the output will be a(result, sign)pair, wheresignis the sign of the sums andresultcontains the logarithms of their absolute values. IfFalseonlyresultis returned and it will contain NaN values if the sums are negative.
- Returns
Either an array
resultor a pair of arrays(result, sign), depending on the value of thereturn_signargument.
- flax.linen.activation.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[source]#
One-hot encodes the given indices.
Each index in the input
xis encoded as a vector of zeros of lengthnum_classeswith the element atindexset to one:>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
Indices outside the range [0, num_classes) will be encoded as zeros:
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- Parameters
x – A tensor of indices.
num_classes – Number of classes in the one-hot dimension.
dtype – optional, a float dtype for the returned values (default
jnp.float_).axis – the axis or axes along which the function should be computed.
- flax.linen.activation.relu(x)[source]#
Rectified linear unit activation function.
Computes the element-wise function:
\[\mathrm{relu}(x) = \max(x, 0)\]except under differentiation, we take:
\[\nabla \mathrm{relu}(0) = 0\]For more information see Numerical influence of ReLU’(0) on backpropagation.
- Parameters
x – input array
- flax.linen.activation.selu(x)[source]#
Scaled exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).
For more information, see Self-Normalizing Neural Networks.
- Parameters
x – input array
- flax.linen.activation.sigmoid(x)[source]#
Sigmoid activation function.
Computes the element-wise function:
\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]- Parameters
x – input array
- flax.linen.activation.silu(x)[source]#
SiLU activation function.
Computes the element-wise function:
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]- Parameters
x – input array
- flax.linen.activation.soft_sign(x)[source]#
Soft-sign activation function.
Computes the element-wise function
\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]- Parameters
x – input array
- flax.linen.activation.softmax(x, axis=-1, where=None, initial=None)[source]#
Softmax function.
Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axissum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- Parameters
x – input array
axis – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.
where – Elements to include in the
softmax.initial – The minimum value used to shift the input array. Must be present when
whereis not None.
- flax.linen.activation.softplus(x)[source]#
Softplus activation function.
Computes the element-wise function
\[\mathrm{softplus}(x) = \log(1 + e^x)\]- Parameters
x – input array
- flax.linen.activation.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[source]#
Normalizes an array by subtracting
meanand dividing by \(\sqrt{\mathrm{variance}}\).
- flax.linen.activation.swish(x)#
SiLU activation function.
Computes the element-wise function:
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]- Parameters
x – input array
- flax.linen.activation.tanh(x, /)#
Compute hyperbolic tangent element-wise.
LAX-backend implementation of
numpy.tanh().Original docstring below.
Equivalent to
np.sinh(x)/np.cosh(x)or-1j * np.tan(1j*x).- Parameters
x (array_like) – Input array.
- Returns
y – The corresponding hyperbolic tangent values. This is a scalar if x is a scalar.
- Return type
ndarray
References
- 1
M. Abramowitz and I. A. Stegun, Handbook of Mathematical Functions. New York, NY: Dover, 1972, pg. 83. https://personal.math.ubc.ca/~cbm/aands/page_83.htm
- 2
Wikipedia, “Hyperbolic function”, https://en.wikipedia.org/wiki/Hyperbolic_function
Summary
|
Parametric Rectified Linear Unit (PReLU) activation function. |
|
Continuously-differentiable exponential linear unit activation. |
|
Exponential linear unit activation function. |
|
Gaussian error linear unit activation function. |
|
Gated linear unit activation function. |
|
Hard Sigmoid activation function. |
|
Hard SiLU activation function |
|
Hard SiLU activation function |
|
Hard \(\mathrm{tanh}\) activation function. |
|
Leaky rectified linear unit activation function. |
|
Log-sigmoid activation function. |
|
Log-Softmax function. |
|
Log-sum-exp reduction. |
|
One-hot encodes the given indices. |
Rectified linear unit activation function. |
|
Rectified Linear Unit 6 activation function. |
|
|
Scaled exponential linear unit activation. |
|
Sigmoid activation function. |
|
SiLU activation function. |
|
Soft-sign activation function. |
|
Softmax function. |
|
Softplus activation function. |
|
Normalizes an array by subtracting |
|
SiLU activation function. |
|
Compute hyperbolic tangent element-wise. |