Accessing your TPUs in Docker Containers with TPU VM

tl;dr — click here to view the Github Repo.

On June 1, 2021, Google finally released TPU VMs into the Public after quietly announcing the private Alpha late last year.

Inside a TPU VM — 96 vCPU and 335 GB of RAM — prettyyy

The major key difference between TPU VMs and the previous way of utilizing TPUs (i.e. TPU Nodes) was that you never had direct access to the TPU Host itself, and the communication had to take place through grpc.

This meant that you had to set up a separate VM that then communicated with the TPU Host, which often introduced some problems that you might have otherwise with training on GPUs.

examples include:

  • Not being able to access local files on disk
  • Bottleneck introduced through TPU <-> VM networking latency
  • Potentially unexpected Cloud Storage costs due to having to store everything on GCS, and if the bucket is not in the same region as the TPU/VM
  • Unexpected TPU Errors that were often followed by cryptic debugging messages.

While TPU VMs solve of these problems, since it’s so new, documentation about specific use cases are scarce, if it exists at all.

For example, at Growth Engine, we leverage TPUs to experiment with large scale NLP models, and we’ve developed a system where we containerize model training on a single VM host connected to multiple TPU Nodes running training individual jobs.

So of course, I wanted to see how to be able to access the TPU directly through a docker container. Unlike TPU Nodes, connecting to the TPU in a TPU VM is much trickier since you can’t rely on simply setting

If you want the short version, just click on the Github repo link and see the Docker Compose file for yourself.

Things I tried that didn’t work:

  • Setting the Jax backend target to be the TPU VM’s internal IP
  • Trying to attach the uncommon devices that I found in /dev that looked like they could be TPUs… getting warmer
  • Initializing TF TPUClusterResolver prior to Initializing Jax. It half worked with the above devices attached but had other errors.

After digging through TPU documentation, Tensorflow’s source code, etc., I finally found the answer in Jax’s source code. Specifically, this tiny piece.

So what ultimately worked is a combination of:

  • Attaching “/dev:/dev” as a device
  • privileged: true
  • Mounting:
    — /var/run/docker.sock:/var/run/docker.sock
    — /usr/share/tpu/:/usr/share/tpu/
    — /lib/
  • Environments:
    — TPU_NAME=tpu_name
    — XRT_TPU_CONFIG=”localservice;0;localhost:51011" (Likely necessary for Pytorch)
    — TF_XLA_FLAGS= — tf_xla_enable_xla_devices (Necessary for TF/Jax)

All together, it looks something like:

And finally, you can access the TPU from the TPU VM within your Docker Container:

Hope my hours of headaches and debugging helps a few people who are working with TPU VMs!

Final Note: You Probably want a VPC & Proxy

Unlike regular VMs within GCP, TPU VMs do not allow you to modify the Firewall settings, meaning that public ports aren’t accessible outside of the GCP network.

So unfortunately, you can’t simply just add the TPU VM’s external IP to your DNS and call it a day.

What I eventually did was have another VM run as a proxy using an nginx container that points to the TPU VM’s internal IP address.

Both the VM and the TPU VM were on the same VPC network, so I’m not sure how this might work if they aren’t.

Special Thanks to TRC for granting access to TPUs and TPU VMs

Chief Architect @ Growth Engine