PPO算法详解
PPO(Proximal Policy Optimization,近端策略优化)是目前RLHF中最常用的强化学习算法,它在策略更新的稳定性和样本效率之间取得了很好的平衡。
PPO 算法原理
背景:策略优化问题
在RLHF中,我们要优化的是策略 πθ(a∣s),目标是最大化预期累积奖励:
J(θ)=Eτ∼πθ[∑t=0T−1R(st,at)]
其中 τ 是状态-动作轨迹。
重要性采样
PPO使用重要性采样,用当前策略 θ 收集样本,然后用旧策略 θold 采样得到的样本来更新当前策略,通过重要性权重修正分布差异:
J(θ)=E(st,at)∼πθold[πθold(at∣st)πθ(at∣st)Aπ(st,at)]
其中 Aπ(st,at) 是优势函数,表示动作 at 相对于状态 st 平均价值的优势。
裁剪目标(Clipped Objective)
PPO的核心创新是裁剪目标,防止策略更新过大:
JCLIP(θ)=E[min(rt(θ)At,clip(rt(θ),1−ϵ,1+ϵ)At)]
其中:
- rt(θ)=πθold(at∣st)πθ(at∣st) 是概率比
- clip(x,1−ϵ,1+ϵ) 将 x 限制在 [1−ϵ,1+ϵ] 区间内
- ϵ 是超参数,通常取0.1或0.2
为什么裁剪?
- 当优势 At>0(这个动作比平均好),clip防止策略更新过大,限制概率比不超过 1+ϵ
- 当优势 At<0(这个动作比平均差),clip限制概率比不低于 1−ϵ
- 这样防止新旧策略差异太大,保证训练稳定,避免崩溃
PPO算法流程
for 迭代 in 范围:
1. 用当前策略θ_old采样N个轨迹(prompts → responses)
2. 计算每个(state, action)的优势估计A_t
3. 最小化裁剪目标J_CLIP,更新θ:
优化K步(通常K=4)
4. θ_old ← θ
PPO 在 RLHF 中的应用
RLHF中的PPO步骤
在大语言模型RLHF中,PPO的流程如下:
1policy_model = 加载SFT模型
2reward_model = 训练好的奖励模型
3
4for iteration in range(num_iterations):
5 # 1. 采样:policy生成回答
6 prompts = sample_prompts_from_dataset()
7 responses = policy_model.generate(prompts)
8
9 # 2. 反馈:计算奖励
10 rewards = reward_model(prompts, responses)
11 # 加上KL散度惩罚
12 rewards -= beta * kl_divergence(policy_model, sft_model, prompts, responses)
13
14 # 3. 学习:更新policy
15 for epoch in range(num_epochs):
16 policy_model.train(prompts, responses, rewards)
角色分工
在RLHF的PPO中:
- Actor(演员):就是我们要训练的策略模型(大语言模型),输入prompt输出response(token序列)
- Critic(评论家):估计给定prompt下的预期总奖励,帮助计算优势函数
- Reference:初始SFT模型,计算KL散度惩罚
- Reward Model:提供奖励信号
用通俗比喻:Actor是学生,Critic是班干部帮着估算得分,Reward Model是老师给分,KL惩罚是不让学生学得太偏。
KL 散度约束
为什么需要KL约束?
在RLHF训练中,如果没有KL约束:
- 模型会为了获得更高奖励快速漂移,远离初始SFT分布
- 可能生成虽然奖励高,但语法不通顺、不自然的文本
- 奖励模型过拟合会导致模型性能崩溃
KL惩罚形式
奖励计算加入KL惩罚:
R=rRM(x,y)−β⋅KL(πθ(y∣x)∥πSFT(y∣x))
其中:
- x 是prompt,y 是生成的response
- β 是KL惩罚系数
- KL散度衡量当前策略和初始SFT策略的差异
KL自适应调整
实践中,通常会自适应调整β:
- 如果KL太大(漂移太快),增大β惩罚
- 如果KL太小,减小β鼓励探索
PPO 超参调优
关键超参数
| 超参数 | 典型值 | 作用 |
|---|
| ϵ (clip范围) | 0.1 ~ 0.2 | 控制策略更新最大幅度,越小越稳定 |
| 迭代内epoch数 | 4 ~ 10 | 对同批样本多次更新,提高样本利用率 |
| 批量大小 | 64 ~ 256 | 更大batch更稳定,但显存要求高 |
| 学习率 | 1e-6 ~ 1e-5 | RLHF中学习率通常比SFT小 |
| KL惩罚系数 β | 0.01 ~ 0.1 | 控制模型漂移,越大越接近初始SFT |
| 优势估计正则化 | - | 使用GAE(Generalized Advantage Estimation)减少方差 |
训练稳定性技巧
- 不要更新太猛:PPO允许同一个样本多轮更新,但不要太多轮,否则容易过拟合
- 梯度裁剪:防止梯度爆炸,通常裁剪到0.5或1.0
- 学习率退火:训练过程中逐步减小学习率
- 提前停止:监控验证集奖励,不要过度训练
PPO 在 RLHF 中的采样
什么是采样过程?
采样就是模型根据prompt输出回答的过程,相当于学生答题,是收集训练数据的过程。
采样策略
PPO中策略由两个部分组成:
- Actor:负责生成回答(决策)
- Critic:估计当前状态下的预期总收益(总结得失)
收益评估
Critic输出的是从当前token开始,能够获得的期望总奖励,结合Reward Model给出的即时奖励,计算优势函数。
面试常见问题
Q: PPO为什么需要裁剪目标?
A: 裁剪的目的是限制新旧策略的差异不要太大,防止一次更新把策略改得太多,导致训练不稳定甚至崩溃。重要性采样允许用旧样本更新,但差异太大时重要性采样方差很高,裁剪可以把过大的更新剪掉,保证训练稳定。
Q: PPO的优缺点是什么?
A:
- 优点:训练稳定,样本效率比信赖域方法高,实现简单,是目前RLHF的标准算法
- 缺点:需要多个模型同时在显存中(actor、critic、reference、RM),显存占用高,训练流程复杂,计算成本高
Q: RLHF中为什么需要KL散度约束?
A: KL散度约束防止当前策略离初始SFT模型太远,避免模型漂移,防止奖励模型的错误引导导致模型生成不自然的文本,稳定训练过程。
Q: PPO在RLHF中同时需要哪几个模型?它们各自的作用是什么?
A: 通常需要4个模型:
- Actor(当前策略):正在训练的大模型,生成回复,需要更新参数
- Critic:估计状态价值,计算优势函数
- Reference SFT:计算KL散度惩罚,防止漂移
- Reward Model:提供奖励信号
所以说PPO训练对显存要求很高,四个模型都需要放在GPU上。
Q: 什么是优势函数?为什么需要它?
A: 优势函数 A(s,a)=Q(s,a)−V(s) 衡量动作a相对于状态s平均价值的优势。它告诉我们这个动作比平均好还是坏,好多少。使用优势函数可以降低估计方差,让训练更稳定。