TPUs are a fast & efficient hardware accelerator for deep learning that is publicly available on Google Cloud. You can use TPUs via Colab, Kaggle notebooks, and GCP Deep Learning VMs (provided the TPU_NAME
environment variable is set on the VM).
All Keras backends (JAX, TensorFlow, PyTorch) are supported on TPU, but we recommend JAX or TensorFlow in this case.
Using JAX:
When connected to a TPU runtime, just insert this code snippet before model construction:
import jax
distribution = keras.distribution.DataParallel(devices=jax.devices())
keras.distribution.set_distribution(distribution)
Using TensorFlow:
When connected to a TPU runtime, use TPUClusterResolver
to detect the TPU. Then, create TPUStrategy
and construct your model in the strategy scope:
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
print("Device:", tpu.master())
strategy = tf.distribute.TPUStrategy(tpu)
except:
strategy = tf.distribute.get_strategy()
print("Number of replicas:", strategy.num_replicas_in_sync)
with strategy.scope():
# Create your model here.
...
Importantly, you should:
- Make sure you are able to read your data fast enough to keep the TPU utilized.
- Consider running multiple steps of gradient descent per graph execution in order to keep the TPU utilized. You can do this via the
experimental_steps_per_execution
argumentcompile()
. It will yield a significant speed up for small models.
Leave a Reply