多头注意力与注意力变体
概述
标准多头注意力(Multi-Head Attention,MHA)虽然效果好,但在推理时需要保存很大的KV缓存,内存开销大。近年来出现了多个优化变体,包括Multi-Query Attention(MQA)、Grouped-Query Attention(GQA),以及IO感知的FlashAttention等,在保持效果接近的同时提升了推理速度,减少了内存占用。
Multi-Head Attention 回顾
原理回顾
标准多头注意力将 d_model 切分成 h 个相等的部分,每个部分是一个头,每个头独立计算注意力,最后拼接:
dk=hdmodel
每个头都有独立的 WQi,WKi,WVi 投影矩阵,每个头都输出自己的结果,最后通过一个输出投影矩阵 WO 整合。
维度变换:
- 输入:
(batch, seq_len, d_model)
- 投影后:每个头 Q/K/V
(batch, seq_len, d_k),h个头就是 (batch, h, seq_len, d_k)
- 注意力计算:每个头独立算,得到
(batch, h, seq_len, d_k)
- 拼接:
(batch, seq_len, h × d_k) = (batch, seq_len, d_model)
- 输出投影:得到最终结果
(batch, seq_len, d_model)
优势
- 多子空间学习:不同头可以学习不同类型的依赖关系
- 表达能力强:多个不同的注意力分布集成,效果更好
- 不增加太多计算量:总分维还是 d_model,计算量和单头差不多
问题
训练过程影响不大,推理时问题明显:每个头都需要保存自己的Key和Value缓存,所以KV缓存大小和头数成正比,参数量大的模型内存开销很高,推理速度被内存带宽限制。
Multi-Query Attention (MQA)
核心思想
Multi-Query Attention 由 Google 在2019年提出,核心思想是:所有头共享同一组Key和Value,只保留Query分开。
也就是说:
- Query:仍然每个头独立,保持 h 个Query投影
- Key/Value:整个层只保留一组Key/Value投影,所有头共享
这样,KV缓存大小就减少了 h 倍,大大减少了内存占用,提升推理速度。
优缺点
优点:
- KV缓存大小减少 h 倍,显存占用大幅降低
- 推理速度明显提升,尤其是长序列推理
- 参数减少,模型推理时内存访问更快
缺点:
- 会带来轻微的效果损失,但大部分任务下降幅很小
- 实际工程中可以通过略微增大模型参数量弥补效果损失
应用模型
PaLM、ChatGLM2、Falcon 等都采用了 Multi-Query Attention。不同模型的适配方式不同:
- Falcon:把隐藏维度从 4096 增大到 4544,多余参数分配给Attention和FFN
- ChatGLM2:把FFN中间维度从 11008 增大到 13696,多余参数分给FFN
Grouped-Query Attention (GQA)
核心思想
Grouped-Query Attention 是 MHA 和 MQA 的折中:不是所有头共享一组KV,而是把多头分成若干组,组内共享KV。
- MHA:每个头一组,每组一个KV → 效果最好,KV最大
- GQA:g 个头一组,每组共享一个KV → 效果和KV大小都在中间
- MQA:所有头一组,整个层一个KV → 效果略差,KV最小
优势
GQA在效果和速度之间取得了很好的平衡:
- 相比MHA,KV缓存减少,推理更快
- 相比MQA,效果下降更少,更接近MHA
- 论文实验表明GQA-8在保持效果几乎不变的情况下,可以达到和MQA相近的推理速度
应用模型
LLaMA2-34B/70B、ChatGLM2 等使用了 Grouped-Query Attention。
三种架构对比
| 特性 | Multi-Head Attention | Grouped-Query Attention | Multi-Query Attention |
|---|
| KV分组 | 每个头独立 | g组共享 | 全层共享 |
| KV缓存大小 | 大(h倍) | 中等(h/g倍) | 小(1倍) |
| 推理速度 | 慢 | 中 | 快 |
| 模型效果 | 最好 | 接近MHA | 轻微下降 |
| 适用场景 | 训练、小模型 | 大模型推理平衡 | 大模型快速推理 |
FlashAttention
核心问题
标准注意力计算需要把整个QKV矩阵都读到高带宽内存(HBM)中,计算中间结果也要写回HBM。HBM容量大但访问速度相对较慢,而芯片上的SRAM速度快但容量小。FlashAttention利用了SRAM做分块计算,减少了对HBM的访问。
核心思想
FlashAttention的核心是分块Softmax + 重计算,用分块计算来适配SRAM容量:
- 分块:把Q/KV按序列长度切分成小块,依次把小块读到SRAM中计算
- 在线Softmax:分块计算Softmax,通过保留logsumexp的信息,保证分块计算结果和全局计算等价
- 避免HBM读写:大部分计算在SRAM完成,只在输入输出读写HBM,减少了数据移动
- 反向传播重计算:不保存正向的中间结果,反向传播时重新计算,节省HBM空间
优点
- 显存节省:相比标准实现,可以节省大量HBM空间,支持更长序列
- 速度更快:减少了HBM访问,更高效地利用SRAM,实际速度提升明显
- 不改变结果:数学上等价于标准Softmax,没有精度损失
关键词
- HBM(高带宽内存):GPU显存,容量大但速度较慢
- SRAM(片上静态内存):速度快容量小,在GPU上就是on-chip缓存
- 分块Softmax:保证分块计算结果等价于全局计算
- 重计算:反向不存中间结果,重新计算节省空间
- Kernel融合:把多个算子融合成一个kernel,减少启动开销
应用
LLaMA、Falcon等主流开源模型都使用FlashAttention加速训练和推理。
其他注意力变体
稀疏注意力(Sparse Attention)
- 思想:不是每个位置都关注所有位置,只关注局部或某些特定位置,减少计算量
- 代表:Sparse Transformer、Longformer的滑动窗口注意力
- 优点:O(n × w)复杂度,适合长序列;保留局部相关性
- 缺点:需要特殊实现,难以通用加速
线性注意力(Linear Attention)
- 思想:改变计算顺序,将复杂度从O(n²)降为O(n)
- 原来:
softmax(QK^T)V,先算n×n矩阵,再乘V → O(n²)
- 现在:
Q (K^T V),先算d×d矩阵,再乘Q → O(n × d²)
- 优点:线性复杂度,适合超长序列
- 缺点:近似计算,效果略有下降
低秩近似注意力
- 思想:注意力矩阵是低秩的,可以用低秩分解近似,减少计算量
- 通过分解大矩阵为几个小矩阵乘积,降低复杂度
金字塔形注意力/分层注意力
- 不同层使用不同的窗口大小,底层关注局部,高层关注全局
- 混合精度计算,平衡效果和速度
面试常见问题
1. MQA为什么能提升推理速度?
回答要点:
- 推理时需要保存KV缓存,每个新生成token都要追加到KV缓存,供下一次计算
- 标准MHA每个头都有独立的K和V,KV缓存大小和头数成正比
- MQA所有头共享同一组KV,KV缓存大小减少h倍,显存占用减少
- 推理速度瓶颈主要是内存带宽(内存受限),减少显存就是减少内存访问,所以速度提升
2. MQA、GQA之间的区别和联系?
回答要点:
- MQA是GQA分组数为1的特例,GQA是MQA和MHA的推广
- MHA:每个头一个KV → 效果好,KV大,慢
- GQA:分成g组,每组共享KV → 效果、速度、内存都在中间
- MQA:只有一组,全层共享KV → 效果略降,KV最小,最快
- 实际应用中,GQA往往是更好的平衡点
3. FlashAttention为什么快?为什么省显存?
回答要点:
- IO感知:标准注意力计算频繁读写HBM,FlashAttention利用片上SRAM做分块计算,减少HBM访问次数
- Kernel融合:多个小算子融合成一个大kernel,减少kernel启动开销和数据移动
- 重计算技巧:反向传播不保存正向中间结果,需要时重新计算,节省显存
- 数学上和标准注意力完全等价,没有精度损失,只是更高效利用硬件
4. 标准注意力计算的IO瓶颈是什么?
回答要点:
- 注意力计算需要访问Q/KV三次(读),写注意力矩阵一次,写输出一次
- 这些都要在HBM中读写,HBM带宽是瓶颈,计算不是瓶颈
- FlashAttention通过分块让大部分计算在高速SRAM进行,只读写一次输入输出,大大减少HBM访问
5. 为什么现在大模型都在优化KV缓存?
回答要点:
- 自回归推理每一步都需要用到之前所有位置的KV,必须缓存下来
- 大模型参数量大,序列长,KV缓存占用显存很大,经常成为推理瓶颈
- 优化KV缓存可以直接减少显存占用,提高推理吞吐量,支持更长上下文
6. 稀疏注意力真的更快吗?
回答要点:
- 理论上计算量更少,但实际硬件很难加速不规则稀疏矩阵计算
- 访存模式不规则,缓存命中率低,往往反而比稠密计算更慢
- 除非有专门的硬件支持,否则实际收益不大