What if I need to customize what fit() does?

You have two options:

1) Subclass the Model class and override the train_step (and test_step) methods

This is a better option if you want to use custom update rules but still want to leverage the functionality provided by fit(), such as callbacks, efficient step fusing, etc.

Note that this pattern does not prevent you from building models with the Functional API, in which case you will use the class you created to instantiate the model with the inputs and outputs. Same goes for Sequential models, in which case you will subclass keras.Sequential and override its train_step instead of keras.Model.

See the following guides:

  • Writing a custom train step in JAX
  • Writing a custom train step in TensorFlow
  • Writing a custom train step in PyTorch

2) Write a low-level custom training loop

This is a good option if you want to be in control of every last little detail – though it can be somewhat verbose.

See the following guides:

  • Writing a custom training loop in JAX
  • Writing a custom training loop in TensorFlow
  • Writing a custom training loop in PyTorch

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *