Batch normalization#
In this guide, you will learn how to apply batch normalization
using flax.linen.BatchNorm.
Batch normalization is a regularization technique used to speed up training and improve convergence. During training, it computes running averages over feature dimensions. This adds a new form of non-differentiable state that must be handled appropriately.
Throughout the guide, you will be able to compare code examples with and without Flax BatchNorm.
Defining the model with BatchNorm#
In Flax, BatchNorm is a flax.linen.Module that exhibits different runtime
behavior between training and inference. You explicitly specify it via the use_running_average argument,
as demonstrated below.
A common pattern is to accept a train (training) argument in the parent Flax Module, and use
it to define BatchNorm’s use_running_average argument.
Note: In other machine learning frameworks, like PyTorch or
TensorFlow (Keras), this is specified via a mutable state or a call flag (for example, in
torch.nn.Module.eval
or tf.keras.Model by setting the
training flag).
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=4)(x)
x = nn.relu(x)
x = nn.Dense(features=1)(x)
return x
class MLP(nn.Module):
@nn.compact
def __call__(self, x, train: bool):
x = nn.Dense(features=4)(x)
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.relu(x)
x = nn.Dense(features=1)(x)
return x
Once you create your model, initialize it by calling flax.linen.init() to
get the variables structure. Here, the main difference between the code without BatchNorm
and with BatchNorm is that the train argument must be provided.
The batch_stats collection#
In addition to the params collection, BatchNorm also adds a batch_stats collection
that contains the running average of the batch statistics.
Note: You can learn more in the flax.linen variables
API documentation.
The batch_stats collection must be extracted from the variables for later use.
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.PRNGKey(0), x)
params = variables['params']
jax.tree_util.tree_map(jnp.shape, variables)
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.PRNGKey(0), x, train=False)
params = variables['params']
batch_stats = variables['batch_stats']
jax.tree_util.tree_map(jnp.shape, variables)
Flax BatchNorm adds a total of 4 variables: mean and var that live in the
batch_stats collection, and scale and bias that live in the params
collection.
FrozenDict({
'params': {
'Dense_0': {
'bias': (4,),
'kernel': (3, 4),
},
'Dense_1': {
'bias': (1,),
'kernel': (4, 1),
},
},
})
FrozenDict({
'batch_stats': {
'BatchNorm_0': {
'mean': (4,),
'var': (4,),
},
},
'params': {
'BatchNorm_0': {
'bias': (4,),
'scale': (4,),
},
'Dense_0': {
'bias': (4,),
'kernel': (3, 4),
},
'Dense_1': {
'bias': (1,),
'kernel': (4, 1),
},
},
})
Modifying flax.linen.apply#
When using flax.linen.apply to run your model with the train==True
argument (that is, you have use_running_average==False in the call to BatchNorm), you
need to consider the following:
batch_statsmust be passed as an input variable.The
batch_statscollection needs to be marked as mutable by settingmutable=['batch_stats'].The mutated variables are returned as a second output. The updated
batch_statsmust be extracted from here.
y = mlp.apply(
{'params': params},
x,
)
...
y, updates = mlp.apply(
{'params': params, 'batch_stats': batch_stats},
x,
train=True, mutable=['batch_stats']
)
batch_stats = updates['batch_stats']
Training and evaluation#
When integrating models that use BatchNorm into a training loop, the main challenge
is handling the additional batch_stats state. To do this, you need to:
Add a
batch_statsfield to a customflax.training.train_state.TrainStateclass.Pass the
batch_statsvalues to thetrain_state.TrainState.createmethod.
from flax.training import train_state
state = train_state.TrainState.create(
apply_fn=mlp.apply,
params=params,
tx=optax.adam(1e-3),
)
from flax.training import train_state
class TrainState(train_state.TrainState):
batch_stats: Any
state = TrainState.create(
apply_fn=mlp.apply,
params=params,
batch_stats=batch_stats,
tx=optax.adam(1e-3),
)
In addition, update your train_step function to reflect these changes:
Pass all new parameters to
flax.linen.apply(as previously discussed).The
updatesto thebatch_statsmust be propagated out of theloss_fn.The
batch_statsfrom theTrainStatemust be updated.
@jax.jit
def train_step(state: TrainState, batch):
"""Train for a single step."""
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
@jax.jit
def train_step(state: TrainState, batch):
"""Train for a single step."""
def loss_fn(params):
logits, updates = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats},
x=batch['image'], train=True, mutable=['batch_stats'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, (logits, updates)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (logits, updates)), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
state = state.replace(batch_stats=updates['batch_stats'])
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
The eval_step is much simpler. Because batch_stats is not mutable, no
updates
need to be propagated. Make sure you pass the batch_stats to flax.linen.apply,
and the train argument is set to False:
@jax.jit
def eval_step(state: TrainState, batch):
"""Train for a single step."""
logits = state.apply_fn(
{'params': params},
x=batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
@jax.jit
def eval_step(state: TrainState, batch):
"""Train for a single step."""
logits = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats},
x=batch['image'], train=False)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics