尧图网站建设 尧图网络
  • 首页
  • 关于我们
  • 服务项目
  • 案例展示
  • 建站流程
  • 资讯中心
  • 联系我们
首页/资讯中心/详情

深度学习进阶(三十一)FlashAttention:IO 感知的精确注意力

深度学习进阶(三十一)FlashAttention:IO 感知的精确注意力
📅 发布时间:2026/6/19 16:39:09

上一篇我们把现代大模型的五个核心模块拼回了 LLaMA 这个完整案例中,可以看到注意力机制仍然是计算最密集的部分。

而这个密集程度在序列变长时,会变得越来越恐怖:

标准自注意力的计算复杂度和空间复杂度都是 \(O(n^2)\):序列长度翻倍,计算量翻四倍,内存占用也翻四倍。

而在之前,我们用 KV Cache 解决了推理阶段的重复计算问题,但训练和长序列推理中的注意力计算本身,仍然是一个巨大的计算瓶颈。
因此,在展开正式多模态前,会再插入几篇现代工程优化技术的相关内容作为支撑。

一直以来,针对 Attention 的计算量问题主要有两条路线:

  1. 近似注意力:如稀疏注意力、低秩近似等,总结来说就是压缩,用质量换速度。
  2. 更好的硬件利用:充分利用 GPU 的计算单元,让显存的数据搬运不再成为瓶颈,这也是现代 LLM 的主流。

22 年的论文:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 走的就是第二条路。

它不仅没有牺牲精度,反而在注意力计算上实现了 2-4 倍的实际加速,并将注意力中间结果的显存占用从 \(O(n^2)\) 降到了 \(O(n)\)。 其核心观点是:

注意力计算慢,不是计算量太大,而是显存搬运太频繁。

这一技术涉及到 GPU 相关的硬件知识,因此先行补充:

1. GPU 上的注意力

1.1 计算指标和 GPU 核心构建

我们知道支撑 LLM 庞大算力的基础设施是 GPU,因此要理解 FlashAttention,首先需要理解 GPU 是怎么工作的。

首先展开两个核心指标:

  1. FLOPs(Floating Point Operations):指浮点运算次数,代表“计算量”。也就是 GPU 的计算单元需要做多少次乘法和加法,衡量的是算法的算力消耗。
  2. IO(Input/Output):指数据在不同内存之间的搬运次数与字节量,代表“访存量”。也就是数据从慢速显存搬到快速缓存,计算完再搬回去的开销。

很显然,前者是我们的算法账,而后者则是硬件相关的开销,在训练和推理中,二者都缺一不可。
我们已经十分熟悉在代码上进行算法优化的逻辑,而 FlashAttention 是首个把 IO 优化思想引入注意力机制的工作。

由此我们继续展开,现代 GPU 有两种主要的存储空间:

  1. HBM(High Bandwidth Memory,高带宽显存):容量大(A100 有 40/80GB),但带宽有限,我们在市面上听说的多少 GB 的显卡,指的都是 HBM。
  2. SRAM(片上共享内存):容量极小(A100 每个 SM 只有 192KB),但带宽极高,速度极快,它的层次和 CPU 中的 cache 相似,但更加受我们操作者控制。

为了获得最高计算效率,GPU 的计算单元往往只直接读写 SRAM,不会直接访问 HBM。
所以,要从 HBM 读取数据到 SRAM,再从 SRAM 做计算,最后把结果写回 HBM,这一过程涉及大量 IO 操作。

abbb133b-ef69-4e58-83ce-9fc133efaa3f.png

1.2 注意力中的 IO 操作

我们对标准注意力计算的过程已经十分熟悉了,这里我们扩展开来,看看其计算在 GPU 上的具体 IO 过程:
a1537138-5570-488c-abf1-9a5d3c93f83a.png

注意这里反复进行了 6 次 HBM ↔ SRAM 的数据搬运。
举个例子:对于序列长度 \(n=1024\)、维度(多头总和) \(d=64\)、批量大小 \(b=1\) 的情况,我们进行一个估计:

\[\text{总搬运量} \approx \underbrace{6}_{\text{搬运趟数}} \times \underbrace{(2n^2 + nd)}_{\text{每趟搬运的元素量}} \times \underbrace{2}_{\text{FP16字节数}} \quad \text{字节} \]

假设 FP16,最终大约的数据量就是

\[6 \times (2 \times 1024^2 + 1024 \times 64) \times 2 \approx 25\quad MB \]

现在扩展上下文到 \(n=64K\) ,这个数字就会暴涨:

\[\text{IO} \approx 6\times (2\times65536^2+65536\times64) \times2 \approx 96\ \text{GB} \]

