1. 训练
# 1.当前版本信息
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 2. 设置device信息 和 创建model
model = UNetSeeInDark()
model._initialize_weights()
gpus = [0,1,2,3]
model = nn.DataParallel(model, device_ids=gpus)
device = torch.device('cuda:0')
model = model.cuda(device=gpus[0])
# 如果不使用并行,只需要注释掉 model = nn.DataParallel(model, device_ids=gpus)
2. 载入并行训练的模型的时候 ,可能需要replace module
def get_model():
# 1.当前版本信息
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))
# 2. model path and saved sample path
m_path = r'D:\savedmodel\ruipai_n2n\checkpoint_8000.pth'
# 3. load model
device = torch.device('cuda:0')
model_copy = UNetSeeInDark().to(device)
# checkpoint = torch.load(m_path)
# model_copy.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint.items()})
model_copy.load_state_dict(torch.load(m_path, map_location=device))
model_copy.eval()
return model_copy