A Debugging Journey into XLA, TPUs, and JAX
Recently I published a guide on how to define a custom TPU/GPU kernel in Keras. Soon after, an eagle-eyed reader (h/t Aditya Kane) noticed that the benchmark results that were intended to demonstrate a Pallas kernel of a fused matmul + ReLU operation beating a standard vanilla implementation did not actually show the correct results. In fact, the vanilla JAX implementation was much faster:
This was confusing. I was pretty confident that the comparison worked correctly in my testing. After some digging, lo an behold, the comparison did in fact work correctly on the original hardware I wrote it on - a TPU v3:
The Yak Shaving Deepens
The final version for keras.io, however, was rendered on a TPU v5e, not a v3. So why the difference? Actually, it gets even more confusing — when you try running the current guide on a public Colab TPU v5e, it runs into a compiler OOM:
Specifically a VMEM compiler OOM. The v5e has 128MiB of VMEM, though Pallas kernels restrict themselves to a 16MiB block by default.
But if I was getting this VMEM OOM on v5e, how was I able to publish the guide in the first place? The same code ran fine on the same v5e chip for rendering the guide. So why did it work on one v5e and fail on another v5e?
After much debugging, it turns out my original TPU VM was running an older version of JAX (0.6.2) than the current version (0.9.2), and between those two versions JAX enabled double buffering for Pallas kernels on the VMEM, which enables simultaneously loading the next tensor from HBM into VMEM while the current one is still being computed in VMEM. This improves performance, however it doubles the amount of VMEM needed to compute things.
The original code was trying to load one tile from the outer dimension as well as the entire contracting dimension into VMEM (128 x 8192 x 4 bytes x 2 operands = 8MiB), and when double buffering is enabled (x2 = 16MiB), this results in an OOM of the 16MiB VMEM allocation.
At this point, I considered a few options:
First, I tried reducing the tile size (as defined by BlockSpec) however the last dimension of the BlockSpec must be at least 128, which is what it was already set to.
Second, one big source of memory usage is that the entire contracting dimension needs to be loaded into memory. I could write a loop inside the Pallas kernel to manually accumulate the contracting dimension so I could tile it up even further, but that would almost certainly be slower than the built-in XLA matmul implementation.
Third (which I wish I had found the API for earlier) I could just tell Pallas to allocate more than 16MiB of VMEM for the computation.
compiler_params=pl.tpu.CompilerParams(vmem_limit_bytes=20_000_000),This allows the kernel to execute without OOMing, however it’s still slower than JAX:
Benchmarking Matrix Size: 8192x8192
------------------------------
Standard Keras (Matmul + ReLU) Average Latency: 7.940 ms
Pallas Fused (Matmul + ReLU) Average Latency: 26.156 msThis is where it become clear that it’s really hard to beat XLA at its own game. I don’t think it’s realistic to write a custom Pallas kernel of a matmul+ReLU that will beat XLA’s native optimized version. Thus I’ve decided to pivot the guide to demonstrating something else that will more robustly benefit from a custom kernel speedup, like a sparse MoE kernel. Stay tuned.
Afterword: Why did it work on a TPU v3?
As I mentioned in the introduction, I originally wrote the guide on a TPU v3, which gave me the results I was expecting — i.e. the custom Pallas kernel was indeed much faster than the vanilla JAX implementation.
A TPU v3 only has 16MiB of VMEM per core, whereas the TPU v5e has 128MiB. Because of the smaller VMEM size, XLA couldn’t parallelize the computation as efficiently, leading it to be memory bound by the HBM.
Even though the HBM bandwidth on a TPU v3 is 900GB/s, that speed is still far slower than the time it takes to compute a 128x128 tile. Having less VMEM space to parallelize the computation on v3 kept the cores waiting idle on data from the HBM. Being memory-bottlenecked helped show the custom kernel (which cuts the number of HBM loads in half) to be faster on v3.