而一个 GPU 的 HBM 带宽大约是 2 TB/s,也就是说仅仅数据搬运(IO)就要 40 多毫秒。注意这只是一次注意力计算。

相比之下,A100 的算力高达 312 TFLOPS,完成这些矩阵乘法的实际计算(FLOPs)可能只需要不到 1 毫秒。计算单元极快。
于是我们发现了问题:

**标准注意力的瓶颈不在计算(FLOPs),而在访存(IO)。是数据搬运太慢,导致 GPU 大量时间在“空转等数据”。

由此,FlashAttention 开始了优化:

2. FlashAttention 的核心思路

FlashAttention 的想法很直接:

与其把整个 Q、K、V 搬来搬去,不如每次只加载一小块到 SRAM 上进行完所有计算,再写回 HBM。

这样,原本需要在 HBM 和 SRAM 之间来回搬运多次的数据,现在就只需要一次完整的读取和一次写入。
但这里有一个硬性问题:

Softmax 是一个全局操作:要计算某个位置的 softmax,需要知道所有位置的分数。

而如果要使用刚刚说的分块计算,显然每个块只能看到自己的局部信息,怎么算全局 softmax 得到注意力权重?

答案是 18 年 NVIDIA 研究者的论文 Online normalizer calculation for softmax提出的 Online Softmax(在线 Softmax)。

2.1 Online Softmax

我们知道,对于一个包含 \(n\) 个元素的向量 \(x = [x_1, x_2, \dots, x_n]\),Softmax 函数将其转换为概率分布 \(y = [y_1, y_2, \dots, y_n]\) 的标准公式为:

\[y_i = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}} \]

但在实际工程中,我们往往并不会使用这个公式,而是使用 Safe Softmax:

\[y_i = \frac{e^{x_i - m}}{\sum_{j=1}^{n} e^{x_j - m}} \quad \text{其中 } m = \max(x) \]

这是因为指数函数 \(e^x\) 增长极快。如果 \(x_i\) 比较大(比如 1000),\(e^{1000}\) 会直接超出计算机浮点数的表示范围,变成 inf(无穷大),导致最终计算结果是 NaN。
因此,Safe Softmax 在分子分母同除以一个常数 \(e^m\),通常取向量中的最大值 \(m = \max(x)\)。
因为指数相减等于相除,不会改变最终的相对比例,但把所有指数都拉到了 \(\le 0\) 的范围,避免了上溢出。
但这仍然没有解决我们现在的问题,因为无论是取最大值,还是计算分母,我们还是需要遍历所有元素。

而 Online Softmax 的思想是这样的:

当数据无法一次性全部读入内存,或者需要分块计算时,可以使用流式更新的方式计算最大值 \(m\) 和分母 \(d\)(即 \(\sum e^{x_j - m}\))。

假设我们正在逐个读取向量的元素,到第 \(i\) 个元素时,更新最大值:

\[m_i = \max(m_{i-1}, x_i) \]

然后更新分母:

\[d_i = d_{i-1} \cdot e^{m_{i-1} - m_i} + e^{x_i - m_i} \]

最终结果:

\[y_i = \frac{e^{x_i - m_n}}{d_n} \]

其核心技巧是 \(d_i\) 的公式中的 \(e^{m_{i-1} - m_i}\) ,这一项的作用是当遇到更大的新最大值 \(m_i\) 时,把之前累加的分母 \(d_{i-1}\) “按比例缩小”,使其基准与新的最大值对齐。

我们来看一个简单实例,假设向量为 \(x=[2,1,3]\),初始化 \(m_0=-\infty, d_0=0\) :

第一步,我们读入 2,更新最大值和分母:

\[m_1=\max(-\infty,2)=2 , d_1 = 0\cdot e^{-\infty} + e^{2-2} = 1 \]

此时:

\[m_1=2,\qquad d_1=1 \]

继续,第二步读入 1,再次更新:

\[m_2=\max(2,1)=2,d_2 = 1\cdot e^{2-2} + e^{1-2}\approx1.368 \]

现在:

\[m_2=2,\qquad d_2\approx1.368 \]

继续第三步读入 3:

\[m_3=\max(2,3)=3 \]

注意这里最大值发生变化,因此之前累加的分母会重新缩放:

\[d_3 = d_2\cdot e^{2-3} + e^{3-3}= 1.368\times e^{-1}+1\approx1.503 \]

最终得到:

\[m=3,d\approx1.503 \]

\[\text{Softmax}(x) = \left[ \frac{e^{2-3}}{1.503}, \frac{e^{1-3}}{1.503}, \frac{e^{3-3}}{1.503} \right]\approx [0.245,\;0.090,\;0.665] \]

与标准 Safe Softmax 的结果完全一致。 整个过程中,我们只需要维护两个标量:\(m_i,d_i\) 即可,而不需要等全部数据读完后再计算。

但很显然,这种 Online Softmax 是逐个更新,并不符合 FlashAttention 分块计算逻辑,因此,我们还要再进行适配:

2.2 把 Online Softmax 嵌入注意力

在注意力中,我们不仅要算 softmax,还要用 softmax 的结果乘以 V,得到最终输出,由此,FlashAttention 的最终做法是:

固定一个Q块,不断流过所有KV块,累积输出并修正。

d3ad6edf-f62c-423b-ac1a-113d2ebee69b.png

再展开一下这个过程,对于某个 Q 块,读取当前的 K 块和 V 块到 SRAM 并进行以下步骤:

  1. 对当前的 Q 块与 K 块算分数矩阵。
  2. 对分数矩阵做局部 softmax,得到局部权重。
  3. 用局部权重乘以 V 块,得到局部输出。
  4. 基于当前已处理的块维护全局 \(m\) 和 \(d\),并用它们修正之前累积的输出。

为了方便理解,我们只看单个 Query Token。
对于这个 Token 来说,它需要和所有 Key Token 计算注意力分数,然后做一次完整的 Softmax。因此,它会维护三个状态:

  1. \(m\):目前见过的最大注意力分数。
  2. \(d\):Softmax 的归一化分母。
  3. \(O\):当前已经累积得到的输出向量。

当 KV 被分成多个块依次加载时,这三个量会随着块的处理不断更新,假设已经处理了前 \(j-1\) 个 KV 块,我们得到如下累积量:

  1. \(m_{\text{prev}}\):前 \(j-1\) 个块中的最大分数。
  2. \(d_{\text{prev}}\):前 \(j-1\) 个块的指数修正和。
  3. \(O_{\text{prev}}\):前 \(j-1\) 个块的部分输出。

现在开始处理第 \(j\) 个 KV 块:

  1. 当前 Query 与这一块 Key 计算注意力分数:

\[S_j=QK_j^T \]

  1. 求出当前块中的最大分数:

\[m_j=\max(S_j) \]

  1. 更新截至目前的全局最大值:

\[m_{\text{new}}= \max(m_{\text{prev}},m_j) \]

  1. 按照 Online Softmax 的思想,修正之前累计的分母:

\[d_{\text{new}} = e^{m_{\text{prev}}-m_{\text{new}}} d_{\text{prev}} + e^{m_j-m_{\text{new}}} \sum e^{S_j-m_j} \]

这里第一项对应历史 KV 块的贡献,第二项对应当前 KV 块的贡献。当新的最大值出现时,历史部分会自动重新缩放到新的基准下。

4c25839e-068b-4cb9-9486-f5a0e1f50d95 (1).png

  1. 修正之前的输出,并加上当前块的贡献:

\[O_{\text{new}}= \frac{d_{\text{prev}}}{d_{\text{new}}} e^{m_{\text{prev}}-m_{\text{new}}} O_{\text{prev}} + \frac{e^{m_j-m_{\text{new}}}}{d_{\text{new}}} \text{softmax}_j(S_j)V_j \]

0282d607-638c-46cd-8403-0f0dc3fd8d46.png
这样处理完当前块后,就得到了新的:

\[(m_{\text{new}},d_{\text{new}},O_{\text{new}}) \]

随后继续读取下一个 KV 块重复这一过程,直到所有 KV 块都被遍历完成。

需要强调的是,上述推导是针对单个 Query Token 进行说明的。实际 FlashAttention 中会同时处理一整个 Q 块,因此这里的 \(m\)、\(d\) 和 \(O\) 实际上并不是标量,而是对应每行 Query Token 分别维护的一组向量状态。

总结来说,FlashAttention 可以看成代码里的嵌套循环,其中外层循环遍历 KV 块,内层循环遍历 Q 块。
整个过程中,只有 QKV 块从HBM 加载,最终的输出 O 写入 HBM,中间的所有结果都只存在于 SRAM 中,不写回 HBM。
这就是 FlashAttention 在注意力中引入 IO 优化思想得到的突破。

3. FlashAttention 的局限和后续发展

FlashAttention 虽然解决了 IO 瓶颈,但它仍然没有完全发挥 GPU 的计算能力,论文报告,FlashAttention 在 A100 上只达到了 25-40% 的理论峰值 FLOPs。
其原因大体分析如下:

  1. 非矩阵乘法的开销:softmax 的指数运算、逐元素的乘法和加法等操作,不如矩阵乘法高效。
  2. 线程块和 warp 的调度:FlashAttention 在并行化上还有优化空间,部分线程块可能处于空闲状态。
  3. 共享内存的读写:虽然减少了 HBM 读写,但 SRAM 内部的通信也存在开销。

但这也同时说明了该方向仍然存在极大的优化空间。
在此之后,23 年的论文:FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning提出了 FlashAttention-2,对 FlashAttention 的工作分配进行了系统性的优化,其核心内容是通过应用更多矩阵乘法,优化线程并行,线程内分工的方式,以更好的并行得到加速效果。

FlashAttention-2 中的各种“榨干 GPU” 的优化其实更符合现代工业标准。但这里涉及到较多的操作系统和硬件内容,就不再详细展开了。
FlashAttention-2 将 GPU 利用率从 FA1 的 25-40% 提升到了 50-73%,相比 FA1 实现了大约 2 倍的速度提升。

还没结束,24 年的论文:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision 将目光投向了新一代的 Hopper 架构 GPU(H100)。

这是因为前两代 FlashAttention 虽然在 A100 上表现优异,但在 H100 上,FA2 也只达到约 35% 的硬件利用率。
在前两代的基础上,FA3 利用 H100 特有的的 TMA(Tensor Memory Accelerator) 硬件单元和实现异步处理,又在 FP8 支持的前提下应用了更多解耦分工和细节优化,最终对比如下:

版本 年份 硬件 核心改进 GPU 利用率 加速比
FA1 2022 A100 IO 感知 + 分块 + Online Softmax 25-40% 基准
FA2 2023 A100 更好的工作分配和并行化 50-73% ~2× vs FA1
FA3 2024 H100 异步 TMA + Warp 特化 + FP8 75% (FP16) / ~60% (FP8) 1.5-2× vs FA2

总结来说,FlashAttention 系列产生的影响远超论文本身的数值,其对显存占用降低让长上下文成为可能,同时让精确注意力重新成为主流。

在 FA 出现之前,为了解决显存占用 \(O(n^2)\) 的问题,研究者们提出了大量近似注意力方法:稀疏注意力、低秩注意力等。这些方法各有各的妥协。
而 FlashAttention 证明了:不需要精度妥协,只要把 IO 优化做好,精确注意力不仅更快,效果也更好。 此后,近似注意力方法的热度大幅下降。

并且,这种把目光投向硬件,IO 感知的思路已经超越了注意力本身:FlashFFTConv、FlashDecoding、甚至一些 Mamba 的工程实现都受到其启发。
如今, FlashAttention 系列仍然在进化中。

相关新闻

  • 6个免费方法让你的手机视频秒变MP4 - 软件工具教程方法
  • Kali Linux实战:ARP欺骗攻击原理、环境搭建与Wireshark流量分析
  • 杭州靠谱品牌首饰回收排行,光谱验金透明称重全款现结 - 奢品小当家

最新新闻

  • FanControl:Windows平台专业风扇智能温控的完整解决方案
  • 建构之法阅读笔记5
  • 别被线上虚高报价骗了!广州正规回收认准收的顶,报价即成交价 - 奢侈品回收测评
  • Honey Select 2终极游戏增强补丁:一键解锁完整游戏体验的完整解决方案
  • MC9S12XE Flash操作全解析:从物理原理到Bootloader实战
  • Python自动化抢票终极指南:5分钟掌握大麦网高效抢票技术

日新闻

  • 5分钟掌握Python进化算法:Geatpy高性能优化工具完全指南
  • Microchip 24AA044 EEPROM选型与应用全指南:从参数解析到实战编程
  • 华为的鸿蒙到底有多牛?为什么称作遥遥领先?

周新闻

  • 3步解锁iOS设备:applera1n激活锁绕过完全指南
  • 39 2026 人工智能证书终极盘点,普通人选 AI 证书可以从这些方向入手
  • Redis 暴露公网有多危险?从端口检查到补救步骤

月新闻

  • 【总结】入门篇:50句话让你记住架构核心概念
  • WeChatMsg技术方案解析:实现Mac微信数据自主管理的完整解决方案
  • WeChatMsg:革新性微信数据备份方案,打造你的专属数字记忆库

关于尧图

  • 公司简介
  • 团队介绍
  • 企业文化
  • 荣誉资质

服务项目

  • 定制开发
  • 电商建站
  • UI 设计
  • 运维服务

快速链接

  • 案例展示
  • 建站流程
  • 常见问题
  • 资讯中心

联系方式

  • 📍北京市朝阳区互联网产业园 A 座 10 层
  • 📞400-888-8888
  • ✉️contact@rkmt.cn
  • 🕐周一至周日 9:00-21:00

© 2024 北京尧图网络科技有限公司 版权所有 | 京 ICP 备 XXXXXXXX 号