上一篇我们把现代大模型的五个核心模块拼回了 LLaMA 这个完整案例中,可以看到注意力机制仍然是计算最密集的部分。
而这个密集程度在序列变长时,会变得越来越恐怖:
标准自注意力的计算复杂度和空间复杂度都是 \(O(n^2)\):序列长度翻倍,计算量翻四倍,内存占用也翻四倍。
而在之前,我们用 KV Cache 解决了推理阶段的重复计算问题,但训练和长序列推理中的注意力计算本身,仍然是一个巨大的计算瓶颈。
因此,在展开正式多模态前,会再插入几篇现代工程优化技术的相关内容作为支撑。
一直以来,针对 Attention 的计算量问题主要有两条路线:
- 近似注意力:如稀疏注意力、低秩近似等,总结来说就是压缩,用质量换速度。
- 更好的硬件利用:充分利用 GPU 的计算单元,让显存的数据搬运不再成为瓶颈,这也是现代 LLM 的主流。
22 年的论文:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 走的就是第二条路。
它不仅没有牺牲精度,反而在注意力计算上实现了 2-4 倍的实际加速,并将注意力中间结果的显存占用从 \(O(n^2)\) 降到了 \(O(n)\)。 其核心观点是:
注意力计算慢,不是计算量太大,而是显存搬运太频繁。
这一技术涉及到 GPU 相关的硬件知识,因此先行补充:
1. GPU 上的注意力
1.1 计算指标和 GPU 核心构建
我们知道支撑 LLM 庞大算力的基础设施是 GPU,因此要理解 FlashAttention,首先需要理解 GPU 是怎么工作的。
首先展开两个核心指标:
- FLOPs(Floating Point Operations):指浮点运算次数,代表“计算量”。也就是 GPU 的计算单元需要做多少次乘法和加法,衡量的是算法的算力消耗。
- IO(Input/Output):指数据在不同内存之间的搬运次数与字节量,代表“访存量”。也就是数据从慢速显存搬到快速缓存,计算完再搬回去的开销。
很显然,前者是我们的算法账,而后者则是硬件相关的开销,在训练和推理中,二者都缺一不可。
我们已经十分熟悉在代码上进行算法优化的逻辑,而 FlashAttention 是首个把 IO 优化思想引入注意力机制的工作。
由此我们继续展开,现代 GPU 有两种主要的存储空间:
- HBM(High Bandwidth Memory,高带宽显存):容量大(A100 有 40/80GB),但带宽有限,我们在市面上听说的多少 GB 的显卡,指的都是 HBM。
- SRAM(片上共享内存):容量极小(A100 每个 SM 只有 192KB),但带宽极高,速度极快,它的层次和 CPU 中的 cache 相似,但更加受我们操作者控制。
为了获得最高计算效率,GPU 的计算单元往往只直接读写 SRAM,不会直接访问 HBM。
所以,要从 HBM 读取数据到 SRAM,再从 SRAM 做计算,最后把结果写回 HBM,这一过程涉及大量 IO 操作。

1.2 注意力中的 IO 操作
我们对标准注意力计算的过程已经十分熟悉了,这里我们扩展开来,看看其计算在 GPU 上的具体 IO 过程:

