Fine-Tune HuggingFace Models on TPUs with TorchAX (No JAX Rewrite Required)

If you've ever tried to fine-tune a large language model on Google Cloud TPUs, you've probably hit the same wall: TPUs love JAX, but your model is written in PyTorch. The traditional solution? Rewrite everything in JAX/Flax. The new solution? TorchAX—a bridge that lets you run PyTorch HuggingFace models directly on TPUs without touching a single line of JAX code.

This tutorial walks through fine-tuning models like Gemma on TPUs using TorchAX and LoRA, complete with evaluation, model persistence, and a working Colab notebook you can run today.

Why TPUs (and Why They've Been Hard to Use)

Google's Tensor Processing Units offer exceptional performance for training and fine-tuning large models, especially on Google Cloud where TPU access comes at a significant cost advantage compared to equivalent GPU clusters. The catch? TPUs were designed with JAX in mind, and the PyTorch ecosystem—where most HuggingFace models live—has historically required workarounds or full rewrites.

Enter TorchAX: a PyTorch/XLA integration layer that compiles PyTorch operations to run natively on TPUs. Instead of converting your transformers model to JAX, you keep your familiar PyTorch training loop and let torchax handle the TPU translation layer.

The benefits are immediate:

  • No model conversion: Load any HuggingFace model with AutoModelForCausalLM
  • Existing PyTorch knowledge: Your training loops, optimizers, and loss functions work as-is
  • LoRA support: Parameter-efficient fine-tuning works out of the box
  • Cost efficiency: TPU v4 pods cost significantly less than equivalent A100 clusters for large-scale training

Setting Up TorchAX for HuggingFace Models

The setup is surprisingly straightforward. You'll need three key dependencies:

pip install torch_xla torchax transformers peft

torchax is the bridge library that compiles PyTorch operations to XLA (the compiler framework TPUs use). peft provides the LoRA implementation for parameter-efficient fine-tuning.

The critical difference from standard PyTorch training is initializing the XLA device:

import torch_xla.core.xla_model as xm

device = xm.xla_device()  # Returns 'xla:0' for TPU
model = AutoModelForCausalLM.from_pretrained('google/gemma-2b')
model = model.to(device)

Once your model is on the XLA device, torchax intercepts PyTorch operations and routes them to the TPU. Your existing training code—dataloaders, loss calculation, backprop—remains unchanged.

Fine-Tuning with LoRA: Keeping It Efficient

Fine-tuning a 2B or 7B parameter model on every weight is expensive and often unnecessary. LoRA (Low-Rank Adaptation) inserts small trainable matrices into the model's attention layers while freezing the base weights. For Gemma or Llama models, this typically means training <1% of total parameters.

With the peft library, applying LoRA is a four-line change:

from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=16,              # Rank of the low-rank matrices
    lora_alpha=32,     # Scaling factor
    target_modules=["q_proj", "v_proj"],  # Which layers to adapt
    lora_dropout=0.1,
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()  # Verify only LoRA params are trainable

The target_modules parameter is model-specific. For Gemma and Llama-family models, targeting the query and value projection layers (q_proj, v_proj) is standard. Mistral and GPT-style models may use different naming conventions—check the model architecture first.

The Training Loop: XLA-Specific Considerations

Your training loop looks nearly identical to standard PyTorch, with one critical addition:

for batch in dataloader:
    inputs = {k: v.to(device) for k, v in batch.items()}
    outputs = model(**inputs)
    loss = outputs.loss
    loss.backward()
    
    xm.optimizer_step(optimizer)  # XLA-aware optimizer step
    optimizer.zero_grad()
    
    if step % 10 == 0:
        xm.mark_step()  # Explicit XLA graph compilation boundary

Two torchax-specific calls:

  1. xm.optimizer_step(optimizer): Replaces optimizer.step() to ensure gradients are synchronized across TPU cores
  2. xm.mark_step(): Tells XLA to compile and execute the accumulated graph—think of it as a barrier that forces computation to complete

Without mark_step(), XLA may buffer operations indefinitely, leading to apparent hangs or OOM errors.

Evaluation and Model Persistence

Evaluation requires moving tensors back to CPU for metrics calculation:

model.eval()
with torch.no_grad():
    for batch in eval_dataloader:
        outputs = model(**batch.to(device))
        logits = outputs.logits.cpu()  # Move to CPU for metric computation
        # Calculate perplexity, accuracy, etc.

Saving the fine-tuned LoRA adapter is handled by peft:

model.save_pretrained('./gemma-lora-adapter')
tokenizer.save_pretrained('./gemma-lora-adapter')

To reload later:

from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained('google/gemma-2b')
model = PeftModel.from_pretrained(base_model, './gemma-lora-adapter')

The LoRA adapter is typically 10-50MB, compared to multi-gigabyte full model checkpoints.

Try It Yourself: Colab Notebook

Google Colab offers free TPU runtime access (TPU v2-8), making this technique immediately accessible. The original tutorial includes a ready-to-run notebook that fine-tunes Gemma-2B on the Alpaca instruction dataset in under 30 minutes.

Key considerations for Colab TPU runtimes:

  • Enable TPU via Runtime → Change runtime type → TPU
  • TPU v2 has 8GB memory per core—use gradient checkpointing for larger models
  • Free tier limits to 12-hour sessions

The Bottom Line

TorchAX eliminates the JAX rewrite barrier that's kept PyTorch developers off TPUs for years. If you're fine-tuning HuggingFace models at scale, especially instruction-tuning or domain adaptation, TPUs with torchax + LoRA offer a compelling cost/performance alternative to GPU clusters—without abandoning the PyTorch ecosystem you already know.

The workflow is production-ready: parameter-efficient training, full save/reload support, and seamless integration with the transformers library. For teams already on Google Cloud or researchers with TPU access, this is the easiest path to large-scale fine-tuning today.