- 多GPU:对于一个model,使用多GPU时:
此时变为DataParallel对象,如果要访问原来model的参数attribute,则用model.module.attribute (原来model.attribute)model = torch.nn.DataParallel(model, device_ids=gpu_ids).cuda()
- model的保存:
torch.save(network.cpu().state_dict(), save_path)
model的载入:
model.load_state_dict(torch.load("./model/PCB-32/net.pth"))
-
pytorch--little tricks
最新推荐文章于 2023-04-03 18:47:03 发布