pytorch学习

简单例子(包含数据集加载、训练、模型保存和测试)

一个基本的例子:https://zhuanlan.zhihu.com/p/508721527

加载数据

1
2
torchvision.datasets.CIFAR10
torch.utils.data.DataLoader()

训练网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()

# 第一层卷积:输入通道3(RGB图像),输出通道6,卷积核大小5×5
self.conv1 = nn.Conv2d(3, 6, 5)

# 池化层:2×2最大池化,步长为2
self.pool = nn.MaxPool2d(2, 2)

# 第二层卷积:输入通道6,输出通道16,卷积核大小5×5
self.conv2 = nn.Conv2d(6, 16, 5)

# 全连接层(线性层)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 输入维度400,输出120
self.fc2 = nn.Linear(120, 84) # 输入120,输出84
self.fc3 = nn.Linear(84, 10) # 输入84,输出10(对应10个类别)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) #relu是一种常见的激活函数之一
x = self.pool(F.relu(self.conv2(x))) #卷积之后池化,然后拉平的一维向量传递给线性全连接层
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 这样一个网络就定义好了
net = Net()

这里的维度变化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
输入: [batch_size, 3, 32, 32]  # 3通道RGB图像

# conv1后
conv1: [batch_size, 6, 28, 28] # (32-5+1)=28,6个通道

# pool1后
pool1: [batch_size, 6, 14, 14] # 2×2池化,尺寸减半

# conv2后
conv2: [batch_size, 16, 10, 10] # (14-5+1)=10,16个通道

# pool2后
pool2: [batch_size, 16, 5, 5] # 2×2池化,尺寸减半

# 展平后
flatten: [batch_size, 16×5×5=400]

# 全连接层
fc1: [batch_size, 120]
fc2: [batch_size, 84]
fc3: [batch_size, 10]

训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch.optim as optim
import torch.nn as nn

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.1)

# 开训
for epoch in range(2):
runnning_loss = 0.0
for i,data in enumerate(trainloader,0):
input,labels = data
optimizer.zero_grad()
output = net(input)
loss =criterion(output,labels)
optimizer.backward()
optimizer.step() # 优化器优化,反向传播后直接更新参数
running_loss += loss.item()
if i%2000==1999:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('finish')

保存模型

1
2
PATH = './model.path'
torch.save(net.state_dict(),PATH)

测试

