pytorch学习
简单例子(包含数据集加载、训练、模型保存和测试)
一个基本的例子:https://zhuanlan.zhihu.com/p/508721527
加载数据
1 2 torchvision.datasets.CIFAR10 torch.utils.data.DataLoader()
训练网络
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 import torch.nn as nnimport torch.nn.functional as Fclass Net (nn.Module): def __init__ (self ): super (Net, self ).__init__() self .conv1 = nn.Conv2d(3 , 6 , 5 ) self .pool = nn.MaxPool2d(2 , 2 ) self .conv2 = nn.Conv2d(6 , 16 , 5 ) self .fc1 = nn.Linear(16 * 5 * 5 , 120 ) self .fc2 = nn.Linear(120 , 84 ) self .fc3 = nn.Linear(84 , 10 ) def forward (self, x ): x = self .pool(F.relu(self .conv1(x))) x = self .pool(F.relu(self .conv2(x))) x = x.view(-1 , 16 * 5 * 5 ) x = F.relu(self .fc1(x)) x = F.relu(self .fc2(x)) x = self .fc3(x) return x net = Net()
这里的维度变化:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 输入: [batch_size, 3 , 32 , 32 ] conv1: [batch_size, 6 , 28 , 28 ] pool1: [batch_size, 6 , 14 , 14 ] conv2: [batch_size, 16 , 10 , 10 ] pool2: [batch_size, 16 , 5 , 5 ] flatten: [batch_size, 16 ×5 ×5 =400 ] fc1: [batch_size, 120 ] fc2: [batch_size, 84 ] fc3: [batch_size, 10 ]
训练
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 import torch.optim as optimimport torch.nn as nncriterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(),lr=0.01 ,momentum=0.1 ) for epoch in range (2 ): runnning_loss = 0.0 for i,data in enumerate (trainloader,0 ): input ,labels = data optimizer.zero_grad() output = net(input ) loss =criterion(output,labels) optimizer.backward() optimizer.step() running_loss += loss.item() if i%2000 ==1999 : print ('[%d, %5d] loss: %.3f' % (epoch + 1 , i + 1 , running_loss / 2000 )) running_loss = 0.0 print ('finish' )
保存模型
1 2 PATH = './model.path' torch.save(net.state_dict(),PATH)
测试
1 2 3 4 5 6 7 testdataloader = torch.utils.data.Dataloader(testset,batch_size=4 ,shuffle=False ) dataiter = iter (testdataloader) images, labels = dataiter.next () imshow(torchvision.utils.make_grid(images)) print ('GroundTruth: ' , ' ' .join('%5s' % classes[labels[j]] for j in range (4 )))
分布式训练:torch.distributed
https://zhuanlan.zhihu.com/p/1982568110990067100
torch.distributed
是 PyTorch 生态系统中分布式训练的核心模块,为现代深度学习提供了强大的多进程、多节点并行计算能力
分布式进程间通信
1 2 3 torch.distributed.all_reduces() torch.distributed.send() torch.dsitributed.recv()
分布式多进程
1 2 3 4 5 6 7 torch.distributed.init_process_group(backend='nccl' ) // NCCL后端(GPU) torch.distributed.init_process_group(backend= 'gloo' ) // gloohouduan (CPU/GPU) torch.distributed.init_process_group(backend= 'mpi' ) // MPI后端(HPC集群) print (torch.distributed.is_nccl_available()) print (torch.distributed.is_gloo_available()) print (torch.distributed.is_mpi_available())
NCCL (NVIDIA Collective Communications Library)
—— GPU 训练的绝对首选
NCCL 是 NVIDIA 专门为自家 GPU 开发的高性能通信库。
硬件支持: 仅限 NVIDIA GPU。
传输效率: 极高。它能感知拓扑结构(如 PCIe, NVLink),并自动选择最快的路径。在多机多卡环境下,它能充分利用 InfiniBand (IB) 网络。
通信模式: 支持所有常见的集体通信算子(All-Reduce, All-Gather, Broadcast 等)。
缺点: 不支持 CPU 之间的通信;报错信息相对晦涩(通常显示为简单的 NCCL Error)。
Gloo
—— 兼容性之王,CPU 训练的首选
Gloo 是由 Facebook 开发的跨平台通信库。
硬件支持: 同时支持 CPU 和 GPU(但在 GPU 上比 NCCL 慢得多)。
传输效率: 中等。它主要基于 TCP 协议进行网络传输,没有针对 GPU 硬件链路(如 NVLink)做极致优化。
优势: * 极其稳定: 很少出现超时崩溃,是调试分布式代码时的“安全网”。
全能型: 如果你的模型部分在 CPU 运行,部分在 GPU 运行,Gloo 是唯一的选择。
应用场景: CPU 集群训练、小规模测试、或者网络环境不支持 RDMA/IB 的普通以太网环境。
MPI (Message Passing Interface)
—— 高性能计算(HPC)的遗产
MPI 是分布式计算领域的老牌标准,在深度学习兴起前就已统治学术界和超级计算机多年。
硬件支持: 取决于具体的 MPI 实现(如 OpenMPI, Intel MPI)。
使用门槛: 比较复杂。在 PyTorch 中使用 MPI 后端,通常需要你手动编译支持 MPI 的 PyTorch 版本,且必须在系统层安装相应的库。
优势: 在传统的大规模超算集群上,MPI 拥有极佳的作业调度和容错能力。
现状: 除非你的训练任务运行在特定的国家实验室超算中心,或者有非常特殊的集群管理需求,否则不推荐 在现代深度学习任务中使用它(因为 NCCL 已经足够强大)。
分布式数据并行
DDP FSDP等
1 2 3 torch.nn.parallel.DistributedDataParallel - DDP类 torch.distributed.fsdp.FullyShardedDataParallel - FSDP类 torch.utils.data.distributed.DistributedSampler - 分布式采样器
1. torch.nn.parallel.DistributedDataParallel (DDP)
作用
数据并行训练的主要实现,将模型复制到每个GPU上,通过梯度同步实现并行训练。
核心原理
每个GPU有独立的模型副本
前向传播时,每个GPU处理不同的数据
反向传播时,通过All-Reduce操作同步梯度
每个GPU使用同步后的梯度更新模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 import torchimport torch.distributed as distimport torch.nn as nnfrom torch.nn.parallel import DistributedDataParallel as DDPimport torch.multiprocessing as mpimport osdef setup (rank, world_size ): """初始化分布式环境""" os.environ['MASTER_ADDR' ] = 'localhost' os.environ['MASTER_PORT' ] = '12355' dist.init_process_group("nccl" , rank=rank, world_size=world_size) torch.cuda.set_device(rank) def cleanup (): """清理分布式环境""" dist.destroy_process_group() def train_ddp (rank, world_size ): """DDP训练函数""" setup(rank, world_size) model = nn.Sequential( nn.Linear(10 , 20 ), nn.ReLU(), nn.Linear(20 , 2 ) ).to(rank) ddp_model = DDP(model, device_ids=[rank]) batch_size = 32 inputs = torch.randn(batch_size, 10 ).to(rank) labels = torch.randint(0 , 2 , (batch_size,)).to(rank) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01 ) for epoch in range (10 ): outputs = ddp_model(inputs) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() if rank == 0 : print (f"Epoch {epoch} , Loss: {loss.item():.4 f} " ) cleanup() if __name__ == "__main__" : world_size = torch.cuda.device_count() mp.spawn(train_ddp, args=(world_size,), nprocs=world_size, join=True )
2. torch.distributed.fsdp.FullyShardedDataParallel (FSDP)
作用
全分片数据并行 ,更高级的并行策略,可以训练超大模型(如数十亿参数)。
与DDP的区别 作用
全分片数据并行 ,更高级的并行策略,可以训练超大模型(如数十亿参数)。
与DDP的区别
既然只存了一部分,计算时怎么办?
这是 FSDP 最巧妙的地方。它采用了一种**“按需索取,用完即丢”**的策略。
当模型进行前向传播(Forward Pass)到某一层时:
收集 (All-Gather): 该层所属的 GPU 会向其他 GPU 广播自己负责的那部分参数,使得所有 GPU 在那一瞬间都拥有了该层的完整参数。
计算: 每张 GPU 用完整参数处理自己那份 Batch 数据。
释放 (Discard): 计算完成后,GPU 立即丢弃掉从别人那里拿来的参数,只保留自己负责的那 1/8 原始分片。
**反向传播(Backward Pass)**也是同样的逻辑:计算完梯度并同步后,非自己负责的梯度立即释放。
优点(鱼): 极大地降低了单卡的显存门槛。它让你可以像写普通数据并行代码一样,训练超出单卡容量的大模型。
代价(熊掌): 通信开销增加。因为在每一层计算前都要进行 All-Gather 通信,如果网络带宽(如没有 IB 网络)不够快,训练速度会变慢。
示例代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 import torchimport torch.nn as nnfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDPfrom torch.distributed.fsdp import MixedPrecisionimport torch.distributed as distfrom torch.distributed.fsdp import ShardingStrategyimport torch.multiprocessing as mpimport osdef setup_fsdp (rank, world_size ): """初始化FSDP环境""" os.environ['MASTER_ADDR' ] = 'localhost' os.environ['MASTER_PORT' ] = '12355' dist.init_process_group("nccl" , rank=rank, world_size=world_size) torch.cuda.set_device(rank) def train_fsdp (rank, world_size ): """FSDP训练函数 - 适合超大模型""" setup_fsdp(rank, world_size) class LargeModel (nn.Module): def __init__ (self ): super ().__init__() self .layer1 = nn.Linear(1024 , 4096 ) self .layer2 = nn.Linear(4096 , 4096 ) self .layer3 = nn.Linear(4096 , 1024 ) self .layer4 = nn.Linear(1024 , 10 ) def forward (self, x ): x = self .layer1(x) x = self .layer2(x) x = self .layer3(x) x = self .layer4(x) return x model = LargeModel().to(rank) sharding_strategy = ShardingStrategy.FULL_SHARD mp_policy = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float32, ) fsdp_model = FSDP( model, sharding_strategy=sharding_strategy, mixed_precision=mp_policy, device_id=rank, ) optimizer = torch.optim.Adam(fsdp_model.parameters(), lr=0.001 ) for step in range (100 ): inputs = torch.randn(16 , 1024 ).to(rank) labels = torch.randint(0 , 10 , (16 ,)).to(rank) outputs = fsdp_model(inputs) loss = nn.functional.cross_entropy(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() if rank == 0 and step % 10 == 0 : print (f"Step {step} , Loss: {loss.item():.4 f} " ) dist.destroy_process_group() if __name__ == "__main__" : world_size = 4 mp.spawn(train_fsdp, args=(world_size,), nprocs=world_size, join=True )
3. torch.utils.data.distributed.DistributedSampler
作用
分布式采样器 ,确保每个GPU/进程获得不同的数据子集,避免数据重复。
工作原理
将完整数据集划分为多个子集
每个进程获得不同的子集
每个epoch打乱数据,但保持进程间数据不重叠
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 import torchfrom torch.utils.data import Dataset, DataLoaderfrom torch.utils.data.distributed import DistributedSamplerimport torch.distributed as distimport torch.multiprocessing as mpimport osclass CustomDataset (Dataset ): def __init__ (self, size=1000 ): self .data = torch.randn(size, 10 ) self .labels = torch.randint(0 , 2 , (size,)) def __len__ (self ): return len (self .data) def __getitem__ (self, idx ): return self .data[idx], self .labels[idx] def setup_sampler (rank, world_size ): """初始化分布式环境""" os.environ['MASTER_ADDR' ] = 'localhost' os.environ['MASTER_PORT' ] = '12355' dist.init_process_group("nccl" , rank=rank, world_size=world_size) def train_with_sampler (rank, world_size ): """使用DistributedSampler的训练函数""" setup_sampler(rank, world_size) dataset = CustomDataset(size=1000 ) sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=True , seed=42 ) batch_size = 32 dataloader = DataLoader( dataset, batch_size=batch_size, sampler=sampler, num_workers=2 , pin_memory=True , drop_last=True ) print (f"Rank {rank} : Sampler has {len (sampler)} samples" ) for epoch in range (3 ): sampler.set_epoch(epoch) total_samples = 0 for batch_idx, (data, labels) in enumerate (dataloader): total_samples += len (data) if batch_idx % 5 == 0 : print (f"Rank {rank} , Epoch {epoch} , Batch {batch_idx} : " f"Sample indices range: {batch_idx*batch_size} to " f"{(batch_idx+1 )*batch_size-1 } " ) print (f"Rank {rank} , Epoch {epoch} processed {total_samples} samples" ) dist.destroy_process_group() def compare_samplers (): """对比普通Sampler和DistributedSampler的区别""" dataset = CustomDataset(size=100 ) print ("=== 普通Sampler(单GPU)===" ) regular_loader = DataLoader(dataset, batch_size=10 , shuffle=True ) all_indices = [] for data, label in regular_loader: all_indices.extend(range (len (all_indices)*10 , len (all_indices)*10 + len (data))) print (f"总样本数: {len (all_indices)} " ) print (f"样本索引: {all_indices[:20 ]} ..." ) print ("\n=== DistributedSampler(4个GPU)===" ) for rank in range (4 ): sampler = DistributedSampler( dataset, num_replicas=4 , rank=rank, shuffle=True , seed=42 ) indices = list (sampler) print (f"GPU {rank} 获得 {len (indices)} 个样本" ) print (f" 前10个索引: {indices[:10 ]} " ) print (f" 这些索引来自: {[dataset.labels[i].item() for i in indices[:10 ]]} " ) if __name__ == "__main__" : world_size = 4 mp.spawn(train_with_sampler, args=(world_size,), nprocs=world_size, join=True ) compare_samplers()
自动梯度同步
参数和梯度的分布式同步
核心概念
进程组(Process Group)
定义 :一组可以相互通信的进程,是分布式通信的基本单位。
相关API :
torch.distributed.group.WORLD - 默认全局进程组
torch.distributed.new_group(ranks) - 创建自定义进程组
举例 :
1 2 3 4 5 # 默认全局进程组,包含所有进程 global_group = torch.distributed.group.WORLD # 自定义进程组,只包含rank 0,1,2 custom_group = torch.distributed.new_group(ranks=[0, 1, 2])
rank(进程标识符)
定义 :进程在进程组中的唯一标识符,从0开始编号。
相关API :
torch.distributed.get_rank() - 获取当前进程的global_rank
分类 :
global_rank :全局进程组中的rank,范围为0到world_size-1
local_rank :单节点内的本地rank,通常对应GPU编号
举例 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 # 单机4卡训练场景 # 节点0: 4个进程,local_rank = 0,1,2,3 # 节点1: 4个进程,local_rank = 0,1,2,3 # global_rank 计算公式: # global_rank = node_rank * local_world_size + local_rank # 示例:2节点,每节点4卡 # 节点0上的GPU0: local_rank=0, global_rank=0 # 节点0上的GPU1: local_rank=1, global_rank=1 # 节点0上的GPU2: local_rank=2, global_rank=2 # 节点0上的GPU3: local_rank=3, global_rank=3 # 节点1上的GPU0: local_rank=0, global_rank=4 # 节点1上的GPU1: local_rank=1, global_rank=5 # 节点1上的GPU2: local_rank=2, global_rank=6 # 节点1上的GPU3: local_rank=3, global_rank=7 rank = torch.distributed.get_rank() # 获取global_rank local_rank = int(os.environ.get('LOCAL_RANK', 0)) # 本地rank
world_size(进程总数)
定义 :进程组中的总进程数。
相关API :
torch.distributed.get_world_size() - 获取进程总数
举例 :
1 2 3 4 5 6 7 # 单机4GPU训练 world_size = 4 # 总共4个进程 # 多机训练:2节点 x 4GPU = 8进程 world_size = 8 # 总共8个进程 world_size = torch.distributed.get_world_size()
初始化和进程组管理
初始化分布式环境
torch.distributed 的初始化是分布式训练的第一步,主要通过 torch.distributed.init_process_group() 函数完成。
基本初始化 API
1 2 3 4 5 6 7 torch.distributed.init_process_group( backend='nccl', # 通信后端 init_method='env://', # 初始化方法 world_size=4, # 总进程数 rank=0, # 当前进程的rank timeout=datetime.timedelta(seconds=30) # 超时时间 )
初始化方法
环境变量 (env://) - 最常用
文件系统 (file://) - 开发环境
TCP (tcp://) - 自定义
📍 为什么需要地址?
init_method 中的地址是分布式训练的协调中心 ,解决了多进程发现和同步的关键问题:
核心作用 :
进程发现 :告诉每个进程如何找到其他进程
信息交换 :让进程们交换 rank、world_size 等配置信息
同步启动 :确保所有进程同时开始训练
为什么需要地址?
1 2 3 机器A (进程0,1) ──┐ ├───> 集合点 (MASTER_ADDR:MASTER_PORT) 机器B (进程2,3) ──┘
分布式进程可能在不同机器上运行 ,需要一个**共同的 rendezvous point(集合点)**来协调。没有这个地址:
进程A 不知道进程B 在哪里
无法知道自己是 rank 0 还是 rank 1
无法知道总共有多少个进程
通信完全无法建立
这就是为什么需要一个”集合点”来协调所有分布式进程!
那么,这个集合点什么时候会用到呢?
答案:主要在初始化阶段使用,通信阶段通常不依赖
初始化阶段 ✅ 必须使用
1 2 3 4 5 # 所有进程都需要通过 init_method 协调 torch.distributed.init_process_group( init_method='tcp://192.168.1.100:12345', # ← 这里需要地址 backend='nccl' )
通信阶段 ❌ 通常不依赖
1 2 3 # 一旦建立连接,通信是点对点的 torch.distributed.all_reduce(tensor) # 不需要再指定地址 torch.distributed.broadcast(tensor, src=0) # 直接进程间通信
工作流程
1 2 3 4 5 6 7 8 9 初始化阶段: 进程A ──连接──> MASTER_ADDR:MASTER_PORT ←──连接── 进程B 进程C ──连接──> MASTER_ADDR:MASTER_PORT ←──连接── 进程D 通信阶段: 进程A ───直接通信─── 进程B 进程C ───直接通信─── 进程D ↘ ↗ MASTER节点(不再参与通信)
特殊情况仍需依赖
动态进程组管理 :添加新进程时可能仍需主节点协调
某些NCCL配置 :复杂网络拓扑下主节点参与通信协调
错误恢复场景 :进程崩溃后重新通过主节点协调恢复
实际意义 :init_method 主要是分布式训练的**“媒人”,牵线搭桥后就可以”退场”了!**
环境变量配置
在使用 init_method='env://' 时,需要设置以下环境变量:
MASTER_ADDR: 主节点地址
MASTER_PORT: 主节点端口
RANK: 当前进程的rank (0 到 world_size-1)
WORLD_SIZE: 总进程数
进程组管理
默认进程组
初始化后会创建一个默认的全局进程组,可以通过以下函数获取信息:
1 2 3 4 5 6 7 8 rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() is_initialized = torch.distributed.is_initialized()
自定义进程组
可以创建自定义的进程组来实现更灵活的通信模式:
1 2 3 4 5 new_group = torch.distributed.new_group(ranks=[0 , 1 , 2 ]) group_size = torch.distributed.get_world_size(new_group)
进程组生命周期
1 2 3 4 5 torch.distributed.destroy_process_group() torch.distributed.destroy_process_group(new_group)
张量并行
列并行 (Column Parallelism) :
1 2 3 设备0: y0 = x @ W0 → y0 设备1: y1 = x @ W1 → y1 最终: y = concat([y0, y1])
行并行 (Row Parallelism) :
1 2 3 设备0: x0 @ W0 → 部分结果 设备1: x1 @ W1 → 部分结果 最终: y = sum(所有部分结果)
基于torch.distributed的实现
列并行实现 (Column Parallelism)
列并行通过分割输出维度来实现并行化,特别适合注意力机制中的 Query-Key-Value 投影。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 class ColumnParallelLinear (nn.Module): """ 列并行线性层:分割输出维度,实现并行计算 计算过程:y = x @ W + b 分割策略:W 按列分割,每个进程处理部分输出特征 """ def __init__ (self, input_size: int , output_size: int , world_size: int ): super ().__init__() self .world_size = world_size self .rank = dist.get_rank() assert output_size % world_size == 0 , f"output_size ({output_size} ) 必须能被 world_size ({world_size} ) 整除" self .local_output_size = output_size // world_size self .weight = nn.Parameter(torch.randn(self .local_output_size, input_size)) self .bias = nn.Parameter(torch.randn(self .local_output_size)) def forward (self, input_tensor: torch.Tensor ) -> torch.Tensor: """ 前向传播:执行局部计算并聚合结果 Args: input_tensor: 输入张量 [batch_size, input_size] Returns: output_tensor: 输出张量 [batch_size, output_size] """ local_output = torch.matmul(input_tensor, self .weight.t()) + self .bias gathered_outputs = [torch.zeros_like(local_output) for _ in range (self .world_size)] dist.all_gather(gathered_outputs, local_output) final_output = torch.cat(gathered_outputs, dim=-1 ) return final_output
行并行实现 (Row Parallelism)
行并行通过分割输入维度来实现并行化,常用于多头注意力中的输出投影层。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 class RowParallelLinear(nn.Module): """ 行并行线性层:分割输入维度,实现归约并行 计算过程:y = x @ W + b 分割策略:W 按行分割,每个进程处理部分输入特征 """ def __init__(self, input_size: int, output_size: int, world_size: int): super().__init__() self.world_size = world_size self.rank = dist.get_rank() # 验证维度可分割性 assert input_size % world_size == 0, f"input_size ({input_size}) 必须能被 world_size ({world_size}) 整除" # 计算每个进程的局部输入维度 self.local_input_size = input_size // world_size # 初始化局部权重矩阵 [output_size, local_input_size] self.weight = nn.Parameter(torch.randn(output_size, self.local_input_size)) # 初始化全局偏置向量 [output_size] - 所有进程共享相同偏置 self.bias = nn.Parameter(torch.randn(output_size)) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: """ 前向传播:分割输入、局部计算、全局归约 Args: input_tensor: 输入张量 [batch_size, input_size] Returns: output_tensor: 输出张量 [batch_size, output_size] """ # Step 1: 将输入张量按最后一个维度分割 input_chunks = torch.chunk(input_tensor, self.world_size, dim=-1) local_input = input_chunks[self.rank] # 选择当前进程对应的输入块 # Step 2: 局部矩阵乘法 local_output = torch.matmul(local_input, self.weight.t()) # Step 3: 全局归约 - 对所有进程的局部输出求和 dist.all_reduce(local_output, op=dist.ReduceOp.SUM) # Step 4: 添加偏置(所有进程都需要添加相同的偏置以保持梯度同步) local_output += self.bias return local_output
多层感知机张量并行实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 class TensorParallelMLP (nn.Module): """张量并行多层感知机""" def __init__ (self, hidden_size, world_size ): super ().__init__() self .world_size = world_size self .rank = dist.get_rank() self .fc1 = ColumnParallelLinear(hidden_size, hidden_size * 4 , world_size) self .fc2 = RowParallelLinear(hidden_size * 4 , hidden_size, world_size) self .gelu = nn.GELU() def forward (self, x ): hidden = self .fc1(x) hidden = self .gelu(hidden) output = self .fc2(hidden) return output
梯度同步处理
张量并行需要特殊的梯度处理:
1 2 3 4 5 6 7 8 9 10 11 12 def tensor_parallel_backward_hook (module, grad_input, grad_output ): """处理张量并行的梯度同步""" if hasattr (module, 'weight' ): dist.all_reduce(module.weight.grad, op=dist.ReduceOp.SUM) if hasattr (module, 'bias' ) and module.bias is not None : dist.all_reduce(module.bias.grad, op=dist.ReduceOp.SUM) return grad_input model.fc1.register_backward_hook(tensor_parallel_backward_hook) model.fc2.register_backward_hook(tensor_parallel_backward_hook)
完整训练示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 def train_tensor_parallel (): """张量并行训练示例""" dist.init_process_group(backend='nccl' ) torch.cuda.set_device(dist.get_rank()) world_size = dist.get_world_size() rank = dist.get_rank() model = TensorParallelMLP(hidden_size=1024 , world_size=world_size) model = model.cuda() optimizer = torch.optim.Adam(model.parameters()) for batch in dataloader: input_data = batch['input' ].cuda() output = model(input_data) loss = compute_loss(output, batch['target' ].cuda()) optimizer.zero_grad() loss.backward() optimizer.step() if rank == 0 : print (f"Loss: {loss.item()} " ) if __name__ == '__main__' : train_tensor_parallel()
内存效率分析
张量并行的优势 :
内存节省 :每个设备只存储模型的一部分
通信效率 :只同步必要的中间结果
扩展性 :支持超大规模模型
与数据并行的对比 :
特性
数据并行
张量并行
内存使用
O(N)
O(N/P)
通信量
高(梯度同步)
中等(中间结果)
适用场景
中等模型
超大模型
实现复杂度
低
高
注意事项
1. 维度要求
权重矩阵维度必须能被world_size整除
输入输出维度需要仔细规划
2. 负载均衡
3. 通信优化
4. 调试困难
高级应用:Megatron-LM风格实现
对于更复杂的Transformer模型,可以参考Megatron-LM的实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 class ParallelTransformerLayer(nn.Module): """并行Transformer层""" def __init__(self, hidden_size, num_heads, world_size): super().__init__() # 张量并行注意力 self.attention = TensorParallelMultiHeadAttention( hidden_size, num_heads, world_size ) # 张量并行前馈网络 self.feed_forward = TensorParallelMLP(hidden_size, world_size) self.norm1 = nn.LayerNorm(hidden_size) self.norm2 = nn.LayerNorm(hidden_size) def forward(self, x): # 残差连接 + 注意力 attn_output = self.attention(self.norm1(x)) x = x + attn_output # 残差连接 + 前馈 ff_output = self.feed_forward(self.norm2(x)) x = x + ff_output return x
这个实现展示了如何将张量并行应用到完整的Transformer架构中,实现超大规模语言模型的训练。
状态检查函数
1 2 3 4 5 # 检查分布式环境状态 print(torch.distributed.is_initialized()) # 是否已初始化 print(torch.distributed.get_rank()) # 当前进程rank print(torch.distributed.get_world_size()) # 进程总数 print(torch.distributed.get_backend()) # 当前后端
进程组管理工具
1 2 3 4 5 6 default_group = torch.distributed.group.WORLD group_size = torch.distributed.get_world_size(default_group) print (torch.distributed.get_rank() in torch.distributed.get_process_group_ranks(default_group))
调试和监控工具
性能监控
megatron timer
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 import timedef benchmark_communication (tensor_size ): tensor = torch.randn(tensor_size).cuda() torch.cuda.synchronize() start_time = time.time() torch.distributed.all_reduce(tensor) torch.cuda.synchronize() end_time = time.time() return end_time - start_time
内存监控
1 2 3 4 5 6 def print_gpu_memory (): if torch.cuda.is_available(): print (f"Rank {torch.distributed.get_rank()} : " f"GPU memory: {torch.cuda.memory_allocated()/1e9 :.2 f} GB / " f"{torch.cuda.memory_reserved()/1e9 :.2 f} GB" )
实用工具函数
分布式随机种子设置
1 2 3 4 5 6 7 8 def set_random_seed (seed=42 ): """设置分布式训练的随机种子""" rank = torch.distributed.get_rank() torch.manual_seed(seed + rank) if torch.cuda.is_available(): torch.cuda.manual_seed(seed + rank) torch.cuda.manual_seed_all(seed + rank)
分布式验证
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 def distributed_validation (model, val_loader, criterion ): """分布式验证函数""" model.eval () total_loss = 0.0 total_correct = 0 total_samples = 0 with torch.no_grad(): for inputs, targets in val_loader: outputs = model(inputs) loss = criterion(outputs, targets) torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) total_loss += loss.item() _, predicted = outputs.max (1 ) correct = predicted.eq(targets).sum ().item() total_correct += correct total_samples += targets.size(0 ) world_size = torch.distributed.get_world_size() total_loss /= world_size total_correct /= world_size total_samples /= world_size accuracy = total_correct / total_samples return total_loss, accuracy