pytorch生成种子,可重复训练:
##牺牲计算效率,提升准确率
from torch.backends import cudnn
cudnn.benchmark = False # if benchmark=True, deterministic will be False
cudnn.deterministic = True
###设置种子,保证可重复性
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
pytorch保存和导入权重
方法之一:
# 第一步:读取当前模型参数
model_dict = mlp.state_dict()
# 第二步:读取预训练模型
pretrained_dict = torch.load("E:\\桌面\\test_new_10\\weights", map_location = "GPU")
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
# 第三步:使用预训练的模型更新当前模型参数
model_dict.update(pretrained_dict)
# 第四步:加载模型参数
model.load_state_dict(model_dict)
方法之二:
保存的代码:
# 保存模型示例代码
print('===> Saving models...')
state = {
'state': model.state_dict(),
'epoch': epoch # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')
读取的代码:
print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
try:
checkpoint = torch.load('./checkpoint/autoencoder.t7')
model.load_state_dict(checkpoint['state']) # 从字典中依次读取
start_epoch = checkpoint['epoch']
print('===> Load last checkpoint data')
except FileNotFoundError:
print('Can\'t found autoencoder.t7')
else:
start_epoch = 0
print('===> Start from scratch')