1
2
3
4
5
6
7
testdataloader = torch.utils.data.Dataloader(testset,batch_size=4,shuffle=False)
dataiter = iter(testdataloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

分布式训练:torch.distributed

https://zhuanlan.zhihu.com/p/1982568110990067100

torch.distributed

是 PyTorch 生态系统中分布式训练的核心模块,为现代深度学习提供了强大的多进程、多节点并行计算能力

分布式进程间通信

1
2
3
torch.distributed.all_reduces() # 集合通信
torch.distributed.send() # 点对点通信:发送
torch.dsitributed.recv() # 点对点通信:接收

分布式多进程

1
2
3
4
5
6
7
torch.distributed.init_process_group(backend='nccl') // NCCL后端(GPU)
torch.distributed.init_process_group(backend= 'gloo') // gloohouduan (CPU/GPU)
torch.distributed.init_process_group(backend= 'mpi') // MPI后端(HPC集群)
# 检查可用后端
print(torch.distributed.is_nccl_available()) # 检查NCCL
print(torch.distributed.is_gloo_available()) # 检查Gloo
print(torch.distributed.is_mpi_available()) # 检查MPI
  1. NCCL (NVIDIA Collective Communications Library)

—— GPU 训练的绝对首选

NCCL 是 NVIDIA 专门为自家 GPU 开发的高性能通信库。

  • 硬件支持: 仅限 NVIDIA GPU。
  • 传输效率: 极高。它能感知拓扑结构(如 PCIe, NVLink),并自动选择最快的路径。在多机多卡环境下,它能充分利用 InfiniBand (IB) 网络。
  • 通信模式: 支持所有常见的集体通信算子(All-Reduce, All-Gather, Broadcast 等)。
  • 缺点: 不支持 CPU 之间的通信;报错信息相对晦涩(通常显示为简单的 NCCL Error)。

  1. Gloo

—— 兼容性之王,CPU 训练的首选

Gloo 是由 Facebook 开发的跨平台通信库。

  • 硬件支持: 同时支持 CPU 和 GPU(但在 GPU 上比 NCCL 慢得多)。
  • 传输效率: 中等。它主要基于 TCP 协议进行网络传输,没有针对 GPU 硬件链路(如 NVLink)做极致优化。
  • 优势: * 极其稳定: 很少出现超时崩溃,是调试分布式代码时的“安全网”。
    • 全能型: 如果你的模型部分在 CPU 运行,部分在 GPU 运行,Gloo 是唯一的选择。
  • 应用场景: CPU 集群训练、小规模测试、或者网络环境不支持 RDMA/IB 的普通以太网环境。

  1. MPI (Message Passing Interface)

—— 高性能计算(HPC)的遗产

MPI 是分布式计算领域的老牌标准,在深度学习兴起前就已统治学术界和超级计算机多年。

  • 硬件支持: 取决于具体的 MPI 实现(如 OpenMPI, Intel MPI)。
  • 使用门槛: 比较复杂。在 PyTorch 中使用 MPI 后端,通常需要你手动编译支持 MPI 的 PyTorch 版本,且必须在系统层安装相应的库。
  • 优势: 在传统的大规模超算集群上,MPI 拥有极佳的作业调度和容错能力。
  • 现状: 除非你的训练任务运行在特定的国家实验室超算中心,或者有非常特殊的集群管理需求,否则不推荐在现代深度学习任务中使用它(因为 NCCL 已经足够强大)。

分布式数据并行

DDP FSDP等

1
2
3
torch.nn.parallel.DistributedDataParallel - DDP类
torch.distributed.fsdp.FullyShardedDataParallel - FSDP类
torch.utils.data.distributed.DistributedSampler - 分布式采样器

1. torch.nn.parallel.DistributedDataParallel (DDP)

作用

数据并行训练的主要实现,将模型复制到每个GPU上,通过梯度同步实现并行训练。

核心原理

  1. 每个GPU有独立的模型副本
  2. 前向传播时,每个GPU处理不同的数据
  3. 反向传播时,通过All-Reduce操作同步梯度
  4. 每个GPU使用同步后的梯度更新模型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# 示例代码
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import os

def setup(rank, world_size):
"""初始化分布式环境"""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

# 初始化进程组
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)

def cleanup():
"""清理分布式环境"""
dist.destroy_process_group()

def train_ddp(rank, world_size):
"""DDP训练函数"""
setup(rank, world_size)

# 1. 创建模型并移动到当前GPU
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 2)
).to(rank)

# 2. 用DDP包装模型
ddp_model = DDP(model, device_ids=[rank])

# 3. 准备数据
# 假设每个GPU的batch_size=32
batch_size = 32
inputs = torch.randn(batch_size, 10).to(rank)
labels = torch.randint(0, 2, (batch_size,)).to(rank)

# 4. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01)

# 5. 训练循环
for epoch in range(10):
# 前向传播
outputs = ddp_model(inputs)
loss = criterion(outputs, labels)

# 反向传播
optimizer.zero_grad()
loss.backward() # DDP自动同步梯度

# 更新参数(所有GPU参数保持同步)
optimizer.step()

if rank == 0: # 只在主进程打印
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

cleanup()

# 启动DDP训练
if __name__ == "__main__":
world_size = torch.cuda.device_count() # GPU数量
mp.spawn(train_ddp, args=(world_size,), nprocs=world_size, join=True)

2. torch.distributed.fsdp.FullyShardedDataParallel (FSDP)

作用

全分片数据并行,更高级的并行策略,可以训练超大模型(如数十亿参数)。

与DDP的区别作用

全分片数据并行,更高级的并行策略,可以训练超大模型(如数十亿参数)。

与DDP的区别

  • DDP: 每个GPU存储完整的模型副本

  • FSDP: 每个GPU只存储模型的一部分,内存效率更高

  • DDP: 每个GPU存储完整的模型副本

  • FSDP: 每个GPU只存储模型的一部分,内存效率更高

既然只存了一部分,计算时怎么办?

这是 FSDP 最巧妙的地方。它采用了一种**“按需索取,用完即丢”**的策略。

当模型进行前向传播(Forward Pass)到某一层时:

  1. 收集 (All-Gather): 该层所属的 GPU 会向其他 GPU 广播自己负责的那部分参数,使得所有 GPU 在那一瞬间都拥有了该层的完整参数。
  2. 计算: 每张 GPU 用完整参数处理自己那份 Batch 数据。
  3. 释放 (Discard): 计算完成后,GPU 立即丢弃掉从别人那里拿来的参数,只保留自己负责的那 1/8 原始分片。

**反向传播(Backward Pass)**也是同样的逻辑:计算完梯度并同步后,非自己负责的梯度立即释放。

  • 优点(鱼): 极大地降低了单卡的显存门槛。它让你可以像写普通数据并行代码一样,训练超出单卡容量的大模型。
  • 代价(熊掌): 通信开销增加。因为在每一层计算前都要进行 All-Gather 通信,如果网络带宽(如没有 IB 网络)不够快,训练速度会变慢。

示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
import torch.distributed as dist
from torch.distributed.fsdp import ShardingStrategy
import torch.multiprocessing as mp
import os

def setup_fsdp(rank, world_size):
"""初始化FSDP环境"""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)

def train_fsdp(rank, world_size):
"""FSDP训练函数 - 适合超大模型"""
setup_fsdp(rank, world_size)

# 1. 创建超大模型(这里用简化示例)
# 实际中可能是包含数亿参数的Transformer
class LargeModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(1024, 4096) # 大矩阵
self.layer2 = nn.Linear(4096, 4096)
self.layer3 = nn.Linear(4096, 1024)
self.layer4 = nn.Linear(1024, 10)

def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x

# 2. 初始化模型
model = LargeModel().to(rank)

# 3. 配置FSDP(多种分片策略可选)
sharding_strategy = ShardingStrategy.FULL_SHARD # 完全分片

# 混合精度配置
mp_policy = MixedPrecision(
param_dtype=torch.float16, # 参数使用半精度
reduce_dtype=torch.float16, # 梯度归约使用半精度
buffer_dtype=torch.float32, # 缓冲区使用全精度
)

# 4. 用FSDP包装模型
fsdp_model = FSDP(
model,
sharding_strategy=sharding_strategy,
mixed_precision=mp_policy,
device_id=rank,
)

# 5. 训练逻辑
optimizer = torch.optim.Adam(fsdp_model.parameters(), lr=0.001)

for step in range(100):
# 生成数据
inputs = torch.randn(16, 1024).to(rank) # 较小的batch_size
labels = torch.randint(0, 10, (16,)).to(rank)

# 前向传播
outputs = fsdp_model(inputs)
loss = nn.functional.cross_entropy(outputs, labels)

# 反向传播
optimizer.zero_grad()
loss.backward()

# 梯度会在FSDP内部自动同步和分片
optimizer.step()

if rank == 0 and step % 10 == 0:
print(f"Step {step}, Loss: {loss.item():.4f}")

dist.destroy_process_group()

# 启动FSDP训练
if __name__ == "__main__":
world_size = 4 # 假设有4个GPU
mp.spawn(train_fsdp, args=(world_size,), nprocs=world_size, join=True)

3. torch.utils.data.distributed.DistributedSampler

作用

分布式采样器,确保每个GPU/进程获得不同的数据子集,避免数据重复。

工作原理

  1. 将完整数据集划分为多个子集
  2. 每个进程获得不同的子集
  3. 每个epoch打乱数据,但保持进程间数据不重叠
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
import os

# 1. 创建自定义数据集
class CustomDataset(Dataset):
def __init__(self, size=1000):
self.data = torch.randn(size, 10)
self.labels = torch.randint(0, 2, (size,))

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
return self.data[idx], self.labels[idx]

def setup_sampler(rank, world_size):
"""初始化分布式环境"""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)

def train_with_sampler(rank, world_size):
"""使用DistributedSampler的训练函数"""
setup_sampler(rank, world_size)

# 1. 创建数据集
dataset = CustomDataset(size=1000)

# 2. 创建分布式采样器
sampler = DistributedSampler(
dataset,
num_replicas=world_size, # 进程总数
rank=rank, # 当前进程排名
shuffle=True, # 是否打乱数据
seed=42 # 随机种子,确保可复现
)

# 3. 创建数据加载器
batch_size = 32
dataloader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler, # 使用分布式采样器
num_workers=2, # 数据加载工作进程数
pin_memory=True, # 启用内存锁页,加速GPU传输
drop_last=True # 丢弃最后一个不完整的batch
)

# 4. 验证采样效果
print(f"Rank {rank}: Sampler has {len(sampler)} samples")

# 5. 训练循环
for epoch in range(3):
# 重要:每个epoch开始前设置epoch
sampler.set_epoch(epoch)

total_samples = 0
for batch_idx, (data, labels) in enumerate(dataloader):
total_samples += len(data)

# 这里应该是实际训练代码
# ...

if batch_idx % 5 == 0:
print(f"Rank {rank}, Epoch {epoch}, Batch {batch_idx}: "
f"Sample indices range: {batch_idx*batch_size} to "
f"{(batch_idx+1)*batch_size-1}")

print(f"Rank {rank}, Epoch {epoch} processed {total_samples} samples")

dist.destroy_process_group()

def compare_samplers():
"""对比普通Sampler和DistributedSampler的区别"""
dataset = CustomDataset(size=100)

# 普通数据加载(单GPU)
print("=== 普通Sampler(单GPU)===")
regular_loader = DataLoader(dataset, batch_size=10, shuffle=True)
all_indices = []
for data, label in regular_loader:
all_indices.extend(range(len(all_indices)*10, len(all_indices)*10 + len(data)))
print(f"总样本数: {len(all_indices)}")
print(f"样本索引: {all_indices[:20]}...")

# 分布式采样器模拟
print("\n=== DistributedSampler(4个GPU)===")
for rank in range(4):
sampler = DistributedSampler(
dataset,
num_replicas=4,
rank=rank,
shuffle=True,
seed=42
)
indices = list(sampler)
print(f"GPU {rank} 获得 {len(indices)} 个样本")
print(f" 前10个索引: {indices[:10]}")
print(f" 这些索引来自: {[dataset.labels[i].item() for i in indices[:10]]}")

# 启动分布式训练
if __name__ == "__main__":
world_size = 4
mp.spawn(train_with_sampler, args=(world_size,), nprocs=world_size, join=True)

# 对比演示
compare_samplers()

自动梯度同步

参数和梯度的分布式同步

  • 集成在DDP/FSDP中的自动梯度同步机制

核心概念

进程组(Process Group)

定义:一组可以相互通信的进程,是分布式通信的基本单位。

相关API

  • torch.distributed.group.WORLD - 默认全局进程组
  • torch.distributed.new_group(ranks) - 创建自定义进程组

举例

1
2
3
4
5
# 默认全局进程组,包含所有进程
global_group = torch.distributed.group.WORLD

# 自定义进程组,只包含rank 0,1,2
custom_group = torch.distributed.new_group(ranks=[0, 1, 2])

rank(进程标识符)

定义:进程在进程组中的唯一标识符,从0开始编号。

相关API

  • torch.distributed.get_rank() - 获取当前进程的global_rank

分类

  • global_rank:全局进程组中的rank,范围为0到world_size-1
  • local_rank:单节点内的本地rank,通常对应GPU编号

举例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 单机4卡训练场景
# 节点0: 4个进程,local_rank = 0,1,2,3
# 节点1: 4个进程,local_rank = 0,1,2,3

# global_rank 计算公式:
# global_rank = node_rank * local_world_size + local_rank

# 示例:2节点,每节点4卡
# 节点0上的GPU0: local_rank=0, global_rank=0
# 节点0上的GPU1: local_rank=1, global_rank=1
# 节点0上的GPU2: local_rank=2, global_rank=2
# 节点0上的GPU3: local_rank=3, global_rank=3
# 节点1上的GPU0: local_rank=0, global_rank=4
# 节点1上的GPU1: local_rank=1, global_rank=5
# 节点1上的GPU2: local_rank=2, global_rank=6
# 节点1上的GPU3: local_rank=3, global_rank=7

rank = torch.distributed.get_rank() # 获取global_rank
local_rank = int(os.environ.get('LOCAL_RANK', 0)) # 本地rank

world_size(进程总数)

定义:进程组中的总进程数。

相关API

  • torch.distributed.get_world_size() - 获取进程总数

举例

1
2
3
4
5
6
7
# 单机4GPU训练
world_size = 4 # 总共4个进程

# 多机训练:2节点 x 4GPU = 8进程
world_size = 8 # 总共8个进程

world_size = torch.distributed.get_world_size()

初始化和进程组管理

初始化分布式环境

torch.distributed 的初始化是分布式训练的第一步,主要通过 torch.distributed.init_process_group() 函数完成。

基本初始化 API

1
2
3
4
5
6
7
torch.distributed.init_process_group(
backend='nccl', # 通信后端
init_method='env://', # 初始化方法
world_size=4, # 总进程数
rank=0, # 当前进程的rank
timeout=datetime.timedelta(seconds=30) # 超时时间
)

初始化方法

  1. 环境变量 (env://) - 最常用

    • 通过环境变量传递配置信息

    • 自动检测 MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE

  2. 文件系统 (file://) - 开发环境

    • 使用共享文件系统进行协调

    • 所有进程写入同一个文件

  3. TCP (tcp://) - 自定义

    • 直接指定主节点地址和端口

    • 格式:tcp://MASTER_ADDR:MASTER_PORT

📍 为什么需要地址?

init_method 中的地址是分布式训练的协调中心,解决了多进程发现和同步的关键问题:

核心作用

  • 进程发现:告诉每个进程如何找到其他进程
  • 信息交换:让进程们交换 rank、world_size 等配置信息
  • 同步启动:确保所有进程同时开始训练

为什么需要地址?

1
2
3
机器A (进程0,1) ──┐
├───> 集合点 (MASTER_ADDR:MASTER_PORT)
机器B (进程2,3) ──┘

分布式进程可能在不同机器上运行,需要一个**共同的 rendezvous point(集合点)**来协调。没有这个地址:

  • 进程A 不知道进程B 在哪里
  • 无法知道自己是 rank 0 还是 rank 1
  • 无法知道总共有多少个进程
  • 通信完全无法建立

这就是为什么需要一个”集合点”来协调所有分布式进程!

那么,这个集合点什么时候会用到呢?

答案:主要在初始化阶段使用,通信阶段通常不依赖

初始化阶段 ✅ 必须使用

1
2
3
4
5
# 所有进程都需要通过 init_method 协调
torch.distributed.init_process_group(
init_method='tcp://192.168.1.100:12345', # ← 这里需要地址
backend='nccl'
)

通信阶段 ❌ 通常不依赖

1
2
3
# 一旦建立连接,通信是点对点的
torch.distributed.all_reduce(tensor) # 不需要再指定地址
torch.distributed.broadcast(tensor, src=0) # 直接进程间通信

工作流程

1
2
3
4
5
6
7
8
9
初始化阶段:
进程A ──连接──> MASTER_ADDR:MASTER_PORT ←──连接── 进程B
进程C ──连接──> MASTER_ADDR:MASTER_PORT ←──连接── 进程D

通信阶段:
进程A ───直接通信─── 进程B
进程C ───直接通信─── 进程D
↘ ↗
MASTER节点(不再参与通信)

特殊情况仍需依赖

  • 动态进程组管理:添加新进程时可能仍需主节点协调
  • 某些NCCL配置:复杂网络拓扑下主节点参与通信协调
  • 错误恢复场景:进程崩溃后重新通过主节点协调恢复

实际意义init_method 主要是分布式训练的**“媒人”,牵线搭桥后就可以”退场”了!**

环境变量配置

在使用 init_method='env://' 时,需要设置以下环境变量:

  • MASTER_ADDR: 主节点地址
  • MASTER_PORT: 主节点端口
  • RANK: 当前进程的rank (0 到 world_size-1)
  • WORLD_SIZE: 总进程数

进程组管理

默认进程组

初始化后会创建一个默认的全局进程组,可以通过以下函数获取信息:

1
2
3
4
5
6
7
8
# 获取当前进程的rank
rank = torch.distributed.get_rank()

# 获取world size
world_size = torch.distributed.get_world_size()

# 检查是否已初始化
is_initialized = torch.distributed.is_initialized()

自定义进程组

可以创建自定义的进程组来实现更灵活的通信模式:

1
2
3
4
5
# 创建新的进程组
new_group = torch.distributed.new_group(ranks=[0, 1, 2])

# 获取特定进程组的大小
group_size = torch.distributed.get_world_size(new_group)

进程组生命周期

1
2
3
4
5
# 销毁进程组
torch.distributed.destroy_process_group()

# 销毁特定进程组
torch.distributed.destroy_process_group(new_group)

张量并行

列并行 (Column Parallelism)

1
2
3
设备0: y0 = x @ W0  →  y0
设备1: y1 = x @ W1 → y1
最终: y = concat([y0, y1])

行并行 (Row Parallelism)

1
2
3
设备0: x0 @ W0  → 部分结果
设备1: x1 @ W1 → 部分结果
最终: y = sum(所有部分结果)

基于torch.distributed的实现

列并行实现 (Column Parallelism)

列并行通过分割输出维度来实现并行化,特别适合注意力机制中的 Query-Key-Value 投影。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class ColumnParallelLinear(nn.Module):
"""
列并行线性层:分割输出维度,实现并行计算

计算过程:y = x @ W + b
分割策略:W 按列分割,每个进程处理部分输出特征
"""

def __init__(self, input_size: int, output_size: int, world_size: int):
super().__init__()
self.world_size = world_size
self.rank = dist.get_rank()

# 验证维度可分割性
assert output_size % world_size == 0, f"output_size ({output_size}) 必须能被 world_size ({world_size}) 整除"

# 计算每个进程的局部输出维度
self.local_output_size = output_size // world_size

# 初始化局部权重矩阵 [local_output_size, input_size]
self.weight = nn.Parameter(torch.randn(self.local_output_size, input_size))

# 初始化局部偏置向量 [local_output_size]
self.bias = nn.Parameter(torch.randn(self.local_output_size))

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""
前向传播:执行局部计算并聚合结果

Args:
input_tensor: 输入张量 [batch_size, input_size]

Returns:
output_tensor: 输出张量 [batch_size, output_size]
"""
# Step 1: 局部矩阵乘法
local_output = torch.matmul(input_tensor, self.weight.t()) + self.bias

# Step 2: 收集所有进程的局部输出
gathered_outputs = [torch.zeros_like(local_output) for _ in range(self.world_size)]
dist.all_gather(gathered_outputs, local_output)

# Step 3: 在输出维度上拼接所有局部结果
final_output = torch.cat(gathered_outputs, dim=-1)

return final_output

行并行实现 (Row Parallelism)

行并行通过分割输入维度来实现并行化,常用于多头注意力中的输出投影层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class RowParallelLinear(nn.Module):
"""
行并行线性层:分割输入维度,实现归约并行

计算过程:y = x @ W + b
分割策略:W 按行分割,每个进程处理部分输入特征
"""

def __init__(self, input_size: int, output_size: int, world_size: int):
super().__init__()
self.world_size = world_size
self.rank = dist.get_rank()

# 验证维度可分割性
assert input_size % world_size == 0, f"input_size ({input_size}) 必须能被 world_size ({world_size}) 整除"

# 计算每个进程的局部输入维度
self.local_input_size = input_size // world_size

# 初始化局部权重矩阵 [output_size, local_input_size]
self.weight = nn.Parameter(torch.randn(output_size, self.local_input_size))

# 初始化全局偏置向量 [output_size] - 所有进程共享相同偏置
self.bias = nn.Parameter(torch.randn(output_size))

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""
前向传播:分割输入、局部计算、全局归约

Args:
input_tensor: 输入张量 [batch_size, input_size]

Returns:
output_tensor: 输出张量 [batch_size, output_size]
"""
# Step 1: 将输入张量按最后一个维度分割
input_chunks = torch.chunk(input_tensor, self.world_size, dim=-1)
local_input = input_chunks[self.rank] # 选择当前进程对应的输入块

# Step 2: 局部矩阵乘法
local_output = torch.matmul(local_input, self.weight.t())

# Step 3: 全局归约 - 对所有进程的局部输出求和
dist.all_reduce(local_output, op=dist.ReduceOp.SUM)

# Step 4: 添加偏置(所有进程都需要添加相同的偏置以保持梯度同步)
local_output += self.bias

return local_output

多层感知机张量并行实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class TensorParallelMLP(nn.Module):
"""张量并行多层感知机"""

def __init__(self, hidden_size, world_size):
super().__init__()
self.world_size = world_size
self.rank = dist.get_rank()

# 第一层:列并行
self.fc1 = ColumnParallelLinear(hidden_size, hidden_size * 4, world_size)
# 第二层:行并行
self.fc2 = RowParallelLinear(hidden_size * 4, hidden_size, world_size)

self.gelu = nn.GELU()

def forward(self, x):
# 第一层:列并行
hidden = self.fc1(x)
hidden = self.gelu(hidden)

# 第二层:行并行
output = self.fc2(hidden)
return output

梯度同步处理

张量并行需要特殊的梯度处理:

1
2
3
4
5
6
7
8
9
10
11
12
def tensor_parallel_backward_hook(module, grad_input, grad_output):
"""处理张量并行的梯度同步"""
# 确保梯度正确传播
if hasattr(module, 'weight'):
dist.all_reduce(module.weight.grad, op=dist.ReduceOp.SUM)
if hasattr(module, 'bias') and module.bias is not None:
dist.all_reduce(module.bias.grad, op=dist.ReduceOp.SUM)
return grad_input

# 注册钩子
model.fc1.register_backward_hook(tensor_parallel_backward_hook)
model.fc2.register_backward_hook(tensor_parallel_backward_hook)

完整训练示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def train_tensor_parallel():
"""张量并行训练示例"""

# 初始化分布式环境
dist.init_process_group(backend='nccl')
torch.cuda.set_device(dist.get_rank())

world_size = dist.get_world_size()
rank = dist.get_rank()

# 创建模型
model = TensorParallelMLP(hidden_size=1024, world_size=world_size)
model = model.cuda()

# 只优化局部参数
optimizer = torch.optim.Adam(model.parameters())

# 训练循环
for batch in dataloader:
input_data = batch['input'].cuda()

# 前向传播
output = model(input_data)
loss = compute_loss(output, batch['target'].cuda())

# 反向传播
optimizer.zero_grad()
loss.backward()

# 梯度同步已在钩子中处理
optimizer.step()

if rank == 0:
print(f"Loss: {loss.item()}")

if __name__ == '__main__':
train_tensor_parallel()

内存效率分析

张量并行的优势

  • 内存节省:每个设备只存储模型的一部分
  • 通信效率:只同步必要的中间结果
  • 扩展性:支持超大规模模型

与数据并行的对比

特性 数据并行 张量并行
内存使用 O(N) O(N/P)
通信量 高(梯度同步) 中等(中间结果)
适用场景 中等模型 超大模型
实现复杂度

注意事项

1. 维度要求
  • 权重矩阵维度必须能被world_size整除
  • 输入输出维度需要仔细规划
2. 负载均衡
  • 确保各设备计算量均衡
  • 避免某些设备成为瓶颈
3. 通信优化
  • 选择合适的通信原语
  • 考虑通信与计算的重叠
4. 调试困难
  • 张量并行调试相对复杂
  • 需要验证梯度同步的正确性

高级应用:Megatron-LM风格实现

对于更复杂的Transformer模型,可以参考Megatron-LM的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class ParallelTransformerLayer(nn.Module):
"""并行Transformer层"""

def __init__(self, hidden_size, num_heads, world_size):
super().__init__()

# 张量并行注意力
self.attention = TensorParallelMultiHeadAttention(
hidden_size, num_heads, world_size
)

# 张量并行前馈网络
self.feed_forward = TensorParallelMLP(hidden_size, world_size)

self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)

def forward(self, x):
# 残差连接 + 注意力
attn_output = self.attention(self.norm1(x))
x = x + attn_output

# 残差连接 + 前馈
ff_output = self.feed_forward(self.norm2(x))
x = x + ff_output

return x

这个实现展示了如何将张量并行应用到完整的Transformer架构中,实现超大规模语言模型的训练。


状态检查函数

1
2
3
4
5
# 检查分布式环境状态
print(torch.distributed.is_initialized()) # 是否已初始化
print(torch.distributed.get_rank()) # 当前进程rank
print(torch.distributed.get_world_size()) # 进程总数
print(torch.distributed.get_backend()) # 当前后端

进程组管理工具

1
2
3
4
5
6
# 获取进程组信息
default_group = torch.distributed.group.WORLD # 默认全局组
group_size = torch.distributed.get_world_size(default_group)

# 检查进程是否在组中
print(torch.distributed.get_rank() in torch.distributed.get_process_group_ranks(default_group))

调试和监控工具

性能监控

megatron timer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import time

# 简单的通信性能测试
def benchmark_communication(tensor_size):
tensor = torch.randn(tensor_size).cuda()

torch.cuda.synchronize()
start_time = time.time()

torch.distributed.all_reduce(tensor)

torch.cuda.synchronize()
end_time = time.time()

return end_time - start_time

内存监控

1
2
3
4
5
6
# GPU内存使用情况
def print_gpu_memory():
if torch.cuda.is_available():
print(f"Rank {torch.distributed.get_rank()}: "
f"GPU memory: {torch.cuda.memory_allocated()/1e9:.2f}GB / "
f"{torch.cuda.memory_reserved()/1e9:.2f}GB")

实用工具函数

分布式随机种子设置

1
2
3
4
5
6
7
8
def set_random_seed(seed=42):
"""设置分布式训练的随机种子"""
rank = torch.distributed.get_rank()
torch.manual_seed(seed + rank)

if torch.cuda.is_available():
torch.cuda.manual_seed(seed + rank)
torch.cuda.manual_seed_all(seed + rank)

分布式验证

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def distributed_validation(model, val_loader, criterion):
"""分布式验证函数"""
model.eval()
total_loss = 0.0
total_correct = 0
total_samples = 0

with torch.no_grad():
for inputs, targets in val_loader:
outputs = model(inputs)
loss = criterion(outputs, targets)

# 收集所有进程的结果
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
total_loss += loss.item()

# 计算准确率
_, predicted = outputs.max(1)
correct = predicted.eq(targets).sum().item()
total_correct += correct
total_samples += targets.size(0)

# 计算全局平均
world_size = torch.distributed.get_world_size()
total_loss /= world_size
total_correct /= world_size
total_samples /= world_size

accuracy = total_correct / total_samples
return total_loss, accuracy