As LLMs scale to hundreds of billions of parameters, device memory becomes a bottleneck during training. This post explains how to use JAX's host offloading feature to transfer activations from TPU device memory to CPU host memory (Intel Xeon processors) during the forward pass and retrieve them during the backward pass. Experiments on TPU v5p show up to 10% reduction in training time for PaliGemma2 28B fine-tuning and ~5% for Llama2-13B training with MaxText, compared to full rematerialization. The post includes code snippets using checkpoint_name() and save_and_offload_only_these_names(), and notes that offloading is only beneficial for larger models where transfer time is less than recomputation time.

5m read timeFrom opensource.googleblog.com
Post cover image
Table of contents
Host offloadingEnabling memory offloading in JAXMeasuring Host Offloading Benefits on TPU v5pWhen to offload activationsCall to ActionAcknowledgments

Sort: