数据并行是分布式训练中最常用的并行策略之一,其核心思想是:
要保证所有worker上的模型参数一致,需要两个关键步骤:
初始同步:确保所有worker都从相同的初始化模型参数开始训练。训练开始前,通常会将0号卡的模型参数通信同步到其他卡。
迭代同步:每次训练迭代中,反向传播计算完梯度后,在优化器更新参数之前,插入reduce通信操作来规约梯度,确保所有worker上的梯度都是相同的。
由于相同的初始化 + 相同的梯度,优化器更新后可以保证所有worker上的模型参数始终一致。
梯度分桶:动机是集体通信在大张量上比在小张量上效率更高。将梯度分成多个桶批量通信。
计算与通信重叠:有了梯度分桶之后,在等待同一个桶内的梯度计算完后,就可以开始进行通信操作,让计算和通信并行执行。
跳过梯度同步:通过梯度累加,减少梯度通信的频次,例如每N步才同步一次梯度。
nn.DataParallel是PyTorch最早提供的单机多卡数据并行实现:
1import torch
2import torch.nn as nn
3
4model = Model()
5device_ids = [0, 1]
6model = nn.DataParallel(model, device_ids=device_ids)负载不均衡:输出默认汇总到第0块卡,导致第一块卡的显存占用远大于其他卡,容易出现OOM。
单进程多线程:DP使用单进程控制多GPU,受Python GIL限制,不能充分利用多CPU核心。
通信效率低:所有梯度都要汇总到主卡再广播,通信瓶颈明显,速度慢。
不支持多机多卡:DP只能在单机多卡环境下使用,无法扩展到多节点。
内存冗余:每个GPU都需要保存完整的模型副本,显存利用率低。
Q: 为什么第一块卡的显存会占用更多? A: 因为output_device默认是device_ids[0],每次输出loss都会在第一块GPU相加计算,造成额外负载。
net.module.state_dict()而不是直接保存整个网络,加载时先创建模型再加载:1# 保存
2torch.save(net.module.state_dict(), './model.pth')
3
4# 加载
5new_net = Model()
6new_net.load_state_dict(torch.load("./model.pth"))Q: DP训练时出现warning如何解决?
UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars;
will instead unsqueeze and return a vector.size_average=False, reduce=True,每个GPU的损失相加但不除以batch大小,最后汇总后再除以整个batch大小,得到正确的平均loss。DDP通过多进程实现分布式训练,每个GPU对应一个进程,解决了DP的负载不均衡和GIL瓶颈问题。核心改进在于使用Ring-AllReduce算法来均衡通信负载。
1import torch.distributed as dist
2dist.init_process_group(backend="nccl")1from torch.utils.data.distributed import DistributedSampler
2train_sampler = DistributedSampler(train_dataset)
3dataloader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)1import torch.nn.parallel.DistributedDataParallel as DDP
2model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)1python -m torch.distributed.run --nnodes=1 --nproc_per_node=2 --node_rank=0 \
2 --master_port=6005 train.py进程组初始化后,rank=0的进程会将网络初始化参数broadcast到其它每个进程,确保初始参数一致。
每个进程各自读取不同的训练数据,DistributedSampler保证进程间数据不重叠。
前向传播和loss计算在每个进程(每个CUDA设备)上独立完成,不需要gather到主进程。
反向阶段,梯度信息通过all-reduce操作,每个进程中的param.grad都会变成所有进程梯度的平均值。
为了提高效率,梯度信息被划分成了多个buckets分桶传输。
因为初始参数相同,梯度经过all-reduce后也相同,所以每个进程更新完参数后,权重自然保持一致,不需要额外broadcast。
注意:BatchNorm的running stats需要在每次迭代中从rank 0broadcast到其他进程。
Ring-AllReduce是DDP实现高效梯度同步的核心,由百度最先提出。它将通信压力分散到所有GPU上,消除了中心节点瓶颈。
假设有N块GPU,每块GPU上的梯度也被切成N份。Ring-AllReduce分为两个阶段:Reduce-Scatter和All-Gather。
定义网络拓扑:每个GPU只和相邻的两个GPU通信。
每次发送对应位置的数据给下一个GPU,同时从上一个GPU接收数据进行累加。
经过N-1次迭代后,每块GPU上都有一块数据拥有了对应位置完整的聚合结果。
依然按照相邻GPU通信的原则,但这次不做累加,而是直接替换。
以Reduce-Scatter结束时每个GPU获得的完整数据块作为起点。
再经过N-1次迭代后,每块GPU上都汇总到了完整的梯度数据。
对于K块GPU:
| 对比维度 | nn.DataParallel | DistributedDataParallel |
|---|---|---|
| 实现方式 | 单进程多线程 | 多进程,每个GPU一个进程 |
| GIL限制 | 受GIL影响,效率低 | 不受GIL限制,效率高 |
| 负载均衡 | 主卡负载不均衡 | 各卡负载均衡 |
| 扩展性 | 仅支持单机多卡 | 支持单机多卡和多机多卡 |
| 通信效率 | 低,中心节点瓶颈 | 高,Ring-AllReduce均衡负载 |
| 速度 | 较慢 | 较快 |
| 使用复杂度 | 简单,只需包装模型 | 相对复杂,需要进程管理 |
A:
A:
A: 每个进程都有自己的优化器。因为all-reduce之后每个进程的梯度都一样,初始参数也一样,所以每个进程独立更新参数后结果自然一致,不需要只在主更更新再广播。
find_unused_parameters=True,DDP会只对用到的参数做all-reduce,否则会报错。默认是False,因为开启会有一点额外开销。A:
model.module.state_dict(),和DP类似。A: