flax.linen.initializers.variance_scaling#
- flax.linen.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Initializer that adapts its scale to the shape of the weights tensor.
With
distribution="truncated_normal"ordistribution="normal", samples are drawn from a (truncated) normal distribution with a mean of zero and a standard deviation (after truncation, if applicable) of \(\sqrt{\frac{scale}{n}}\), where n is:the number of input units in the weights tensor, if
mode="fan_in",the number of output units, if
mode="fan_out", orthe average of the numbers of input and output units, if
mode="fan_avg".
This initializer can be configured with
in_axis,out_axis, andbatch_axisto work with general convolutional or dense layers; axes that are not in any of those arguments are assumed to be the “receptive field” (convolution kernel spatial axes).With
distribution="truncated_normal", the absolute values of the samples are truncated at 2 standard deviations before scaling.With
distribution="uniform", samples are drawn from:a uniform interval, if dtype is real, or
a uniform disk, if dtype is complex,
with a mean of zero and a standard deviation of \(\sqrt{\frac{scale}{n}}\) where n is defined above.
- Parameters
scale – scaling factor (positive float).
mode – one of
"fan_in","fan_out", and"fan_avg".distribution – random distribution to use. One of
"truncated_normal","normal"and"uniform".in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.