1.单卡训练,单卡加载
这里我为了把三个模块save到同一个文件里,我选择对所有的模型先封装成一个checkpoint字典,然后保存到同一个文件里,这样就可以在加载时只需要加载一个参数文件。
保存:
states = {
'state_dict_encoder': encoder.state_dict(),
'state_dict_decoder': decoder.state_dict(),
}
torch.save(states, fname)
加载:
#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)
2.单卡训练,多卡加载
保存:
states = {
'state_dict_encoder': encoder.state_dict(),
'state_dict_decoder': decoder.state_dict(),
}
torch.save(states, fname)
加载:
加载过程也没有任何改变,但是要注意,先加载模型参数,再对模型做并行化处理
#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
3.多卡训练,单卡加载
注意,如果你考虑到以后可能需要单卡加载你多卡训练的模型,建议在保存模型时,去除模型参数字典里面的module,如何去除呢,使用model.module.state_dict()代替model.state_dict()
保存:
states = {
'state_dict_encoder': encoder.module.state_dict(), #不是encoder.state_dict()
'state_dict_decoder': decoder.module.state_dict(),
}
torch.save(states, fname)
加载:
要注意由于我们保存的方式是以单卡的方式保存的,所以还是要先加载模型参数,再对模型做并行化处理
#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
3.多卡训练,单卡加载,方法二
使用model.state_dict()保存,但是单卡加载的时候,要把模型做并行化(在单卡上并行)
保存:
states = {
'state_dict_encoder': encoder.state_dict(),
'state_dict_decoder': decoder.state_dict(),
}
torch.save(states, fname)
加载:
要注意由于我们保存的方式是以多卡的方式保存的,所以无论你加载之后的模型是在单卡运行还是在多卡运行,都先把模型并行化再去加载
#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)
4.多卡保存,多卡加载
这就和多卡保存,单卡加载第二中方式一样了 使用model.state_dict()保存,加载的时候,要先把模型做并行化(在多卡上并行)
保存:
states = {
'state_dict_encoder': encoder.state_dict(),
'state_dict_decoder': decoder.state_dict(),
}
torch.save(states, fname)
加载: 要注意由于我们保存的方式是以多卡的方式保存的,所以无论你加载之后的模型是在单卡运行还是在多卡运行,都先把模型并行化再去加载
#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)