0%

PyTorch多卡训练初探

在自己当前的分类模型集成框架里加入了多卡训练, 借此梳理相关的概念.


PyTorch DDP(Distributed Data Parallel)包括了三个核心概念, 分别为rank, local_ranknode_rank.

  • rank: 即Gobal Rank, 为物理GPU的编号
  • local_rank: 当前节点的进程编号. 每个节点GPU都从0开始编号
  • node_rank: 当前节点的编号, 在多节点训练时使用

需要注意的是, local_ranktorchrun自动传入的环境变量, 可以通过以下方式获取:

1
local_rank = int(os.environ.get('LOCAL_RANK', 0))

此外pytorch的DDP中还有一些需要初始化的参数, 包括world_size, 为全局进程总数, 是所有节点上所有GPU进程的总数量; nproc_per_node是单节点进程数, 表示单个物理机上启动的训练进程的数量, master_addr是主节点IP地址, 用于初始化分布式进程组的rank=0进程所在机器的IP地址; master_port为主节点端口号, 用于初始化分布式进程组的协调进程所监听的TCP端口号, 与master_addr共同指定了协调服务的位置; backend为通信后端, 用于指定进程间通信所用协议, 常用nccl来进行GPU之间的高速通信.

结合torchrun启动分布式训练的命令行如下:

1
CUDA_VISIBLE_DEVICES=1,2 torchrun --nproc_per_node=2 train.py

启动后首先需要对模型进行DDP场景下的加载, 对应代码如下:

1
2
3
4
5
6
7
8
9
10
11
if config.train.distributed:
# 启用 SyncBatchNorm
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = model.to(config.device)

# 使用 DistributedDataParallel 封装模型
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[int(config.device.split(":")[1])],
output_device=int(config.device.split(":")[1])
)

以上代码中的nn.SyncBatchNorm.convert_sync_batchnorm(model)的核心作用是跨所有GPU聚合批次的统计信息, 用于共享GPU之间不同数据的计算结果. nn.parallel.DistributedDataParallel(...)在每个进程/GPU 上拥有一个完整的模型副本, 独立进行前向和反向传播. 在反向传播完成后, 它会自动协调(使用 All-Reduce 操作)所有副本的梯度, 并同步更新模型参数, 确保所有 GPU 上的模型保持一致.

此外, 在进行分布式训练时需要对数据的加载进行对应处理, 使不同GPU读取到不同的训练或测试数据, 对应的代码实现如下:

1
2
3
4
5
if config.train.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) # 创建分布式数据采样器
else:
train_sampler = None
# train_loader = DataLoader(train_dataset, sampler=train_sampler, ...)

以上代码中的torch.utils.data.distributed.DistributedSampler为分布式采样器, 根据当前的rank和world_size自动将整个数据集逻辑地分割成world_size个不重叠的子集, 从而保证每个GPU在每个epoch中都能获取唯一且互不重叠的数据批次.

在训练进行中, 同样需要对数据采集器进行必要的同步操作, 对应代码如下:

1
2
3
if config.train.distributed:
# 在每个 epoch 开始时调用
train_loader.sampler.set_epoch(epoch)

以上代码中的epoch为当前轮次的编号, 该函数的核心功能是重新打乱数据, 确保每个epoch的数据划分都是随机且均匀的.

对于模型的运行与训练结果保存, 在模型训练层面, DDP模式下训练循环和单卡类似, 在梯度反向传播中DDP机制会自动在所有进程之间同步梯度, 无需额外干预和对代码的修改. 对于训练权重的保存, DDP模式只在rank=0的进程上执行保存操作.

通过以上操作即可实现多卡的模型训练与保存.