Making a RNN stateful means that the states for the samples of each batch will be reused as initial states for the samples in the next batch.
When using stateful RNNs, it is therefore assumed that:
- all batches have the same number of samples
- If
x1
andx2
are successive batches of samples, thenx2[i]
is the follow-up sequence tox1[i]
, for everyi
.
To use statefulness in RNNs, you need to:
- explicitly specify the batch size you are using, by passing a
batch_size
argument to the first layer in your model. E.g.batch_size=32
for a 32-samples batch of sequences of 10 timesteps with 16 features per timestep. - set
stateful=True
in your RNN layer(s). - specify
shuffle=False
when callingfit()
.
To reset the states accumulated:
- use
model.reset_states()
to reset the states of all layers in the model - use
layer.reset_states()
to reset the states of a specific stateful RNN layer
Example:
import keras
from keras import layers
import numpy as np
x = np.random.random((32, 21, 16)) # this is our input data, of shape (32, 21, 16)
# we will feed it to our model in sequences of length 10
model = keras.Sequential()
model.add(layers.LSTM(32, input_shape=(10, 16), batch_size=32, stateful=True))
model.add(layers.Dense(16, activation='softmax'))
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
# we train the network to predict the 11th timestep given the first 10:
model.train_on_batch(x[:, :10, :], np.reshape(x[:, 10, :], (32, 16)))
# the state of the network has changed. We can feed the follow-up sequences:
model.train_on_batch(x[:, 10:20, :], np.reshape(x[:, 20, :], (32, 16)))
# let's reset the states of the LSTM layer:
model.reset_states()
# another way to do it in this case:
model.layers[0].reset_states()
Note that the methods predict
, fit
, train_on_batch
, etc. will all update the states of the stateful layers in a model. This allows you to do not only stateful training, but also stateful prediction.
Leave a Reply