What’s the difference between the training argument in call() and the trainable attribute?

training is a boolean argument in call that determines whether the call should be run in inference mode or training mode. For example, in training mode, a Dropout layer applies random dropout and rescales the output. In inference mode, the same layer does nothing. Example:

y = Dropout(0.5)(x, training=True)  # Applies dropout at training time *and* inference time

trainable is a boolean layer attribute that determines the trainable weights of the layer should be updated to minimize the loss during training. If layer.trainable is set to False, then layer.trainable_weights will always be an empty list. Example:

model = Sequential([
    ResNet50Base(input_shape=(32, 32, 3), weights='pretrained'),
    Dense(10),
])
model.layers[0].trainable = False  # Freeze ResNet50Base.

assert model.layers[0].trainable_weights == []  # ResNet50Base has no trainable weights.
assert len(model.trainable_weights) == 2  # Just the bias & kernel of the Dense layer.

model.compile(...)
model.fit(...)  # Train Dense while excluding ResNet50Base.

As you can see, “inference mode vs training mode” and “layer weight trainability” are two very different concepts.

You could imagine the following: a dropout layer where the scaling factor is learned during training, via backpropagation. Let’s name it AutoScaleDropout. This layer would have simultaneously a trainable state, and a different behavior in inference and training. Because the trainable attribute and the training call argument are independent, you can do the following:

layer = AutoScaleDropout(0.5)

# Applies dropout at training time *and* inference time  
# *and* learns the scaling factor during training
y = layer(x, training=True)

assert len(layer.trainable_weights) == 1
# Applies dropout at training time *and* inference time  
# with a *frozen* scaling factor

layer = AutoScaleDropout(0.5)
layer.trainable = False
y = layer(x, training=True)

Special case of the BatchNormalization layer

For a BatchNormalization layer, setting bn.trainable = False will also make its training call argument default to False, meaning that the layer will no update its state during training.

This behavior only applies for BatchNormalization. For every other layer, weight trainability and “inference vs training mode” remain independent.


Comments

Leave a Reply

Your email address will not be published. Required fields are marked *