import torch
import torch.nn as nn
import ipdb
class DataParallelModel(nn.Module):
def __init__(self):
super().__init__()
self.block1 = nn.Linear(10, 20)
def forward(self, x):
x = self.block1(x)
return x
def data_parallel(module, input, device_ids, output_device=None):
if not device_ids:
return module(input)
if output_device is None:
output_device = device_ids[0]
replicas = nn.parallel.replicate(module, device_ids)
print(f"replicas:{replicas}")
inputs = nn.parallel.scatter(input, device_ids)
print(f"inputs:{type(inputs)}")
for i in range(len(inputs)):
print(f"input {i}:{inputs[i].shape}")
replicas = replicas[:len(inputs)]
outputs = nn.parallel.parallel_apply(replicas, inputs)
print(f"outputs:{t
Pytorch并行计算:nn.parallel.replicate, scatter, gather, parallel_apply
最新推荐文章于 2024-03-14 14:52:11 发布
本文介绍了PyTorch中实现并行计算的关键步骤,包括模型复制`nn.parallel.replicate`,数据分散`scatter`,并行应用`parallel_apply`以及结果收集`gather`。通过实例展示了即使batch size不完全平均分配,也能自动调整进行并行计算的过程。
摘要由CSDN通过智能技术生成