pytorch保存和加载数据模型

保存和加载数据模型

状态字典state_dict

以字典的形式保存模型各层的参数, model.state_dict()

以字典的形式保存优化器的参数,optimizer.state_dict()

保存和加载模型

保存/加载state_dict(推荐使用)
# 保存
PATH = './download'
torch.save(model.state_dict(), PATH)

# 加载  
model = TheModelClass()
model.load_state_dict(torch.load(PATH))
model.eval()
保存/加载完整模型
# 保存
torch.save(model, PATH)
#加载 
model = torch.load(PATH)
model.eval()

保存和加载Checkpoint

# 保存
epoch = 7
loss = torch.nn.CrossEntropyLoss()
torch.save({
     'epoch': epoch,
     'model_state_dict': model.state_dict(),
     'optimizer_state_dict': optimizer.state_dict(),
     'loss': loss,}, PATH)

# 加载

model = TheModelClass()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict']) >>>optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
#- or -
#model.train()```

在一个文件中保存多个模型

# 保存 
modelA, modelB = model, model
optimizerA, optimizerB = optimizer, optimizer
torch.save({
   'modelA_state_dict': modelA.state_dict(),
   'modelB_state_dict': modelB.state_dict(),
   'optimizerA_state_dict': optimizerA.state_dict(),
   'optimizerB_state_dict': optimizerB.state_dict(),
   }, PATH)

#加载 
modelA = TheModelClass()
modelB = TheModelClass()
optimizerA = optimizer
optimizerB = optimizer
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
modelA.eval()
modelB.eval()
#- or -
#modelA.train()
#modelB.train()```

使用在不同模型参数下的热启动模式

无论是从缺少某些键的 state_dict 加载还是从键的数目多于加载模型的 state_dict , 都可以通过在load_state_dict()函数中将strict参数设置为 False 来忽略非匹配键的函数。
如果要将参数从一个层加载到另一个层,但是某些键不匹配,主要修改正在加载的 state_dict 中的参数键的名称以匹配要在加载到模型中的键即可。

# 保存
torch.save(modelA.state_dict(), PATH)
#加载
modelB = TheModelClass()
modelB.load_state_dict(torch.load(PATH), strict=False)```

通过设备保存/加载模型

保存/加载到CPU

#保存  
torch.save(model.state_dict(), PATH)
device = torch.device('cpu')
model = TheModelClass()
model.load_state_dict(torch.load(PATH, map_location=device))```

保存/加载到GPU

# 保存
torch.save(model.state_dict(), PATH)
#加载device = torch.device("cuda")
model = TheModelClass()
model.load_state_dict(torch.load(PATH))
model.to(device)
#确保在你提供给模型的任何输入张量上调用input = input.to(device)```

保存到CPU,加载到GPU

# 保存
torch.save(model.state_dict(), PATH)
#加载
device = torch.device("cuda")
model = TheModelClass(\*args,\**kwargs)
model.load_state_dict(torch.load(PATH,map_location="cuda:0")) ## Choose whatever GPU device number you want
model.to(device)
#确保在你提供给模型的任何输入张量上调用input = input.to(device)``
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值