pytorch 使用DataParallel 单机多卡和单卡保存和加载模型的正确方法

1.单卡训练,单卡加载

这里我为了把三个模块save到同一个文件里,我选择对所有的模型先封装成一个checkpoint字典,然后保存到同一个文件里,这样就可以在加载时只需要加载一个参数文件。

保存:

states = {
        'state_dict_encoder': encoder.state_dict(),
        'state_dict_decoder': decoder.state_dict(),
    }
torch.save(states, fname)

加载:

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)

2.单卡训练,多卡加载

保存:

states = {
        'state_dict_encoder': encoder.state_dict(),
        'state_dict_decoder': decoder.state_dict(),
    }
torch.save(states, fname)

加载:
加载过程也没有任何改变,但是要注意,先加载模型参数,再对模型做并行化处理

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)

3.多卡训练,单卡加载

注意,如果你考虑到以后可能需要单卡加载你多卡训练的模型,建议在保存模型时,去除模型参数字典里面的module,如何去除呢,使用model.module.state_dict()代替model.state_dict()

保存:

states = {
        'state_dict_encoder': encoder.module.state_dict(), #不是encoder.state_dict()
        'state_dict_decoder': decoder.module.state_dict(),
    }
torch.save(states, fname)

加载:
要注意由于我们保存的方式是以单卡的方式保存的,所以还是要先加载模型参数,再对模型做并行化处理

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)

3.多卡训练,单卡加载,方法二

使用model.state_dict()保存,但是单卡加载的时候,要把模型做并行化(在单卡上并行)

保存:

states = {
        'state_dict_encoder': encoder.state_dict(), 
        'state_dict_decoder': decoder.state_dict(),
    }
torch.save(states, fname)

加载:
要注意由于我们保存的方式是以多卡的方式保存的,所以无论你加载之后的模型是在单卡运行还是在多卡运行,都先把模型并行化再去加载

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)

4.多卡保存,多卡加载

这就和多卡保存,单卡加载第二中方式一样了 使用model.state_dict()保存,加载的时候,要先把模型做并行化(在多卡上并行)

保存:

states = {
        'state_dict_encoder': encoder.state_dict(), 
        'state_dict_decoder': decoder.state_dict(),
    }
torch.save(states, fname)

加载: 要注意由于我们保存的方式是以多卡的方式保存的,所以无论你加载之后的模型是在单卡运行还是在多卡运行,都先把模型并行化再去加载

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch提供了方便的方式来保存加载模型,使得我们可以轻松地在本地保存模型并在需要的时候加载模型。 下面是一个保存加载模型的例子: 1. 保存模型 ``` import torch # 创建模型 model = torch.nn.Linear(10, 2) # 保存模型 torch.save(model.state_dict(), 'model.pth') ``` 在这个例子中,我们使用`torch.save()`函数将模型的参数保存在文件`model.pth`中,该文件将被保存在当前目录中。 2. 加载模型 ``` import torch # 创建模型 model = torch.nn.Linear(10, 2) # 加载模型 model.load_state_dict(torch.load('model.pth')) ``` 在这个例子中,我们使用`torch.load()`函数从文件`model.pth`中加载模型的参数,并使用`model.load_state_dict()`函数将参数加载模型中。 需要注意的是,这种方式只能保存加载模型的参数,而不是整个模型。如果想要保存整个模型,可以使用以下方式: 1. 保存模型 ``` import torch # 创建模型 model = torch.nn.Linear(10, 2) # 保存模型 torch.save(model, 'model.pth') ``` 在这个例子中,我们使用`torch.save()`函数将整个模型保存在文件`model.pth`中,该文件将被保存在当前目录中。 2. 加载模型 ``` import torch # 加载模型 model = torch.load('model.pth') ``` 在这个例子中,我们使用`torch.load()`函数从文件`model.pth`中加载整个模型。注意,这种方式只适用于Python对象的序列化,因此如果模型中有一些自定义的类,则需要自己手动重写加载函数。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值