注意这里反复进行了 6 次 HBM ↔ SRAM 的数据搬运。
举个例子:对于序列长度 \(n=1024\)、维度(多头总和) \(d=64\)、批量大小 \(b=1\) 的情况,我们进行一个估计:
假设 FP16,最终大约的数据量就是
现在扩展上下文到 \(n=64K\) ,这个数字就会暴涨:
而一个 GPU 的 HBM 带宽大约是 2 TB/s,也就是说仅仅数据搬运(IO)就要 40 多毫秒。注意这只是一次注意力计算。
相比之下,A100 的算力高达 312 TFLOPS,完成这些矩阵乘法的实际计算(FLOPs)可能只需要不到 1 毫秒。计算单元极快。
于是我们发现了问题:
**标准注意力的瓶颈不在计算(FLOPs),而在访存(IO)。是数据搬运太慢,导致 GPU 大量时间在“空转等数据”。
由此,FlashAttention 开始了优化:
2. FlashAttention 的核心思路
FlashAttention 的想法很直接:
与其把整个 Q、K、V 搬来搬去,不如每次只加载一小块到 SRAM 上进行完所有计算,再写回 HBM。
这样,原本需要在 HBM 和 SRAM 之间来回搬运多次的数据,现在就只需要一次完整的读取和一次写入。
但这里有一个硬性问题:
Softmax 是一个全局操作:要计算某个位置的 softmax,需要知道所有位置的分数。
而如果要使用刚刚说的分块计算,显然每个块只能看到自己的局部信息,怎么算全局 softmax 得到注意力权重?
答案是 18 年 NVIDIA 研究者的论文 Online normalizer calculation for softmax提出的 Online Softmax(在线 Softmax)。
2.1 Online Softmax
我们知道,对于一个包含 \(n\) 个元素的向量 \(x = [x_1, x_2, \dots, x_n]\),Softmax 函数将其转换为概率分布 \(y = [y_1, y_2, \dots, y_n]\) 的标准公式为:
但在实际工程中,我们往往并不会使用这个公式,而是使用 Safe Softmax:
这是因为指数函数 \(e^x\) 增长极快。如果 \(x_i\) 比较大(比如 1000),\(e^{1000}\) 会直接超出计算机浮点数的表示范围,变成 inf(无穷大),导致最终计算结果是 NaN。
因此,Safe Softmax 在分子分母同除以一个常数 \(e^m\),通常取向量中的最大值 \(m = \max(x)\)。
因为指数相减等于相除,不会改变最终的相对比例,但把所有指数都拉到了 \(\le 0\) 的范围,避免了上溢出。
但这仍然没有解决我们现在的问题,因为无论是取最大值,还是计算分母,我们还是需要遍历所有元素。
而 Online Softmax 的思想是这样的:
当数据无法一次性全部读入内存,或者需要分块计算时,可以使用流式更新的方式计算最大值 \(m\) 和分母 \(d\)(即 \(\sum e^{x_j - m}\))。
假设我们正在逐个读取向量的元素,到第 \(i\) 个元素时,更新最大值:
然后更新分母:
最终结果:
其核心技巧是 \(d_i\) 的公式中的 \(e^{m_{i-1} - m_i}\) ,这一项的作用是当遇到更大的新最大值 \(m_i\) 时,把之前累加的分母 \(d_{i-1}\) “按比例缩小”,使其基准与新的最大值对齐。
我们来看一个简单实例,假设向量为 \(x=[2,1,3]\),初始化 \(m_0=-\infty, d_0=0\) :
第一步,我们读入 2,更新最大值和分母:
此时:
继续,第二步读入 1,再次更新:
现在:
继续第三步读入 3:
注意这里最大值发生变化,因此之前累加的分母会重新缩放:
最终得到:
与标准 Safe Softmax 的结果完全一致。 整个过程中,我们只需要维护两个标量:\(m_i,d_i\) 即可,而不需要等全部数据读完后再计算。
但很显然,这种 Online Softmax 是逐个更新,并不符合 FlashAttention 分块计算逻辑,因此,我们还要再进行适配:
2.2 把 Online Softmax 嵌入注意力
在注意力中,我们不仅要算 softmax,还要用 softmax 的结果乘以 V,得到最终输出,由此,FlashAttention 的最终做法是:
固定一个Q块,不断流过所有KV块,累积输出并修正。

再展开一下这个过程,对于某个 Q 块,读取当前的 K 块和 V 块到 SRAM 并进行以下步骤:
- 对当前的 Q 块与 K 块算分数矩阵。
- 对分数矩阵做局部 softmax,得到局部权重。
- 用局部权重乘以 V 块,得到局部输出。
- 基于当前已处理的块维护全局 \(m\) 和 \(d\),并用它们修正之前累积的输出。
为了方便理解,我们只看单个 Query Token。
对于这个 Token 来说,它需要和所有 Key Token 计算注意力分数,然后做一次完整的 Softmax。因此,它会维护三个状态:
- \(m\):目前见过的最大注意力分数。
- \(d\):Softmax 的归一化分母。
- \(O\):当前已经累积得到的输出向量。
当 KV 被分成多个块依次加载时,这三个量会随着块的处理不断更新,假设已经处理了前 \(j-1\) 个 KV 块,我们得到如下累积量:
- \(m_{\text{prev}}\):前 \(j-1\) 个块中的最大分数。
- \(d_{\text{prev}}\):前 \(j-1\) 个块的指数修正和。
- \(O_{\text{prev}}\):前 \(j-1\) 个块的部分输出。
现在开始处理第 \(j\) 个 KV 块:
- 当前 Query 与这一块 Key 计算注意力分数:
- 求出当前块中的最大分数:
- 更新截至目前的全局最大值:
- 按照 Online Softmax 的思想,修正之前累计的分母:
这里第一项对应历史 KV 块的贡献,第二项对应当前 KV 块的贡献。当新的最大值出现时,历史部分会自动重新缩放到新的基准下。

