How can I use stateful RNNs?

Making a RNN stateful means that the states for the samples of each batch will be reused as initial states for the samples in the next batch.

When using stateful RNNs, it is therefore assumed that:

  • all batches have the same number of samples
  • If x1 and x2 are successive batches of samples, then x2[i] is the follow-up sequence to x1[i], for every i.

To use statefulness in RNNs, you need to:

  • explicitly specify the batch size you are using, by passing a batch_size argument to the first layer in your model. E.g. batch_size=32 for a 32-samples batch of sequences of 10 timesteps with 16 features per timestep.
  • set stateful=True in your RNN layer(s).
  • specify shuffle=False when calling fit().

To reset the states accumulated:

  • use model.reset_states() to reset the states of all layers in the model
  • use layer.reset_states() to reset the states of a specific stateful RNN layer

Example:

import keras
from keras import layers
import numpy as np

x = np.random.random((32, 21, 16))  # this is our input data, of shape (32, 21, 16)
# we will feed it to our model in sequences of length 10

model = keras.Sequential()
model.add(layers.LSTM(32, input_shape=(10, 16), batch_size=32, stateful=True))
model.add(layers.Dense(16, activation='softmax'))

model.compile(optimizer='rmsprop', loss='categorical_crossentropy')

# we train the network to predict the 11th timestep given the first 10:
model.train_on_batch(x[:, :10, :], np.reshape(x[:, 10, :], (32, 16)))

# the state of the network has changed. We can feed the follow-up sequences:
model.train_on_batch(x[:, 10:20, :], np.reshape(x[:, 20, :], (32, 16)))

# let's reset the states of the LSTM layer:
model.reset_states()

# another way to do it in this case:
model.layers[0].reset_states()

Note that the methods predictfittrain_on_batch, etc. will all update the states of the stateful layers in a model. This allows you to do not only stateful training, but also stateful prediction.


Comments

Leave a Reply

Your email address will not be published. Required fields are marked *