net = torch.nn.Linear(10,1)
print(net)
print('---------------------')
net = torch.nn.DataParallel(net, device_ids=[0,3])
print(net)
#######################
##以下是输出:
Linear(in_features=10, out_features=1, bias=True)
---------------------
DataParallel(
(module): Linear(in_features=10, out_features=1, bias=True)
)
3. 如何保存和加载多GPU网络?
Pytorch模型保存和加载的方法,可以看我的这篇博客。在这里不做详细的介绍了,这里只展示如何来保存和加载多GPU网络,它与普通网络有一点细微的不同。废话不多说,直接上代码:
net = torch.nn.Linear(10,1) # 先构造一个网络
net = torch.nn.DataParallel(net, device_ids=[0,3]) #包裹起来
torch.save(net.module.state_dict(), './networks/multiGPU.pt') #保存网络
# 加载网络
new_net = torch.nn.Linear(10,1)
new_net.load_state_dict(torch.load("./networks/multiGPU.pt"))