Introduction

Google’s Tensor Processing Units (TPUs) are custom accelerators designed for large-scale machine learning workloads. Unlike CPUs and GPUs, TPUs use a unique systolic array architecture to maximize throughput for matrix operations, making them ideal for deep learning tasks. This post is a practical walkthrough of TPU architecture, setup, and usage, with tips for maximizing performance and minimizing cost.

Summary

  • TPUs excel at large, static-shape matrix multiplications, making them perfect for transformer models and similar architectures.
  • JAX is the preferred framework for TPUs; PyTorch and TensorFlow have less robust support.
  • Efficient TPU usage requires careful attention to data layout, shape, and storage.
  • Google Cloud provides several TPU versions, each with different capabilities and pricing.
  • Automating setup and teardown is crucial to avoid unnecessary costs.

TPU Architecture Overview

  1. CPUs have a single fast processor, limited by memory bandwidth.
  2. GPUs contain many small cores and high memory bandwidth.
  3. TPUs, e.g., v3, feature two 128x128 systolic ALUs—grids of ALUs optimized for matrix operations.

The following animation shows how network weights are laid out inside a TPU:

And this animation shows the systolic movement of data inputs into TPUs:

These animations illustrate how TPUs perform pairwise convolutions between all weights and inputs. Notably, TPUs minimize slow HBM (High Bandwidth Memory) access by passing data directly between ALUs, both vertically (inputs) and horizontally (accumulated results). This design enables high throughput for large matrix multiplications.

Practical Recommendations

  • Avoid reshape operations; keep tensor shapes constant, as shapes are compiled into the model.
  • Use large matrices with dimensions as multiples of 8 for best performance.
  • Prefer matrix multiplications; other operations (add, sub, reshape) are slower.

System Architecture and Key Terms

  • Batch Inference: On-demand, but slow.
  • TensorCore (TC): Contains matrix-multiply units (MXUs), vector, and scalar units. MXUs are 128x128 or 256x256.
  • TPU Cube: Topology of interconnected TPUs (v4+).
  • Multislice: Connection between multiple TPU slices.
  • Queued Resource: Manages requests for TPU environments.
  • Host/Sub-host: Linux VM(s) controlling TPUs; a host can manage multiple TPUs.
  • Slice: Collection of chips connected via fast interconnects (ICI).
  • SparseCore: Specialized hardware for large embedding tables (v5+).
  • TPU Pod: Cluster of TPUs for large-scale training.
  • TPU VM/Worker: Linux VM with direct TPU access.
  • TPU Versions: Each generation differs significantly in architecture and capabilities.

TPU Versions and Specs

TPU v6e

  • 1 TC per chip, 2 MXUs per TC
  • 32GB memory, 1640 Gbps BW, 918 TFLOPs
  • Max pod: 256 chips, 8960 TPUs

TPU v5p

  • 95GB memory, 2765 Gbps BW, 459 TFLOPs
  • Max pod: 8960 TPUs

TPU v5e

  • 16GB memory, 819 Gbps BW, 197 TFLOPs
  • Max pod: 256 chips

TPU v4

  • 32GB memory, 1200 Gbps BW, 275 TFLOPs
  • Max pod: 4096 chips

TPU VM Images and Hardware

Default VM image: tpu-ubuntu2204-base. Example VM specs for v5litepod-1:

  • 1 v5e TPU, 24 CPUs, 48GB RAM
  • Larger pods scale up CPU/RAM and NUMA nodes

Regions and Zones

  • EU: europe-west4
  • v6e: us-east1-d, us-east5-b

Supported Models

Google’s MaxText repo provides optimized TPU training code for Llama2, Mistral, Gemma, etc., using JAX.

Getting Started: Requesting and Using TPUs

  1. Enable “Cloud TPU” in Google Cloud.
  2. Request quota for your desired TPU type and zone.
  3. Use gcloud to provision, monitor, and SSH into TPU VMs.
  4. Always delete resources when done to avoid charges.

Example setup and teardown commands:

# Set up environment variables
export PROJECT_ID= # your project
export SERVICE_ACCOUNT=xyz-compute@developer.gserviceaccount.com
export RESOURCE_NAME=v5litepod-1-resource

# Authenticate and enable TPUs
gcloud auth login
gcloud services enable tpu.googleapis.com
gcloud beta services identity create --service tpu.googleapis.com --project $PROJECT_ID

# Request a TPU node
gcloud alpha compute tpus queued-resources create $RESOURCE_NAME \
     --node-id v5litepod \
     --project $PROJECT_ID \
     --zone us-central1-a \
     --accelerator-type v5litepod-1 \
     --runtime-version v2-alpha-tpuv5-lite \
     --valid-until-duration 1d \
     --service-account $SERVICE_ACCOUNT

# Check status
gcloud alpha compute tpus queued-resources describe $RESOURCE_NAME \
     --project $PROJECT_ID \
     --zone us-central1-a

# SSH into TPU node
gcloud alpha compute tpus tpu-vm ssh v5litepod \
     --project $PROJECT_ID \
     --zone  us-central1-a

# Delete TPU resource
gcloud alpha compute tpus queued-resources delete v5litepod-1-resource \
     --project $PROJECT_ID \
     --zone us-central1-a --force --quiet

Exploring the TPU VM

  • Check disk space: df -h
  • List hardware: lspci
  • CPU info: hwinfo | less or nproc
  • RAM: cat /proc/meminfo | grep MemTotal

Installing Packages

pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas]
pip install timm

Automating Setup with Startup Scripts

You can use a startup script to automate package installation and environment setup:

gcloud alpha compute tpus queued-resources create $RESOURCE_NAME \
     --node-id v5litepod \
     --project $PROJECT_ID \
     --zone us-central1-a \
     --accelerator-type v5litepod-1 \
     --runtime-version v2-alpha-tpuv5-lite \
     --valid-until-duration 1d \
     --service-account $SERVICE_ACCOUNT \
     --metadata startup-script='#! /bin/bash
      pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
      pip install torch_xla[pallas]
      pip install timm
      EOF'

Storage Options

  • Boot Disk: 100GB by default; usable for small datasets.
  • Persistent Disk (PD): Add for larger, persistent storage.
  • Cloud Storage (gs://): Unlimited size, but slower than PD.
  • GCSFUSE: Mount a bucket as a local directory for easy access.
  • Filestore: High-performance, but minimum 1TB.

Training and Inference

  • Use JAX for best TPU support.
  • For multi-host pods, use --worker=all to run commands on all VMs.
  • Always match JAX/Flax versions to avoid compatibility issues.

Cost and Quota Considerations

  • v5p and v6e chips are expensive and may require quota increases.
  • Always automate resource deletion to avoid unexpected charges.

Final Thoughts

TPUs offer unmatched performance for large-scale deep learning, but require careful setup and management. By understanding the architecture, using the right frameworks, and automating your workflow, you can maximize both performance and cost-efficiency.