保存和加载数据模型
状态字典state_dict
以字典的形式保存模型各层的参数, model.state_dict()
以字典的形式保存优化器的参数,optimizer.state_dict()
保存和加载模型
保存/加载state_dict(推荐使用)
# 保存 PATH = './download' torch.save(model.state_dict(), PATH) # 加载 model = TheModelClass() model.load_state_dict(torch.load(PATH)) model.eval()
保存/加载完整模型
# 保存 torch.save(model, PATH) #加载 model = torch.load(PATH) model.eval()
保存和加载Checkpoint
# 保存 epoch = 7 loss = torch.nn.CrossEntropyLoss() torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss,}, PATH) # 加载 model = TheModelClass() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) >>>optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() #- or - #model.train()```
在一个文件中保存多个模型
# 保存 modelA, modelB = model, model optimizerA, optimizerB = optimizer, optimizer torch.save({ 'modelA_state_dict': modelA.state_dict(), 'modelB_state_dict': modelB.state_dict(), 'optimizerA_state_dict': optimizerA.state_dict(), 'optimizerB_state_dict': optimizerB.state_dict(), }, PATH) #加载 modelA = TheModelClass() modelB = TheModelClass() optimizerA = optimizer optimizerB = optimizer checkpoint = torch.load(PATH) modelA.load_state_dict(checkpoint['modelA_state_dict']) modelB.load_state_dict(checkpoint['modelB_state_dict']) optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']) optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']) modelA.eval() modelB.eval() #- or - #modelA.train() #modelB.train()```
使用在不同模型参数下的热启动模式
无论是从缺少某些键的 state_dict 加载还是从键的数目多于加载模型的 state_dict , 都可以通过在load_state_dict()函数中将strict参数设置为 False 来忽略非匹配键的函数。
如果要将参数从一个层加载到另一个层,但是某些键不匹配,主要修改正在加载的 state_dict 中的参数键的名称以匹配要在加载到模型中的键即可。# 保存 torch.save(modelA.state_dict(), PATH) #加载 modelB = TheModelClass() modelB.load_state_dict(torch.load(PATH), strict=False)```
通过设备保存/加载模型
保存/加载到CPU
#保存 torch.save(model.state_dict(), PATH) device = torch.device('cpu') model = TheModelClass() model.load_state_dict(torch.load(PATH, map_location=device))```
保存/加载到GPU
# 保存 torch.save(model.state_dict(), PATH) #加载device = torch.device("cuda") model = TheModelClass() model.load_state_dict(torch.load(PATH)) model.to(device) #确保在你提供给模型的任何输入张量上调用input = input.to(device)```
保存到CPU,加载到GPU
# 保存 torch.save(model.state_dict(), PATH) #加载 device = torch.device("cuda") model = TheModelClass(\*args,\**kwargs) model.load_state_dict(torch.load(PATH,map_location="cuda:0")) ## Choose whatever GPU device number you want model.to(device) #确保在你提供给模型的任何输入张量上调用input = input.to(device)``
pytorch保存和加载数据模型
最新推荐文章于 2024-07-21 23:40:23 发布