-
GPU保存,CPU加载
PATH = './model.pth'
torch.save(model.state_dict(), PATH)
device = torch.device('cpu')
model = Net()
model.load_state_dict(torch.load(PATH, map_location=device))
2.保存在GPU 上,在 GPU 上加载
device = torch.device("cuda")
model = Net()
model.load_state_dict(torch.load(PATH))
model.to(device)
3.保存 CPU 上,在 GPU 上加载
device = torch.device("cuda")
model = Net()
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # 选择哪个GPU
model.to(device)