KV Cache与推理优化
KV Cache 原理
在Transformer自注意力计算中,Attention分数的计算公式为:
Attention(Q,K,V)=softmax(dkQKT)V
在自回归生成过程中,每一步只生成一个新token。如果每次都重新计算所有历史token的K和V,会造成大量重复计算。
KV Cache的核心思想:将已经计算过的历史token的Key和Value缓存下来,下次推理时只需要计算新token的K和V,直接拼接缓存即可,避免重复计算。
推理阶段KV Cache流程:
- Prefill阶段:处理输入prompt的所有token,一次性计算所有K和V,缓存起来
- Decoding阶段:每一步只计算新生成token的K和V,追加到缓存中,然后用完整的K+V计算Attention
- 这样每一步解码只需要计算一个新token的K和V,大大减少计算量
KV Cache 显存占用计算
计算公式
对于一个LLaMA风格模型:
- 模型层数:L
- 注意力头数:H
- 每个头的维度:D
- 序列长度:N
- 数据类型:每个元素占B字节
KV Cache总显存占用:
总大小=2×L×H×D×N×B
系数2来自Key和Cache各一份。
实际例子
LLaMA-13B:
- L = 40层
- H = 40头
- D = 128维
- N = 2048序列长度
- FP16,B = 2字节
计算:
2×40×40×128×2048×2=1.7GB
所以单个序列的KV Cache在LLaMA-13B中最多占用1.7GB显存,这是相当可观的。
更长序列占用更大,比如16k上下文就需要约13.6GB。
KV Cache 优化策略
传统KV缓存管理存在的问题:
- 动态变化:不同序列长度变化很大,难以预测内存需求
- 内存碎片:频繁分配释放导致碎片化,实际浪费60%-80%的内存
- 过度预留:为了应对最坏情况预留过多内存
常见优化策略:
1. 分块存储 / 分页管理
PagedAttention(vLLM):
- 将每个序列的KV缓存划分为固定大小的块(页)
- 块不需要连续存储,通过块表映射逻辑块到物理块
- 按需分配物理块,内存浪费只发生在最后一个块,浪费率低于4%
- 支持高效内存共享,并行采样时多个输出序列共享prompt的KV缓存
2. 滑动窗口优化
StreamingLLM:
- 基于Attention Sink现象:文本最初几个token总是吸收大量注意力
- 只保留初始几个token(sink)加上最近N个token在窗口内
- 踢出中间token,不需要保存所有历史
- 可以在不重新训练的情况下支持无限长度流式生成
3. 量化压缩
对KV Cache进行量化,用更低精度存储:
- FP16 → INT8:减少一半显存
- 极端情况可以INT4,但精度损失较大
- 大部分推理框架都支持量化KV Cache
4. 重用缓存
FasterTransformer优化:
- 缓存激活值和输出
- 重复使用缓存,避免多层反复计算
- GPT-3 96层只需要1/96的内存用于激活
多轮对话中的 KV Cache 管理
多轮对话场景面临的挑战:
- 对话轮次越多,KV Cache越大,显存占用持续增长
- 不同对话长度差异大,内存分配困难
- 长时间对话容易OOM
常见管理策略:
1. 滑动窗口截断
- 只保留最近N个token的KV Cache
- 丢弃最早的token
- 问题:会丢失早期上下文信息,影响生成质量
2. StreamingLLM + Attention Sink
- 保留前N个sink token + 最近M个token
- 保持生成质量的同时,将KV Cache大小固定在(N+M)以内
- 支持百万token级别的流式对话
3. SwiftInfer优化
- 基于TensorRT重新实现StreamingLLM
- 重新优化KV Cache机制和位置偏移注入
- 在StreamingLLM基础上再提升46%推理速度
4. 按需回收
- 对话结束后立即释放该对话的所有KV缓存
- 在PagedAttention中可以逐个块回收,碎片化少
面试常见问题
Q1: 什么是KV Cache?为什么需要它?
A: KV Cache是在自回归推理过程中,缓存历史token计算出的Key和Value,避免每一步都重新计算所有历史token的K和V,大大减少重复计算,提高推理速度。没有KV Cache的话,每一步都要重新计算整个序列,时间复杂度从O(1)变成O(n),推理会非常慢。
Q2: KV Cache占多少显存?如何计算?
A: KV Cache显存占用 = 2 × 层数 × 头数 × 头维度 × 序列长度 × 每个元素字节数。2来自K和V各一份。例如LLaMA-13B在2048序列长度下约占1.7GB(FP16),序列越长占用越大。
Q3: 为什么KV Cache容易造成显存浪费?
A: 因为不同序列长度变化很大且不可预测,传统连续分配方式容易产生内存碎片,而且需要过度预留内存,实际系统中常常浪费60%-80%的内存。
Q4: PagedAttention是如何解决KV缓存管理问题的?
A: PagedAttention借鉴操作系统分页思想,将KV缓存分成固定大小的块,块可以不连续存储,通过块表映射。按需分配物理块,浪费只发生在最后一块,浪费率低于4%,大幅提高内存利用率。还支持内存共享,降低并行采样的内存开销。
Q5: 多轮对话中KV Cache不断增长怎么办?
A: 常用方法:1) 滑动窗口,只保留最近N个token;2) StreamingLLM保留sink token + 滑动窗口,保持质量;3) 量化压缩KV Cache;4) 对话结束后及时回收。
Q6: KV Cache在训练阶段需要吗?
A: 训练阶段是一次性处理整个序列,并行计算所有位置,不需要缓存历史,所以不需要KV Cache。只有推理自回归生成时才需要KV Cache。
Q7: 量化KV Cache会影响生成质量吗?
A: INT8量化对质量影响很小,大多数情况几乎不可感知,可以节省一半显存,是非常实用的优化。INT4压缩更多,但可能有可感知的质量损失。