# 一、为什么需要多 GPU?

当我们训练深度学习模型(特别是大模型)时,单张 GPU 往往不够:

  • 数据太大,单卡放不下;
  • 模型太复杂,训练太慢;
  • 想更快训练完成。

于是就有了「多 GPU 训练」这个想法:
让多张显卡一起干活,分担计算任务。


# 二、多 GPU 的三种主要方式

多 GPU 的实现方式有很多,但核心分为三类

方式 中文名称 思路 优缺点
1. 模型并行(Model Parallel) 模型拆开 把模型的不同层分给不同 GPU 适合模型太大放不下一张卡;通信频繁,实现复杂
2. 数据并行(Data Parallel) 数据拆开 模型每张卡都复制一份,每张卡训练一部分数据 通用、最常见
3. 分布式数据并行(Distributed Data Parallel, DDP) 高效数据并行 每张卡独立进程,自动同步梯度 工业级方案,最快、最稳定

# 三、核心思想:数据并行

多 GPU 训练中,数据并行是最常用的方式。

其思想非常简单:

模型每张卡都有一份副本;每张卡训练不同部分的数据;最后汇总结果。

假设你有 2 张 GPU,每次 batch 有 4 张图片:

批次数据: [1, 2, 3, 4]

GPU0 ← [1, 2]
GPU1 ← [3, 4]

训练流程:

  1. 每张 GPU 都有一份相同的模型副本。
  2. 各 GPU 前向计算 → 得到自己的 loss。
  3. 各 GPU 反向传播 → 得到自己的梯度。
  4. 梯度同步(allreduce) → 求平均。
  5. 各 GPU 同步更新参数。

这样每张 GPU 的参数始终保持一致。


# 四、关键概念详解

下面是训练中会出现的一些关键术语

概念 说明
train 训练集,用于更新模型参数。
val (validation) 验证集,用于评估模型是否过拟合,不参与训练。
test 测试集,用于最终测试模型效果。
query 查询集,常见于检索任务,用来搜索匹配结果。
gallery / bounding_box 检索任务中的数据库部分(比如人脸库、行人库)。
batch(批次) 一次喂给模型的数据量。
split_batch 把一个 batch 的数据拆成几份给不同 GPU。
allreduce 把多 GPU 的梯度相加、平均,让模型参数保持一致。
scatter 把数据分发到多个 GPU。
gather / concat 把多个 GPU 的结果合并回来。

# 五、三种实现方式详解

# 1. 模型并行(Model Parallel)

思想:

模型太大,一张卡放不下,就把不同层放到不同 GPU 上。

示意:

GPU0: 负责模型前半部分
GPU1: 负责模型后半部分

特点:

  • 模型能变得更大;
  • 但通信频繁,效率不高;
  • 实现较复杂。

适用:GPT、BERT 这类超大模型。


# 2. 数据并行(Data Parallel,DP)

思想:

每张卡都有一份完整模型副本,但处理不同部分的数据。

PyTorch 用法:

model = torch.nn.DataParallel(model)

流程:

  1. 数据被平均分配给每张 GPU(scatter)。
  2. 各 GPU 前向 & 反向传播。
  3. 主 GPU(GPU0)收集所有梯度,求平均。
  4. GPU0 更新参数,再广播回所有 GPU。

缺点:

  • GPU0 负担过重(通信瓶颈);
  • 不能多机训练。

适用:快速实验、小模型。


# 3. 分布式数据并行(Distributed Data Parallel,DDP)

思想:

数据并行的改进版,每张 GPU 独立进程,自动通信同步。

PyTorch 用法:

from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(model)

流程图示(假设2张GPU)

Step 1: 复制模型
GPU0: 模型A
GPU1: 模型A

Step 2: 拆分数据
GPU0: batch[0:2]
GPU1: batch[2:4]

Step 3: 前向+反向
各GPU独立算梯度

Step 4: allreduce同步
梯度平均后广播到每张卡

Step 5: 更新参数(同步)

一句话总结:

DDP 是 PyTorch 官方推荐的多 GPU 训练标准方案。


# 六、底层机制理解:scatter / allreduce / gather

操作 作用 举例
scatter 拆数据 把 100 张图片分成两份发到 GPU0、GPU1
allreduce 同步梯度 GPU0 的梯度 + GPU1 的梯度 → 平均后广播
gather / concat 合并结果 把每个 GPU 的预测拼起来做整体评估

这些操作是 数据并行的核心通信机制
在 DDP 中,它们是由 PyTorch 自动完成的。


# 七、在 DDP 中训练需要注意什么

事项 原因
每个 GPU 独立进程 保证同步和效率
使用 DistributedSampler 自动分配数据给不同 GPU
设置环境变量(rank, world_size) 标识 GPU 编号和总数
初始化通信(init_process_group) 建立 NCCL 通信通道
每个进程保存日志、模型 避免覆盖

# 八、三个阶段的总结对比表

阶段 干的事 特点 示例
模型并行 模型拆分 各 GPU 存不同层 GPT、BERT
数据并行(DP) 数据拆分 + 主卡同步 简单但低效 nn.DataParallel
分布式数据并行(DDP) 数据拆分 + 自动同步 高效工业级 DistributedDataParallel

# 九、类比理解:多 GPU 像工厂流水线

工人 工作
GPU0 处理样本1-50
GPU1 处理样本51-100
  • 每个工人有同样的手册(模型);
  • 干完活后汇报经验(梯度);
  • 经理平均总结经验(allreduce);
  • 更新手册;
  • 下一轮所有人继续。

这,就是数据并行和 DDP 的本质。