- 修正之前的输出,并加上当前块的贡献:

这样处理完当前块后,就得到了新的:
随后继续读取下一个 KV 块重复这一过程,直到所有 KV 块都被遍历完成。
需要强调的是,上述推导是针对单个 Query Token 进行说明的。实际 FlashAttention 中会同时处理一整个 Q 块,因此这里的 \(m\)、\(d\) 和 \(O\) 实际上并不是标量,而是对应每行 Query Token 分别维护的一组向量状态。
总结来说,FlashAttention 可以看成代码里的嵌套循环,其中外层循环遍历 KV 块,内层循环遍历 Q 块。
整个过程中,只有 QKV 块从HBM 加载,最终的输出 O 写入 HBM,中间的所有结果都只存在于 SRAM 中,不写回 HBM。
这就是 FlashAttention 在注意力中引入 IO 优化思想得到的突破。
3. FlashAttention 的局限和后续发展
FlashAttention 虽然解决了 IO 瓶颈,但它仍然没有完全发挥 GPU 的计算能力,论文报告,FlashAttention 在 A100 上只达到了 25-40% 的理论峰值 FLOPs。
其原因大体分析如下:
- 非矩阵乘法的开销:softmax 的指数运算、逐元素的乘法和加法等操作,不如矩阵乘法高效。
- 线程块和 warp 的调度:FlashAttention 在并行化上还有优化空间,部分线程块可能处于空闲状态。
- 共享内存的读写:虽然减少了 HBM 读写,但 SRAM 内部的通信也存在开销。
但这也同时说明了该方向仍然存在极大的优化空间。
在此之后,23 年的论文:FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning提出了 FlashAttention-2,对 FlashAttention 的工作分配进行了系统性的优化,其核心内容是通过应用更多矩阵乘法,优化线程并行,线程内分工的方式,以更好的并行得到加速效果。
FlashAttention-2 中的各种“榨干 GPU” 的优化其实更符合现代工业标准。但这里涉及到较多的操作系统和硬件内容,就不再详细展开了。
FlashAttention-2 将 GPU 利用率从 FA1 的 25-40% 提升到了 50-73%,相比 FA1 实现了大约 2 倍的速度提升。
还没结束,24 年的论文:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision 将目光投向了新一代的 Hopper 架构 GPU(H100)。
这是因为前两代 FlashAttention 虽然在 A100 上表现优异,但在 H100 上,FA2 也只达到约 35% 的硬件利用率。
在前两代的基础上,FA3 利用 H100 特有的的 TMA(Tensor Memory Accelerator) 硬件单元和实现异步处理,又在 FP8 支持的前提下应用了更多解耦分工和细节优化,最终对比如下:
| 版本 | 年份 | 硬件 | 核心改进 | GPU 利用率 | 加速比 |
|---|---|---|---|---|---|
| FA1 | 2022 | A100 | IO 感知 + 分块 + Online Softmax | 25-40% | 基准 |
| FA2 | 2023 | A100 | 更好的工作分配和并行化 | 50-73% | ~2× vs FA1 |
| FA3 | 2024 | H100 | 异步 TMA + Warp 特化 + FP8 | 75% (FP16) / ~60% (FP8) | 1.5-2× vs FA2 |
总结来说,FlashAttention 系列产生的影响远超论文本身的数值,其对显存占用降低让长上下文成为可能,同时让精确注意力重新成为主流。
在 FA 出现之前,为了解决显存占用 \(O(n^2)\) 的问题,研究者们提出了大量近似注意力方法:稀疏注意力、低秩注意力等。这些方法各有各的妥协。
而 FlashAttention 证明了:不需要精度妥协,只要把 IO 优化做好,精确注意力不仅更快,效果也更好。 此后,近似注意力方法的热度大幅下降。
并且,这种把目光投向硬件,IO 感知的思路已经超越了注意力本身:FlashFFTConv、FlashDecoding、甚至一些 Mamba 的工程实现都受到其启发。
如今, FlashAttention 系列仍然在进化中。