As LLMs scale to hundreds of billions of parameters, device memory becomes a bottleneck during training. This post explains how to use JAX's host offloading feature to transfer activations from TPU device memory to CPU host memory (Intel Xeon processors) during the forward pass and retrieve them during the backward pass.

5m read timeFrom opensource.googleblog.com
Post cover image
Table of contents
Host offloadingEnabling memory offloading in JAXMeasuring Host Offloading Benefits on TPU v5pWhen to offload activationsCall to ActionAcknowledgments

Sort: