FlashAttention-3:具有异步性和低精度的快速准确注意力机制

一、简介

FlashAttention 对注意力计算进行重新排序的算法,并利用 tiling 和重计算来显著加快计算速度,将内存使用量从序列长度的二次减少到线性。

Flash Attention 核心解决方案主要有两项:

  1. 融合算子 + Softmax Tiling:采用 Online Softmax 算法,实现了 Softmax 在 GPU 上的分块计算,节省了大量的 GMEM 读写;
  2. 重计算(Recomputation): 前向计算不保存 Attention 矩阵,仅保留数据量更小的 logsumexp,在反向计算时重新计算 Attention 矩阵,也减少了 Attention 矩阵的读写。

FlashAttention-2,在算法、并行化和工作分区等方面有了显著改进。

Flash Attention 的 Online Softmax 算法

首先相比 A100 80%-90% 的利用率,FA2 GPU 利用率在 H100 上仅为 35%-40%,存在着一定优化空间。其次 Hopper 架构的 WGMMA 和 TMA 新指令在增速提效的同时,可以方便我们在 tile-based 的维度进行算子开发。所以作者又提出了更加变态的 Flash Attention V3,可以把 GPU 效能利用率提高到 75%,速度比 Flash Attention V2 提升 1.5~2 倍。

FlashAttention-3,它采用了加速 Hopper GPU 注意力的三种主要技术

  • 通过 warp-specialization 重叠整体计算和数据移动;
  • 交错分块 matmul 和 softmax 运算;
  • 利用硬件支持 FP8 低精度的不连贯处理。

FlashAttention-3 的改进将带来

  • 更高效的 GPU 利用率:H100 理论最大 FLOPS 利用率为 75%,而之前仅为 35%。这使得 LLM 的训练和运行速度比以前的版本快得多。

  • 较低精度下更好的性能:FlashAttention-3 可以在保持精度的同时使用较低精度的数字 (FP8)。这可以实现更快的处理速度并可能降低内存使用量,从而为运行大规模人工智能操作的客户节省成本并提高效率。

  • 能够在 LLM 中使用更长的上下文:通过加速注意力机制,FlashAttention-3 使 AI 模型能够更有效地处理更长的文本片段。这使得应用程序能够理解并生成更长、更复杂的内容而不会减慢速度。

二、Hopper 架构 GPU

2.1 更多更强的 SM

Flash Attention V2不同的 task 会分配到 SM 上面做运算,所以可以说 SM 数量越多,GPU 算力越强,而这边 A100 有 108 个 SM, H100 提升到了 132 个 ,而且 H100 的 SM 在 FP16 精度上面的运算速度,在 MMA(矩阵相乘累加)的任务上是 A100 SM 的 2 倍,所以算一算整体速度 H100 是 A100 的 3 倍左右。

2.2 FP8 Tensor Core

另外 H100 上面还增加了处理 FP8 精度的 tensor core ,可以处理两种不同 format 的 FP8 矩阵运算,而因为 FP8 的位元表示是 FP16 的一半,所以 H100 FP8 的运算是 A100 FP16 速度的 6.4 倍。

2.3 Thread Block Cluster

在过往 A100 GPU 上面,我们会把多个 thread 分成 3 个阶层,分别是 thread、thread block 和 Grid,而因为 H100 的 SM 更多更强,单用 3 层去分派已经满足不了更复杂和更庞大的运算任务,所以H100 引入了 Thread Block Cluster 层,让 thread 的调度和记忆体管理上的颗粒度可以再分的更细致。

我们可以看到 Grid 的颗粒度,可以对应到 GEME 的记忆体区块,而这就是我们所熟知的 HBM,通常这块记忆体的 bandwidth 是最慢的,所以我们要做加速运算,会尽量减少从这块 GEME 拿数据。

接下来是 Thread Block Cluster,硬体上可以对应到 Graph Processing Cluster (GPC),而GPC 提供了所谓的 SM-to-SM Network,来加速不同 SM 之间的数据传输。

我们可以看到在 A100 当中,如果不同的 Thread Block 之间要互相传递数据的话,需要透过 HBM,但是 H100 中我们可以直接透过 SM-to-SM Network 来做更有效率的传输。 而这边数据 physical 的位置就是在 L2 Cache,Logical 的名称叫做 distributed shared memory (DSMEM) 。

接下来是 Thread Block,其也称之为 cooperative thread arrays (CTA),其对应到 SM,这边在前面就有提到过,在 Thread Block 里面不同的 Thread 如果要做数据传递的话,就是透过 shared memory (SMEM)。对于每一个 thread 最多可以有 256 个 private register (RMEM)。

L1 Cache 在 SM 当中,而 L0 Cache 在 Warp 里面。

2.4 WGMMA (Warpgroup MMA)

Warpgroup 指的是 4 个连续的 warps,共 128 个连续的 threads,正好对应了一个 SM 最多可并行计算的线程数。在 H100 上,我们可以以 Warpgroup 为粒度调度 GEMM 运算。下面说明了 A100 和 H100 调度 GEMM 的 API 的区别:

  • A100上,wmma.mma.sync (warp-level) 和 mma.sync(thread-level) 均为调用 Tensor Core 计算的同步 API,也就是必须等到结果计算出来,线程才能继续执行下一个指令;
  • H100上,新增的 wgmma.mma_async(warpgroup-level) 可以异步运行 Tensor Core,也就是可以与其他单元并行计算(例如 CUDA Core)。WGMMA operand A 可以从 RMEM/SMEM 读取,operand B 只能从 SMEM 读取,更多关于 WGMMA 指令的数据类型、shape 要求和数据排布等细节,可参考 PTX 相关文档

2.5 Tensor Memory Accelerator

TMA 是 H100 新增的硬件单元,它允许程序在 GMEM 和 SMEM 之间异步且双向地传输 1D 到 5D 的张量。通过这个专门用于数据移动的硬件单元,线程可以被解放出来做其他工作,而不是计算地址和管理数据移动,这消除了 Hopper 架构之前 SM 必须使用寄存器在不同内存空间之间移动数据的需求。

TMA 指令非常轻量化,只需要一个线程即可启动 TMA 传输。

TMA 单元减少了 SM 线程的计算需求

TMA 不仅负责数据本身的移动,还可以计算所需的目标内存地址,应用任何数据变换(如归约操作和按位操作等),并可以处理布局转换,以“交错”(swizzled)模式将数据传输到 SMEM,使其在使用时不会产生任何存储体冲突(bank conflicts)。

TMA 不仅可以将相同的数据传输到调用 SM 的 SMEM,还可以传输到同一 Thread Block Cluster 中的其他 SM 的 SMEM。这被称为 multicast

如果需要,TMA 还可以将相同的数据 multicast 到同一 Thread Block Cluster 中的其他 SM。一旦数据传输完成,TMA 会通知相关的消费者数据已准备就绪。

2.6 Register Dynamic Reallocation

最后一个也很猛的功能就是动态 reallocate register,也就是 Warp Group (4 个 Warps)间的 register 可以动态做 reallocate,让我们有更多的 RMEM 可以用。

三、FlashAttention 3

在 Hopper 架构下,我们可以充分利用 Warp Specialization + Intra-warpgroup overlapping 的异步性,实现计算与通信、计算与计算之间的 overlap。

A100 之前的异步Warp SpecializationWarp Specialization 的目标是掩盖通信延迟,让计算单元(如 CUDA Core / Tensor Core)尽可能满载运行。具体做法是往 SM 里塞尽可能多的 warps,通过 SM 中的 warp schedulers 在不同的 warp 间切换实现异步。例如,如果一个 warp 正在等待数据,可以切换成另一个 warp 进行计算。由于所有 warp 中所有的线程均保存在 register file 中,warp 的上下文切换是几乎没有成本的,在一个时钟周期里就可以完成。

一般而言,我们会指定一些 warp 进行数据传输(producer warp),另一些 warp 读取数据进行计算(consumer warp),两者通过 barrier 进行数据依赖的同步。通过 warp scheduler 的调度,数据复制的延时就可以很好地被计算所隐藏,反之亦然。

A100 的异步:Multistage。A100 新增的cp.async指令_,_可以在同一 warp 中实现前一块数据的计算和后一块数据通信的 overlap,因此就能通过编排流水线的方式实现异步,这就是 Multistage。由于在 warp 内部实现了异步,采用 warp 间异步的 warp specialization 便不再需要。Multistage 也是 FA2 的工程实现方式。

由于 warp 需要保留当前计算的数据以及预留后面传输过来的数据,通常 warp 要保留至少 2 份数据缓存空间,即 double buffer。如果 stage 数量进一步增加,就需要保留更多的 buffer。

H100 的异步:Warp Specialization + Intra-warpgroup overlapping。一方面,由于 TMA 在硬件上实现了数据传输的异步,我们不再需要 Multistage 那样由 warp 自行处理数据传输了。另一方面,由于 WGMMA 指令的出现,从 warpgroup 维度调度线程能够享受 WGMMA 的异步性。同时 1)Hopper 架构新增了在不同 warpgroup 间重新分配寄存器(warpgroup-wide register reallocation)的 API setmaxnreg;2)TMA 仅需一个线程发送指令即可运行。我们可以给 producer 分配最少的资源,consumer 分配更多的资源,从而最大化有效算力。因此 Warp Specialization 方案能够提供更快的运算速度。

同时,在 consumer warpgroup 内部,我们仍然可以采用 GEMM 和 softmax 的 overlap 来实现两个 warpgroup 计算和计算的同时进行,也就是 Intra-warpgroup overlapping。这就是 FA3 采用的异步策略。

由于 H100 的 Tensor Core 运算速度更快,我们需要更极致的异步来掩盖通信延时,因此结合 Warp Specialization 和 Intra-warpgroup overlapping 的优势便能够实现 FA3 快速的运算。

我们用一张图简单说明 Warp SpecializationMultistage,以及将 Warp Specialization 和 Multistage 的思想结合,变为 Ping-Pong Scheduling 这三者的区别:

三种异步的比较

3.1 Warp-specialization

Flash Attention V3可以把 data 的传递用 Producer-Consumer 的形式定义

  • Producer 可以对应到 TMA
  • Consumer 可以对应到 Tensor Core

简单来说就是 TMA 拿的数据提供给 Tensor Core 运算。

而这边所提到的 Warp-specialization,指的就是我们可以把 Thread Block 里面的 warps 分成 Producer Warp Group 和 Consumer Warp Group。

  • Producer Warp Group 做的事情就是用 TMA 把 data 从 HBM 拉到 shared memory

  • Consumer Warp Group 做的事情就是用 Tensor Core 来计算这些 data。

这边我们进一步看到演算法的地方,在 consumer warp group 的地方有 SS-GEMM 和 RS-GEMM 两种不同的矩阵运算,这边 SS 的意思就是第一个 operand 是来自 shared memory,而 RS 则是来自 register。

要先有 Q 我们才能做后面的运算,所以说 Q 一定要先用 TMA 从 HBM 拉到 shared memory,至于 K 和 V 我们可以 asynchrony 的做,所以我们一开始的时候会初始化一个 s-stage circular SMEM buffer 去纪录 KV load 到 shared memory ,所以一进入 producer 这个 for loop 的时候,我们不会管 consumer 到底有没有把 K 和 V 拿去做矩阵运算,直接继续读 KV 直到 buffer 满了,也就是经过 s 次。

而满了之后我们就会开始等 consumer 算完 attention 并释放这个 stage 的 buffer,之后 producer 才会再读取新的 K 和 V。所以算 S 的时候 source 是来自 shared memory,而算 O 的时候 source 是来自 register。

另外一个值得一提的就是我们这边会用 Register Dynamic Reallocation 去(de)allocations register,增加可以使用的 register 数量。毕竟我们可以看到我们这样分 producer 和 consumer warp group 又做 asynchronous 操作,会需要很多 register。

我们结合以下的流程图,从微观层面介绍 Warp Specialization 单个 SM 中 Producer 和 Consumer 是如何进行协作和实现异步性的。

Producer 和 Consumer 的异步
  1. producer warpgroup 获取 SMEM 缓冲区的 barrier lock。

  2. producer warpgroup 通过单个线程向 TMA 芯片发起 TMA 请求。

  3. TMA 计算所需的实际 SMEM 地址,将数据移动到 SMEM,并在移动时会进行数据布局转换(如 swizzling),以便在 SMEM 中实现最快速(无 bank conflict)的访问。

  4. 数据也可以 multicast 到其他 SM,或者可能需要等待来自其他 TMA multicast 的数据以完成加载。(thread block cluster 可以在多个 SM 之间共享 SMEM)

  5. 此时,barrier 被更新以信号通知数据已到达 SMEM。

  6. 相关的 consumer warpgroup 现在开始工作,发出多个 wgmma.mma_async 命令,这些命令将数据从 SMEM 读取到 Tensor Core,随后进行矩阵乘法计算。

  7. MMA 累加后的值在完成计算后被写入 RMEM。

  8. consumer warpgroup 释放 SMEM 上的 barrier。

  9. producer warpgroup 开始工作,发出下一条 TMA 指令以重新填充现在空闲的 SMEM 缓冲区。

  10. consumer warpgroup 同时对累加结果进行后处理(epilogue),然后将数据从 RMEM 移动到不同的 SMEM 缓冲区。

  11. consumer warpgroup 发出 cp.async_bulk 命令,将数据从 SMEM 移动到 GMEM。

从宏观层面看,为最大化提升性能,我们希望一个 SM 仅占有一个 thread block,这个 block 中的 warpgroup 由多个 Producer 和多个 Consumer 组成。下面以 1 Producer + 2 Consumers 为例。

Ping-Pong Architecture
  • Producer:warpgroup 中每个线程分配 24 个 registers,主要职责是分发 TMA 指令,由 TMA 将数据从 GMEM 移至 SMEM。数据传输完成后,TMA 会通知相应的 consumer 数据已准备就绪。Producer 会推举出一个 leader 线程发送 TMA 异步指令,指令结束后即停止运行,等待 SMEM buffer 释放;
  • Consumers:每个 warpgroup 的线程分配 240 个 registers,主要职责是获取 SMEM buffer 的数据,计算 GEMM 和 softmax,释放 buffer 并通知 producer 数据已被释放。随后处理收尾的计算任务、计算结果的数据传输等工作,这也被称为 epilogue 阶段。

这里寄存器的分配个数是通过setmaxnreg 指定的。寄存器分配需要满足一系列的约束条件:

  1. setmaxnreg可指定特定 warpgroup 每个线程所分配到的寄存器数量。这个数量必须在 [24, 256] 之间,且必须为 8 的倍数;
  2. 每个 warpgroup 的每个线程分配的寄存器不超过 255 个(CUDA/NVCC 限制);
  3. 每个线程所在的所有 warpgroup 分配的寄存器总和不超过 512 个(因为一个 SM 内总共有 64k 个寄存器,一个 warpgroup 包含 128 个线程,所以每个线程只能保留 64k/128 = 512 个寄存器。在我们的例子中,每个线程都位于 1 Producer + 2 Consumers 中,因此寄存器数量为 24 + 240 + 240 = 504 < 512);
  4. 每个线程所在的所有 warpgroup 分配的寄存器总和必须为 warpgroup 数量的整数倍(例如这里有 3 个 warpgroup,那么 24/240/240 共 504 个寄存器恰好是 3 的倍数,而 32/240/240 即使符合上面的三个条件,但总和 512 并非为 3 的倍数,因此不成立)。

为尽可能减少 Producer 的寄存器,增加 consumer 的寄存器,24/240/240 就是 1 Producer + 2 Consumers 的最佳分配方案。对于 1 Producer + 3 Consumers 而言,32/160/160/160 也是最佳的分配方案。

Producer 和 Consumers 之间的通信机制是依靠 CUTLASS 的 Asynchronous Pipeline Class + Barriers 来实现的。

3.2 Pingpong scheduling

作者看起来对于读数据和运算同时做还不够满意,所以又加上了一个矩阵运算和 Softmax 同时运算的效能优化。主要原因是因为softmax 当中的 exp 运算是由 multi-function unit运算,所以说当 Tensor Core 做矩阵运算的时候,我们同时可以做 softmax 的运算。

这边我们可以看到下面的 pipeline,我们主要会有三个运算的步骤

  • 第一个是 QK 矩阵运算 GEMM0
  • 第二的是Softmax算出P
  • 最后一个 PV 矩阵运算 GEMM1

如果有两个 warp group,我们可以用黑色的虚线也就是 synchronization barriers,强制 warp group 2 GEMM0 做完后,warp group 1 才能做 GEMM1 ,而这个就是所谓的 Pingpong scheduling。原本 Flash Attention V2 在算 Softmax 的时候,会浪费掉算力,但是使用了 Pingpong scheduling 我们可以把矩阵运算塞满整个时程。

Ping-Pong Scheduling

Ping-pong scheduling 主要发生在两个 consumer warpgroup 之间。由于 WGMMA 的异步性,我们可以同时运行 softmax 和 GEMM 计算,按照下图的调度并用bar.sync在虚线处同步,可以让两个 warpgroup 轮流交替进行 GEMM 计算,以实现更高的 Tensor Core 算力利用率。

3.3 Intra-warpgroup overlapping

如果只有一个 warp group我们一样也可以做 softmax GEMM overlapping,做法就是我们会把这一个 iteration 的 PV 矩阵运算,留到下一个 iteration 算 softmax 的时候同时一起做。下图展示的是 2-stage 流水线方案。

Intra-warpgroup overlap

注意,在 2-stage 方案中,寄存器需要同时保存前一份数据 softmax 的计算结果和后一份数据 GEMM0 的计算结果,因此寄存器的压力会比没有流水线的情况要大。

我们可以从下面的演算法看到,在进入 inner for loop 之前我们会先计算第一个 Scur = QK(第 4 行),然后计算 softmax(第 6 行)。接下来进入到 inner for loop 我们会先在一开始,就去计算下一笔数据的 Snext = QK(第 9 行),而计算的同时我们马上把这一笔的 V load 进来,然后发起运算这一笔数据的 O = PcurV 运算(第 11 行),到了下一步之后,我们会等待刚刚下一笔的 Snext 算完,然后接着计算 softmax(第 13 行),这个时候 O = PcurV 也正在同时运算,最后我们再把 Snext 复制到 Scur,接着下一个 iteration,依此类推。

接下来我们来看 Backward 的部分,基本上一样也是把它拆成 producer 和 consumer warp group 一个 load 数据一个做运算。我们可以看到因为我们要先 recompute S = QK(第 21 行)和回推 dP = dOV(第 23 行),所以这个地方是 SS-GEMM,而 dV(第 26 行)和 dK(第 27 行)的更新,因为 QdO 是透过 s-stage buffer 管理的,所以是 RS-GEMM。

不过这个地方有个麻烦的东西, 就是 dQ 的更新 ,他必须拉上来做更新

dQi ← dQi + dSi(j)Kj ∈ ℝBr × d

因为这里不像 dVdK 是针对 j 维度去更新,dQ 是针对 i 维度去更新,会让不同的 thread block 同时对同一个地方进行写入,所以会造成 memory contention 的问题。所以这个地方作者开了另一个 warp 专门来处理 dQ ,也就是说 dSK(第 28 行)算出来的东西,照理来说我们要加回 dQ,但是我们聪明的作者,使用了 semaphore(信号量) 把算出来的结果以 atomic 的方式加回去 HBM 上的 dQ

3.2.1 3-stage pipelining

Flash Attention V3 当中,把整个 attention 的运算分成两个 stage 完成,但我们的作者又想到了一个更疯狂的 3-stage pipeline。基本上这个想法就是因为 softmax 花做多时间运算,所以除了把上一次的 PV 运算和这一次的 Softmax 运算同时进行外,当上一次的 PV 算完后,我们马上计算下一次的 QK

Intra-warpgroup overlap

理论上,三个计算步骤可以安排 3-stage 流水线,但由于寄存器数量的限制,强行编排三级流水线,要么会造成寄存器溢出,极大程度影响性能,要么只能选择更小 block size,同样会影响性能。FA3 经性能测试后,采用了 2-stage 的方案。

3.4 Persistent Kernel

在工程实现方面,FA3 算子在每个 SM 上会启动一个 persistent Kernel,成为一个 persistent thread block。这个 persistent block 在它的生命周期内(一次 kernel launch 的计算中)可以处理多个 thread block 的 tile 分块数据,在两个 thread block 的计算之间,可以将前一个 block 的 kernel prologue 阶段和后一个 block 的 launch 阶段同时进行,由此掩盖了同一 SM 上先后两个 thread block 切换的延迟。

在早期的架构上,在 SM 上并发运行多个 thread block 就能很好地处理延时问题。但在 Hopper 架构下,Tensor Core 的计算已经非常快了,这就要求有更深的流水线来掩盖延时,而更深的流水线阻碍了在一个 SM 上运行多个 thread block,因此 persistent thread block 可以在多个 tiles 和多个 warpgroups 上运行 collective main loops。

Persistent Kernel 是一种宏观上的概念,而 Stream-K 算法是一种 Persistent Kernel 的工程实现,具有负载均衡的特性。

四、FP8 低精度运算

4.1 contiguous issue

Flash Attention V3 另一个突破点,就是支持 FP8 精度的运算,但是在 FP8 又会出现新的问题,当给定一个M × K 矩阵𝐴 矩阵𝐵 ,做 A × 𝐵矩阵运算的时候

  • 如果外部 M 或 N 的维度是 contiguous 的,我们会说 A 或 B operand 是 mn-major
  • 如果内部 K 的维度是 contiguous 的,则是 k-major。

虽然 FP16 精度在 SMEM 能接受 mn-major 和 k-major 的输入 operand,但是 FP8 只能接受 k-major 的输入 operand。PyTorch 官方文档解说,如果有一个 tensor x,当我们使用 x.contiguous()的时候,如果这个 tensor 在记忆体当中本身就是连续的,就会返回 x 本身,但如果不是,则会 copy 一份 x 并回传一份连续的 tensor。而操作 x.transpose(0,1)运算时,其并不会创建一个新的 tensor,而是去改变这个 tensor 的 meta data,像是 offset 和 stride,所以经过 transpose 的 x 就不是 contiguous 的。

而 tensor 不是 contiguous,并不是指矩阵里面 element 的 address 随便分散在记忆体当中,而是指 element adress 的 order 被改变过了,所以是不连续的。attention 的两个矩阵运算

S = QK, P = Softmax(S), O = PV

而 FP8 只能支持 k-major,所以必须确保QK运算的时候 head dimension 是连续的,而 PV 运算的时候 sequence dimension 是连续的。

关于这一点一般来说,TMA load 进来的 QKV 基本上 head dimension 会是连续的,所以QK的运算没有问题,主要会有问题的是 PV 运算,我们需要额外加一个 Transpose 让 sequence dimension 符合 k-major。所以在这边作者用了一个方式,就是当 V 加载到 SMEM 后,对 V 做 in-kernel 转置。

实作的方式就是使用 LDSM/STSM 指令 ,这两个指令分别代表的是把数据从 SMEM load 到 RMEM 和把数据从 RMEM store 到 SMEM,而因为这两个指令都是 register efficient 的指令,也就是不会用太多 register,所以我们可以把 in-kernel transpose 操作放在 producer warp group 当中。

4.2 WGMMA Layout

另外一个麻烦的点就是,因为我们想要把所有 attention 的运算用 single kernel 来算,但是 FP32 accumulator 会和 FP8 operand layouts 产生冲突。这边我们可以看到底下针对 FP32 和 FP8 的 WGMMA Layout,他们是长得不一样的。

P 目前的排布 P 目前的排布

P 最终期望的排布 P 最终期望的排布

直观上看是必须做不同线程之间的数据 shuffle 操作的。但论文中给了另外一种方案:只在线程内部做寄存器的交换,即将 {d0..d7} 变换为 {d0 d1 d4 d5 d2 d3 d6 d7},每个线程会对该线程在寄存器持有的所有数据,以每 8 个数据为一组做上面的变换。

实际上这样的操作等价于对 P 矩阵做了列的交换,于是我们在对 V 做转置的时候,只要做一个相同的行变换,利用 PV = P 的列变换 V 的行变换 这个数学性质,我们就可以保证矩阵运算结果的正确性,同时避免了线程间的 shuffle 操作。

为什么等价于对 P 矩阵的列交换呢?可以想象在做了线程内部寄存器的交换后,FP32 T0{d2, d3} 这个寄存器和 T0{d4, d5} 寄存器做了交换,T1{d2, d3} 这个寄存器和 T1{d4, d5} 寄存器做了交换,T2 和 T3 同理。然后将交换后的矩阵和FP8做对比,就可以发现实际上两者只有列顺序的不同,对于 T0 而言,只需要将变换后的 0,1,8,9 列再移动到前四列,就成为了FP8的排布。

线程内交换寄存器后,P 矩阵仅需做列交换就可以完成数据排布转换

4.3 Block Quantization

针对 Quantization,因为 Flash Attention V3 是一个 block 一个 block 来运算,所以 Quantization 也是以 Block 为单位。把 Quantization 这件事情 Fused 到 Attention 的前一个步骤,就是 Rotary Embedding ,而因为 Rotary Embedding 是 memory-bandwidth bound,也就是说 I/O 时间大于运算时间,所以这边进行 Quantization 操作并不会减慢运算速度

4.4 Incoherent processing

最后就是如何避免 Quantization Error,这里我们可以在进入 Quantization 之前做这样的操作:(QM)(KM) = QK,其中 MM = I,而这里的 M 是一个随机的 orthogonal matrix。

而因为 QMKM 的每个 entry 都是 QK entry 的 random sum,所以这个方法可以帮助我们来降低 Quantization Error。

另外为了加速 QMKM 的运算,这边使用了 Hadamard transform ,也就是把 M 设为值为±1 的 random diagonal matrices 和 Hadamard matrix 的 product,所以在矩阵相乘的时候时间复杂度就可以从 O(d2)降到 O(dlog d),同样的这边 QMKM 的运算一样可以 Fused 到 rotary embedding 当中。

五、Performance

FA3 算子的前向 BF16 最高算力可达 850 TFLOPS,对比 FA2 约 300 TFLOPS 的算力,FA3 提升了 2.8x 的算力利用率,且这个算力已经非常接近 H100 的峰值算力 989.4 TFLOPS 了。

BF16 前向性能

FP8 精度的算力最高也有 1322 TFLOPS,对比 H100 峰值算力 1978.9 TFLOPS,FA3 对硬件 FP8 的算力利用率也非常夸张了。

FP8 前向性能

在 FA3 论文中还展示了 GEMM-Softmax pipelining 和 Warp Specialization 两个方案的消融实验,以及 FP8 量化的数值误差。

Reference

  1. FlashAttention-3 Paper
  2. FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
  3. Flash Attention 3 深度解析
  4. 榨乾GPU效能的Flash Attention 3
  5. 英伟达又赚到了!FlashAttention3来了:H100利用率飙升至75%

FlashAttention-3:具有异步性和低精度的快速准确注意力机制
https://mztchaoqun.com.cn/posts/D98_FlashAttention-3/
作者
mztchaoqun
发布于
2025年12月2日
许可协议