LayerNormalization详解
概述
Layer Normalization(LayerNorm,层归一化)是Transformer中必不可少的组件,用于稳定训练,防止梯度消失爆炸。和CV中常用的Batch Normalization(BatchNorm,批归一化)不同,LayerNorm在特征维度做归一化,更适合NLP和Transformer场景。
归一化基础
神经网络训练过程中,随着参数不断更新,每一层的输入分布会发生变化,这个现象叫Internal Covariate Shift(内部协变量偏移)。归一化的作用就是将输入分布拉回均值0方差1的稳定分布,加速训练收敛。
三种常见归一化对比:
| 归一化方法 | 归一化维度 | 适用场景 |
|---|
| BatchNorm | 对批维度(同一个batch所有样本同一特征) | CV,固定长度,大batch |
| LayerNorm | 对特征维度(每个样本单独归一化) | NLP,Transformer,变长序列 |
| RMSNorm | 只做缩放,不减去均值 | 大模型,简化LayerNorm |
| InstanceNorm | 对通道内空间维度 | 图像生成,风格迁移 |
| GroupNorm | 分组归一化 | 检测分割,小batch |
LayerNorm 公式
计算公式
对于一个样本的隐状态 x∈Rd,LayerNorm计算步骤:
-
计算均值:
μ=d1∑i=1dxi
-
计算方差:
σ2=d1∑i=1d(xi−μ)2
-
归一化:
x^i=σ2+ϵxi−μ
-
缩放平移(带可训练参数γ和β:
yi=γix^i+βi
其中ε是一个很小的数(如1e-6),防止除零。γ和β是维度d的可训练参数,让模型可以学习到最优的缩放和偏移。
核心思想
LayerNorm对每个样本单独在特征维度做归一化,不依赖batch内其他样本,所以:
- 不受batch size影响,batch size=1也能用
- 适合变长序列,每个样本长度不同也没问题
- 每个位置独立归一化,符合Transformer处理每个token的特点
RMSNorm(Root Mean Square Layer Normalization)
RMSNorm是LayerNorm的简化版本,只做均方根缩放,去掉了减去均值这一步。
计算公式
RMS(x)=d1∑i=1dxi2
yi=RMS(x)xi⋅γi
特点
- 简化了LayerNorm,去掉了计算均值和减去均值的步骤
- 计算速度更快,显存占用更少
- 实践中效果和LayerNorm差不多,甚至略有提升
- 现在很多大模型(如LLaMA)改用RMSNorm节省计算
LayerNorm vs BatchNorm
区别对比
| 对比维度 | LayerNorm | BatchNorm |
|---|
| 归一化维度 | 特征维度(每个样本) | 批维度(同一特征跨样本) |
| 均值方差 | 每个样本单独算 | 整个batch一起算 |
| batch size依赖 | 不依赖,batch=1也能用 | 依赖,batch小了方差不准效果差 |
| 变长序列支持 | 天然支持,每个长度都能算 | 不友好,padding位置统计不准 |
| 适用场景 | NLP,Transformer,RNN | CV,CNN,固定尺寸 |
为什么Transformer用LayerNorm不用BatchNorm?
- 序列长度变化:Transformer处理变长序列,不同样本长度不同,BatchNorm在变长序列上很难做
- 推理batch小:推理时batch size可能是1,BatchNorm统计的均值方差不准,效果很差
- 独立token处理:Transformer每个token位置独立处理,LayerNorm在特征维度归一化更自然
- BN假设不成立:BN假设分布在batch上稳定,但Transformer是变长,这个假设不成立
LayerNorm更符合Transformer的结构特点,所以Transformer普遍用LayerNorm。
Pre-LN vs Post-LN
LayerNorm放在哪个位置,对训练稳定性影响很大。常见的两种放置方式:
Post-LN(原始Transformer位置)
结构:
Attention -> Add -> LN -> FFN -> Add -> LN
也就是:
output=LN(x+sublayer(x))
缺点:深层网络中,梯度范数会随着深度指数增长,导致训练不稳定,容易发散。
Pre-LN(现在主流做法)
结构:
LN -> Attention -> Add -> LN -> FFN -> Add
也就是:
output=x+sublayer(LN(x))
优点:梯度范数在各层近似恒定,深层网络训练更稳定,不容易发散。
缺点:相比Post-LN,模型最终效果略差一点。
Sandwich-LN
在Pre-LN基础上,额外在输出再插一个LayerNorm,也就是两层都有。CogView用来避免值爆炸,但可能导致训练不稳定。
DeepNorm
DeepNorm是更深层网络的改进:
- 在LayerNorm之前,对残差连接做上缩放(α > 1)
- 初始化时对参数做下缩放(β < 1)
优点:可以缓解深层模型爆炸式更新,把模型更新限制在常数范围,让深层训练更稳定。
代码实现:
不同大模型使用情况
![不同模型使用的归一化方法:
| 模型 | 归一化类型 | 位置 |
|---|
| BERT | LayerNorm | Pre-LN |
| GPT | LayerNorm | Pre-LN |
| LLaMA | RMSNorm | Pre-LN |
| BLOOM | LayerNorm (+embedding后加LN | Pre-LN |
| T5 | LayerNorm | Pre-LN |
注意:BLOOM在embedding后额外加了一层LayerNorm,虽然提升训练稳定性,但可能带来性能损失。
面试常见问题
1. LayerNorm和BatchNorm的区别是什么?各适合什么场景?
回答要点:
- BatchNorm对批维度归一化,跨样本对同一特征算均值方差,适合CV大batch场景
- LayerNorm对特征维度归一化,每个样本单独算,不受batch size影响,适合NLP、Transformer变长序列场景
- Transformer处理变长输入,推理batch可能很小,所以用LayerNorm不用BatchNorm
2. RMSNorm和LayerNorm有什么区别?
回答要点:
- LayerNorm要减均值再做归一化,RMSNorm省略了减均值这一步,只做均方根缩放
- RMSNorm计算更快,参数更少,效果差不多,现在大模型越来越倾向用RMSNorm简化计算
3. Pre-LN和Post-LN的区别是什么?为什么现在都用Pre-LN?
回答要点:
- Post-LN:LayerNorm放在残差连接之后,原始Transformer做法。深层网络梯度范数会增大,训练不稳定
- Pre-LN:LayerNorm放在子层之前,残差连接里面。各层梯度范数近似恒定,训练更稳定,不容易发散
- 深度Transformer训练稳定性更重要,所以现在主流都用Pre-LN
4. 归一化为什么能加速训练?
回答要点:
- 解决内部协变量偏移:每一层输入分布变化大,归一化拉回稳定分布,让学习率可以调更大,收敛更快
- 缓解梯度消失/爆炸:让各层输出分布稳定,梯度传播更顺畅
- 让损失曲面更平滑,优化更容易,收敛更快
5. LayerNorm的可训练参数γ和β是做什么的,可以去掉吗?
回答要点:
- 如果去掉γ和β,归一化后一定是固定均值0方差1,限制了模型表达能力
- γ和β允许模型学习最优的缩放和偏移,可以恢复出最优分布
- 保留了表达能力,又有归一化的好处,一般都保留
- 去掉会降低模型表达能力,所以不建议去掉
6. DeepNorm解决什么问题?
回答要点:
- 深层Transformer随着深度增加,模型更新幅度容易爆炸,训练不稳定
- DeepNorm通过缩放残差和参数初始化缩放,把每一层的更新幅度限制在常数范围
- 让更深层网络训练更稳定,可以训练几百层的大模型