training
is a boolean argument in call
that determines whether the call should be run in inference mode or training mode. For example, in training mode, a Dropout
layer applies random dropout and rescales the output. In inference mode, the same layer does nothing. Example:
y = Dropout(0.5)(x, training=True) # Applies dropout at training time *and* inference time
trainable
is a boolean layer attribute that determines the trainable weights of the layer should be updated to minimize the loss during training. If layer.trainable
is set to False
, then layer.trainable_weights
will always be an empty list. Example:
model = Sequential([
ResNet50Base(input_shape=(32, 32, 3), weights='pretrained'),
Dense(10),
])
model.layers[0].trainable = False # Freeze ResNet50Base.
assert model.layers[0].trainable_weights == [] # ResNet50Base has no trainable weights.
assert len(model.trainable_weights) == 2 # Just the bias & kernel of the Dense layer.
model.compile(...)
model.fit(...) # Train Dense while excluding ResNet50Base.
As you can see, “inference mode vs training mode” and “layer weight trainability” are two very different concepts.
You could imagine the following: a dropout layer where the scaling factor is learned during training, via backpropagation. Let’s name it AutoScaleDropout
. This layer would have simultaneously a trainable state, and a different behavior in inference and training. Because the trainable
attribute and the training
call argument are independent, you can do the following:
layer = AutoScaleDropout(0.5)
# Applies dropout at training time *and* inference time
# *and* learns the scaling factor during training
y = layer(x, training=True)
assert len(layer.trainable_weights) == 1
# Applies dropout at training time *and* inference time
# with a *frozen* scaling factor
layer = AutoScaleDropout(0.5)
layer.trainable = False
y = layer(x, training=True)
Special case of the BatchNormalization
layer
For a BatchNormalization
layer, setting bn.trainable = False
will also make its training
call argument default to False
, meaning that the layer will no update its state during training.
This behavior only applies for BatchNormalization
. For every other layer, weight trainability and “inference vs training mode” remain independent.
Leave a Reply