How can I freeze layers and do fine-tuning?

Setting the trainable attribute

All layers & models have a layer.trainable boolean attribute:

>>> layer = Dense(3)
>>> layer.trainable
True

On all layers & models, the trainable attribute can be set (to True or False). When set to False, the layer.trainable_weights attribute is empty:

>>> layer = Dense(3)
>>> layer.build(input_shape=(None, 3)) # Create the weights of the layer
>>> layer.trainable
True
>>> layer.trainable_weights
[<KerasVariable shape=(3, 3), dtype=float32, path=dense/kernel>, <KerasVariable shape=(3,), dtype=float32, path=dense/bias>]
>>> layer.trainable = False
>>> layer.trainable_weights
[]

Setting the trainable attribute on a layer recursively sets it on all children layers (contents of self.layers).

1) When training with fit():

To do fine-tuning with fit(), you would:

  • Instantiate a base model and load pre-trained weights
  • Freeze that base model
  • Add trainable layers on top
  • Call compile() and fit()

Like this:

model = Sequential([
    ResNet50Base(input_shape=(32, 32, 3), weights='pretrained'),
    Dense(10),
])
model.layers[0].trainable = False  # Freeze ResNet50Base.

assert model.layers[0].trainable_weights == []  # ResNet50Base has no trainable weights.
assert len(model.trainable_weights) == 2  # Just the bias & kernel of the Dense layer.

model.compile(...)
model.fit(...)  # Train Dense while excluding ResNet50Base.

You can follow a similar workflow with the Functional API or the model subclassing API. Make sure to call compile() after changing the value of trainable in order for your changes to be taken into account. Calling compile() will freeze the state of the training step of the model.

2) When using a custom training loop:

When writing a training loop, make sure to only update weights that are part of model.trainable_weights (and not all model.weights). Here’s a simple TensorFlow example:

model = Sequential([
    ResNet50Base(input_shape=(32, 32, 3), weights='pretrained'),
    Dense(10),
])
model.layers[0].trainable = False  # Freeze ResNet50Base.

# Iterate over the batches of a dataset.
for inputs, targets in dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

Interaction between trainable and compile()

Calling compile() on a model is meant to “freeze” the behavior of that model. This implies that the trainable attribute values at the time the model is compiled should be preserved throughout the lifetime of that model, until compile is called again. Hence, if you change trainable, make sure to call compile() again on your model for your changes to be taken into account.

For instance, if two models A & B share some layers, and:

  • Model A gets compiled
  • The trainable attribute value on the shared layers is changed
  • Model B is compiled

Then model A and B are using different trainable values for the shared layers. This mechanism is critical for most existing GAN implementations, which do:

discriminator.compile(...)  # the weights of `discriminator` should be updated when `discriminator` is trained
discriminator.trainable = False
gan.compile(...)  # `discriminator` is a submodel of `gan`, which should not be updated when `gan` is trained

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *