FlashAttention challenges ML researchers to think about systems-level improvements
Fast and Memory-Efficient Exact Attention with IO-Awareness
Introduction and Motivation
Transformers are the hottest thing on the planet right now. One of their biggest weaknesses is that they have a limited context window — typically 1-2K tokens. This is because the attention block (the core of the transformer model) scales quadratically with respect to the sequence length.
Most of this cost is not in the computation itself, but in the reading and writing of these attention matrices from GPU SRAM to memory and back.
Core Insight: By adapting the attention algorithm to reduce number of writes to memory, we can speed up the attention block and transformer architecture as a whole.
Development Details
This is what a typical attention block looks like. We start with two vectors Q and K. Their cross product gives us attention matrix A, which is then masked, put through a softmax, and then a dropout. Finally, this matrix is multiplied with the vector V to get our desired output.
In this sequence, we must typically materialize at least one of these operations (write it to memory), and often the attention matrix is big enough such that to compute something like softmax, multiple trips to memory are needed. This matrix is also kept in memory for backpropagation.
FlashAttention uses two clever techniques:
Tiling: Decompose the softmax and compute it by blocks instead of in one go. In this way, the I/O cost of calculating large attention blocks is reduced.
Recomputation: Instead of storing the matrix for backpropagation, we simply recalculate it, which is faster than the I/Os.
Evaluation
The core result: FlashAttention is faster. A lot faster than previously optimized models like Nvidia’s MLPerf.
On a particularly difficult dataset for long sequences: Pathfinder, Transformer models have always either timed out or returned random (50%) performance. Because of the increased context length capacity, FlashAttention is the first transformer model to provide non-random performance.
Limitations and Future Work
Because of the specificity in memory transfer and I/O required, FlashAttention is written in CUDA instead of a higher-level variant like PyTorch. This makes it more challenging to build upon or modify with ease.
Other parts of the transformer architecture can benefit from I/O awareness beyond the transformer.
Most models are not trained on one GPU but multiple. Extending FlashAttention to multi GPU I/O awareness is crucial.
Generally, I think this paper is an amazing example of MLSys research, obtaining performance improvement by leveraging the intersection of systems, networking, and ML.
References
[1] Paper link: https://arxiv.org/pdf/2205.14135.pdf