How can I ensure my training run can recover from program interruptions?

To ensure the ability to recover from an interrupted training run at any time (fault tolerance), you should use a keras.callbacks.BackupAndRestore callback that regularly saves your training progress, including the epoch number and weights, to disk, and loads it the next time you call Model.fit().

import numpy as np
import keras

class InterruptingCallback(keras.callbacks.Callback):
  """A callback to intentionally introduce interruption to training."""
  def on_epoch_end(self, epoch, log=None):
    if epoch == 15:
      raise RuntimeError('Interruption')

model = keras.Sequential([keras.layers.Dense(10)])
optimizer = keras.optimizers.SGD()
model.compile(optimizer, loss="mse")

x = np.random.random((24, 10))
y = np.random.random((24,))

backup_callback = keras.callbacks.experimental.BackupAndRestore(
    backup_dir='/tmp/backup')
try:
  model.fit(x, y, epochs=20, steps_per_epoch=5,
            callbacks=[backup_callback, InterruptingCallback()])
except RuntimeError:
  print('***Handling interruption***')
  # This continues at the epoch where it left off.
  model.fit(x, y, epochs=20, steps_per_epoch=5,
            callbacks=[backup_callback])

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *