A deep-dive into implementing and optimizing Flash Attention using NVIDIA's cuTile Python library on Blackwell GPUs. Covers the full kernel implementation including online softmax, causal masking, and grouped-query attention. The core of the post is a 'trap and rescue' optimization journey: naively increasing tile size from 64×64 to 256×128 degrades performance by 18-43%, but applying fast math (flush_to_zero, approximate division), K-loop splitting for causal masks, block ID remapping, and autotuning recovers and exceeds baseline by up to 1.66x. Each optimization step is backed by Nsight Compute profiling data showing registers, occupancy, and compute/memory throughput.

21m read timeFrom developer.nvidia.com
Post cover image
Table of contents
What is attention?Understanding online softmaxCausal attention and grouped-query attentionPart 1: The flash attention kernel in CUDA TileLaunching the kernel: Host-side codePart 2: The “trap and rescue” optimization journey1. The trap of larger tiles2. The rescue with fast math3. K-loop split4. ProgramId remapping5. AutotuningSummary: The optimization stackGetting started

Sort: