FlashAttention (arXiv:2205.14135)
PDF: Flash Attention - 2205.14135.pdf
Overview
- Introduces IO-aware algorithm for exact attention that reduces memory reads/writes between GPU HBM (high-bandwidth memory) and SRAM (on-chip cache) from O(n²) to O(n²/M) passes, where M is SRAM size.
- Uses tiling to divide Q, K, V into blocks that fit in SRAM, computing softmax incrementally using online softmax algorithm with running statistics.
- Implements fused CUDA kernel that performs attention computation (QK^T, softmax, dropout, output projection) in a single kernel without materializing intermediate n×n matrices.
- Achieves 2-4× training speedup on GPT-2/BERT-sized models and enables context lengths up to 64k tokens (limited only by memory, not speed).
- Exact attention—no approximation or quality degradation compared to standard attention.
Core Concepts
- Memory Hierarchy Awareness: Standard attention is memory-bound, not compute-bound. The bottleneck is moving data between slow HBM (40-80 GB, ~1.5 TB/s) and fast SRAM (20 MB, ~19 TB/s). FlashAttention minimizes HBM accesses.
- Tiling Algorithm: Divide Q into blocks of size Bᵣ (e.g., 128) and K,V into blocks of size Bᶜ (e.g., 64). For each Q block:
- Load Q block to SRAM
- Loop over K,V blocks, computing attention scores and outputs incrementally
- Use online softmax to maintain numerically stable running max and sum
- Online Softmax: Compute softmax in streaming fashion without seeing all values first:
- Track running max m and sum ℓ for each row
- When processing new K block, update: m_new = max(m_old, m_block), ℓ_new = exp(m_old - m_new)·ℓ_old + exp(m_block - m_new)·ℓ_block
- Rescale accumulated output accordingly
- Kernel Fusion: Combine QK^T matmul, softmax, attention dropout, and OV matmul into single CUDA kernel. Eliminates intermediate writes/reads of attention matrix.
- Recomputation in Backward: Don’t store O(n²) attention matrix for backward pass. Instead, recompute attention on-the-fly during backprop using same tiling strategy. Trade computation for memory.
Relevance to MegaContext
- Critical for POC Implementation: FlashAttention is baseline requirement for efficient GistNet and LensNet training. Without it, even 8k Working Context would be prohibitively slow.
- Enables long working contexts: W_max=32k becomes feasible with FlashAttention’s linear memory scaling. Standard attention would require 32²×d = 1GB+ just for attention matrix; FlashAttention reduces to ~10MB.
- GistNet compression efficiency: When GistNet computes cross-attention between 32 input tokens and gist slot queries, FlashAttention reduces 32² reads to ~4 passes. Critical for real-time compression in Runtime Loop.
- LensNet scoring speedup: Non-causal attention over W-length working context benefits from FlashAttention’s tiling, reducing LensNet Scoring latency by 2-3×.
- Enables deeper hierarchies: Memory savings allow larger batch sizes during GistNet Training, or longer lookahead horizons H for ΔNLL@H computation.
What We Can Use
- Integrate FlashAttention-2/3 in all attention modules: Replace PyTorch’s scaled_dot_product_attention with FlashAttention implementation. Applies to:
- GistNet’s self-attention and cross-attention layers
- LensNet’s dual cross-attention over working context
- Base Runtime’s frozen LLM attention (if we can modify inference code)
- Tune block sizes for MegaContext workloads: Default Bᵣ=128, Bᶜ=64 optimized for standard transformers. Profile MegaContext’s specific attention patterns (non-causal LensNet, short cross-attention in GistNet) and adjust tile sizes for optimal HBM↔SRAM traffic.
- Recomputation strategy for training: Apply FlashAttention’s backward pass recomputation to GistNet Training and LensNet Training. Reduces activation memory by 5-10×, enabling larger batch sizes or longer horizons.
- Variable-length attention: Use FlashAttention’s support for variable sequence lengths to handle Working Context Assembly’s mixed LOD sequences efficiently—don’t need to pad LOD0/LOD1/LOD2 entries to uniform length.
- Fused dropout and masking: Implement attention masking (for causal/non-causal switching) and dropout inside the FlashAttention kernel rather than as separate operations, reducing memory traffic by another 20-30%.
- Multi-query attention optimization: FlashAttention-2 includes optimizations for MQA/GQA patterns (shared KV across heads). If we adopt grouped-query attention in GistNet or LensNet, use these optimizations.
- Benchmark attention patterns: Profile MegaContext’s specific attention workloads (32-token compression, W-length scoring, H-horizon lookahead) against FlashAttention’s block sizes to identify bottlenecks and optimize accordingly.
Limitations & Risks
- Requires CUDA: FlashAttention is GPU-only (CUDA/Triton). No efficient CPU fallback, which limits development/debugging on non-GPU machines. Critical for POC Implementation deployment.
- Backward pass recomputation overhead: Trading memory for compute works during training but adds 10-20% wall-clock time. For very long sequences (H=128), this overhead compounds.
- Causal vs. non-causal switching: FlashAttention optimizes for causal attention. LensNet’s non-causal attention over working context may not benefit as much from tiling optimizations—need separate tuning.
- Dropout during inference: FlashAttention’s fused dropout is training-only. If we use dropout-based uncertainty estimation during inference, need separate implementation path.
- Mixed precision numerical issues: Aggressive FP16/BF16 usage in FlashAttention can cause numerical instability with extreme attention distributions. May require careful scaling for LensNet’s signed focus scores.
- Version compatibility: FlashAttention-2/3 APIs differ significantly. Pinning to specific version creates tech debt; upgrading requires retesting all attention modules.
Potential Follow-Up Reading
- FlashAttention-2 (2307.08691) - 2× faster than v1, better parallelism for long sequences
- FlashAttention-3 (2024.07.07) - Hardware-aware optimizations for Hopper GPUs (H100), asynchronous loads
- Paged Attention (vLLM) - Memory management for dynamic batching; complements FlashAttention for serving
- Triton tutorials - Understanding kernel implementation details; useful for custom MegaContext attention patterns
- Memory-Efficient Attention (xformers) - Alternative implementation with similar goals; comparison baseline
- Ring Attention (2310.01889) - Distributed attention across multiple devices; relevant for scaling beyond single GPU
Open Questions for MegaContext
- What’s the actual speedup/memory benefit for MegaContext’s specific attention patterns (non-causal LensNet, small cross-attention GistNet) vs. standard causal LM attention?
- Should we fork FlashAttention to add custom support for mixed LOD sequences (LOD0/LOD1/LOD2 with different embedding scales)?
- Can we combine FlashAttention’s tiling with sparse attention patterns from Sparse Transformers—tile within sparse blocks for maximum efficiency?
- How to handle FlashAttention’s recomputation during counterfactual ΔNLL@H computation in LensNet Training—double recomputation overhead?
- Should POC Implementation have fallback to standard attention for debugging, or always require FlashAttention (simpler but less portable)?
- What’s the memory/speed tradeoff between FlashAttention’s recomputation and Reformer’s reversible layers? Can we combine both?