自注意力机制
概述
自注意力(Self-Attention),也叫内部注意力,是Transformer的核心机制。它允许序列中的每个位置都能关注到序列中其他所有位置,通过加权求和的方式整合全局信息,从而得到每个位置更好的表示。
相比RNN的顺序计算和CNN的局部感受野,自注意力能够直接建模任意两个位置之间的依赖关系,不管它们相距多远。
Self-Attention 计算流程
完整计算分为四步:
步骤1:计算Q/K/V
将输入X分别通过三个可训练的权重矩阵投影,得到查询(Query)、键(Key)、值(Value):
Q=XWQK=XWKV=XWV
其中:
- Q(查询):当前位置要"查询"的信息向量,表示"我在找什么"
- K(键):每个位置"提供"的信息向量,表示"我这里有什么"
- V(值):每个位置实际输出的内容向量,表示"我要输出什么"
维度说明:
- 输入序列长度:n = seq_len
- 输入维度:d_model
- Q/K维度:d_k
- V维度:d_v(通常d_k = d_v = d_model / num_heads)
步骤2:计算注意力分数
通过点积计算Query和每个Key的相似度,得到注意力分数:
Attention Score=QKT
得到一个 (n × n) 的注意力分数矩阵,矩阵中第i行第j列表示第i个位置对第j个位置的注意力分数。分数越高,表示越关注。
步骤3:Scaling + Softmax
将注意力分数除以 √d_k 然后用Softmax归一化,得到注意力权重:
Attention(Q,K,V)=softmax(dkQKT)V
这就是著名的 Scaled Dot-Product Attention。
步骤4:加权求和得到输出
用Softmax得到的注意力权重对V做加权求和,得到最终输出:
Output=AV
其中A是注意力权重矩阵,形状 n × n,V是n个Value向量,输出形状 n × d_v。
Scaled Dot-Product Attention 公式推导
完整推导如下:
设输入序列长度为 n,每个Query和Key维度都是 d_k:
- Q:
(n, d_k), K: (n, d_k), V: (n, d_v)
- 计算点积:
QK^T: (n, n),元素 (i,j) = Σ_{t=1到d_k} Q_{i,t} × K_{j,t}
- 缩放:每个元素除以
√d_k
- Softmax:对每一行做Softmax,得到概率分布,每行和为1
- 乘以V:得到最终输出
(n, d_v)
最终公式:
Attention(Q,K,V)=softmax(dkQKT)V
为什么要除以 √d_k?
这是一个非常高频的面试题。原因要从方差的角度来分析:
假设 Q 和 K 的各个分量都是均值为0,方差为1的独立随机变量:
- QK^T 中每个元素
a_ij = Σ_{t=1}^{d_k} q_{it} × k_{tj}
- 这个和的方差是多少呢?Var(qk) = E[qk]^2 - (E[qk])^2,因为E[q] = E[k] = 0,E[q^2] = 1,所以 Var(qk) = E[(qk)^2] = E[q^2]E[k^2] = 1×1 = 1
- 所以 Var(a_ij) = Σ_{t=1}^{d_k} Var(q_{it}k_{tj}) = d_k × 1 = d_k
可以看到,点积结果的方差会随着d_k增大而增大。当d_k很大时,点积结果会分布在一个很大的区间,某些点会特别大,特别小。
Softmax函数在输入绝对值很大时,输出会接近one-hot分布(即只有一个位置接近1,其他接近0),这会使得梯度非常小,导致训练困难。
除以 √d_k 之后,方差重新变回1:
Var(dkaij)=dk1Var(aij)=dk1×dk=1
这样,不管d_k多大,点积输出的方差都稳定在1,Softmax不会进入饱和区,梯度保持健康,训练更稳定。
总结一下:缩放是为了防止点积结果方差过大导致Softmax饱和,从而稳定训练。
多头注意力(Multi-Head Attention)
多头注意力就是把原始的d_model维度分成h个小的头,每个头独立计算自注意力,然后把结果拼接起来。
公式表示:
MultiHead(Q,K,V)=Concat(head1,head2,...,headh)WO
headi=Attention(QWiQ,KWiK,VWiV)
多头注意力的优势:
- 不同头可以学习到不同的注意力模式
- 一个头可以关注局部语法关系,另一个头可以关注远距离语义依赖
- 多头相当于并行多个独立的注意力,增强表达能力
- 拆分多头后每个头维度减小,总的计算量和单个大注意力差不多
掩码注意力(Masked Attention)
在Decoder的自注意力中,为了保证自回归生成的因果性,每个位置只能关注到它之前的位置,不能看到未来的位置。因此需要在Softmax之前给未来位置加上一个无穷大的负掩码:
Attention Score[i][j]=−∞,if j>i
这样经过Softmax后,未来位置的权重就变成了0,相当于不会被关注。
掩码可以保证训练时不会泄露答案,和推理时的行为保持一致。
面试常见问题
1. 自注意力的时间复杂度和空间复杂度是多少?
回答要点:
- 时间复杂度:O(n² × d_model),其中n是序列长度,需要计算n×n的注意力矩阵
- 空间复杂度:O(n²),需要存储n×n的注意力权重矩阵
- 相比之下,RNN时间复杂度是O(n × d_model²),空间复杂度O(d_model²)
- 短序列:自注意力更快;长序列:自注意力计算量增长快
2. Self-Attention、Cross-Attention、Masked Self-Attention 有什么区别?
回答要点:
| 类型 | Q来源 | K/V来源 | 掩码 | 使用场景 |
|---|
| Self-Attention | Q = K = V 都来自同一层输出 | 同一层输出 | 无(双向) | Encoder自注意力,所有位置可见 |
| Masked Self-Attention | Q = K = V 都来自同一层输出 | 同一层输出 | 有掩码,只能看前面 | Decoder自注意力,保证因果性 |
| Cross-Attention / Encoder-Decoder Attention | Q来自Decoder上一层输出 | K/V来自Encoder输出 | 无 | Decoder中关注源序列 |
3. 为什么点积注意力比加法注意力好?
原始论文中对比了两种注意力:
- 点积注意力:计算QK^T,并行性好,利用矩阵乘法加速,实际更快
- 加法注意力:用一个MLP计算注意力分数,理论上可以处理非线性,但难以并行
现在几乎都用点积注意力,因为硬件对矩阵乘法优化得很好,点积更快。
4. 自注意力如何处理变长输入?
自注意力本身不限制输入长度,可以处理任意长度的输入,只需要根据实际序列长度计算对应的注意力矩阵即可。但是长度增加会导致计算量平方增长,O(n²),所以太长会很慢且耗显存。
5. 自注意力比RNN好在哪里?
回答要点:
- 并行计算:整个序列可以同时计算,RNN必须顺序计算,训练速度慢很多
- 长程依赖:直接建模任意位置连接,不会梯度消失,RNN长距离信息容易丢失
- 计算路径长度:任意两个位置之间只需要一步,RNN需要n步,减少长程信息损失
6. 自注意力有什么缺点?
回答要点:
- 计算复杂度:O(n²),当序列很长时,计算量和显存占用增长很快
- 没有内置位置信息:需要额外加位置编码,RNN本身就有顺序性
- 对短序列冗余:很多注意力权重其实是不必要的
7. 什么是注意力偏置?为什么需要它?
注意力偏置是在计算注意力分数时额外加上一项,常见于相对位置编码(比如T5、ALiBi)。通过在注意力分数矩阵上加上与相对位置相关的偏置,让模型能够感知到位置信息,而不需要直接加在embedding上。