当模型大到单卡无法容纳时,我们需要使用多卡训练。流水线并行的目标是:
面临的主要挑战:
朴素模型并行将模型按层拆分到不同GPU,做一轮forward再统一做一轮backward:
存在两个主要问题:
K/(K+M),当GPU数量K越大,空置比例越接近1,资源浪费严重。中间结果内存占用大:每块GPU需要保存所有micro-batch的中间激活,内存压力大。
Gpipe是Google提出的经典流水线并行框架,主要通过两个核心技术解决上述问题:
在原有的mini-batch基础上,进一步切分成多个更小的micro-batch,让流水线跑起来:
切分后,bubble的时间复杂度变为:
bubble时间 = (K - 1) / (M + K - 1)M >= 4K 时,bubble产生的空转占比已经很小,可以忽略不计。动机:虽然切分micro-batch解决了GPU空置问题,但每个micro-batch都需要保存中间激活,内存压力仍然很大。
核心思想:用时间换空间。几乎不保存中间结果,等到backward的时候,再重新forward计算一遍。只保存每个stage的输入,其余中间结果算完就丢弃。
空间复杂度对比:
O(N * L * d / K),N是mini-batch大小,L是每层宽度,d是模型深度O(N * d / K + N * L / M),M是micro-batch数量当L变大时,Gpipe对GPU内存压力显著减小。
| 特性 | GPipe | PipeDream |
|---|---|---|
| 梯度更新 | 同步更新 | 异步更新 |
| 权重一致性 | 每轮更新后权重一致 | 使用旧权重计算梯度,有权重偏移 |
| GPU利用率 | 较低(等待流水线) | 较高,减少气泡 |
| 实现复杂度 | 简单 | 复杂 |
目前PyTorch原生的流水线接口基于Gpipe。
张量并行也称为模型并行,是一种层内并行策略:将单个层的权重切分到多个GPU上,而不是整个层放到一个GPU上。
Y = X * A:3D并行就是将三种并行策略组合起来,共同训练超大模型:
| 并行方式 | 显存效率 | 通信效率 | 实现难度 |
|---|---|---|---|
| 数据并行 | 低(每个卡都存完整模型) | 中 | 简单 |
| 张量并行 | 高(成比例减少) | 低(频繁通信) | 难 |
| 流水线并行 | 中(按层切分) | 高(P2P通信) | 中 |
显存效率排序:张量并行 > 流水线并行 > 数据并行 通信效率排序:流水线并行 > 数据并行 > 张量并行
8 * 35 * 8 = 2240块A100A:
A: 激活检查点是一种时间换空间的显存优化技术。不保存所有中间激活,只保存检查点位置的激活。反向传播时,重新计算检查点之间的中间激活。这样可以大幅减少显存占用,但增加计算量。
A: 根据ZeRO论文,虽然ZeRO3也能达到类似的显存优化效果,但张量并行通信量太高,只能限于节点内(需要NVLINK)。当GPU数量增加到千量级,3D并行的效率明显优于纯ZeRO3。
A: 不适合。张量并行需要节点内NVLINK超高速连接,没有NVLINK通信瓶颈会非常严重。万兆网条件下,ZeRO的通信量都很大,更不用说3D并行的张量并行了。这种场景下,优先尝试ZeRO,如果还是放不下再考虑PP。
| 场景 | 推荐方案 |
|---|---|
| 单GPU,显存够用 | 直接单GPU |
| 单GPU,显存不够 | Offload到CPU |
| 单节点多卡,模型能放进单卡 | DDP 或 ZeRO stage 1/2 |
| 单节点多卡,模型放不进单卡 | 张量并行 或 ZeRO stage 3 |
| 多节点多卡,高带宽网络 | ZeRO 或 3D并行 |
| 多节点多卡,低带宽网络 | DP + PP + TP + ZeRO-1 |
A: 气泡是GPU空闲等待的时间。朴素流水线并行中,很多GPU处于空闲。解决方法是将大batch切分成多个micro-batch,让流水线"流动"起来,减少气泡占比。当micro-batch数量远大于GPU数量时,气泡占比可以忽略。
A: re-materialization就是激活检查点技术。因为切分micro-batch后,如果每个micro-batch都保存中间激活,内存压力还是很大。通过不保存中间激活,反向传播时重新计算,可以大幅节省显存,代价是增加一些计算量,典型的时间换空间。
A:
A: 1F1B是PipeDream提出的调度方法,在饱和流水线后,每个step每个GPU做一次forward然后立刻做一次backward,相比Gpipe的全部forward完再全部backward,可以进一步减少气泡,提高GPU利用率。
A: 对于千亿/万亿参数模型,单种并行策略无法满足要求:
组合三种并行,可以在显存效率和计算效率之间取得最好平衡。