pytorch不加载fc_pytorch模型保存与加载以及常见问题

本文详细介绍了在PyTorch中如何保存和加载模型,包括使用多GPU的情况。同时,文章列举并解决了在模型加载时可能遇到的问题,如状态字典的键不匹配、版本差异导致的问题、部分层不参与训练等,提供了相应的解决方案。
摘要由CSDN通过智能技术生成

一. 模型保存与加载

#多gpu

import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4' #choose

model = TheModelClass(*args, **kwargs)

model = torch.nn.DataParallel(model).cuda()

#加载与训练模型 file.pth.tar

checkpoint = torch.load('file.pth.tar')

model.load_state_dict[checkpoint]

#保存训练模型,保存路径为model_file

torch.save(model.state_dict(), model_file)

二.常见问题

1.Missing key(s) in state_dict: Unexpected key(s) in state_dict:

如果加载的预训练模型之前使用了torch.nn.DataParallel(),而此时的训练并没有使用,则会出现这样的错误。【该问题常出现在"训练模型与当前模型参数不完全一致时,需要update,而update又必须在DataParallel之前" 的这种情况下】

我们需要去掉参数中的前缀 module.

#方法1:

model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('myfile.pth').items()})

#方法2:

state_dict = torch.load('file.pth.tar')

from collections import OrderedDict

new_state_dict = OrderedDict()

for k, v in state_dict.items():

name = k[7:] # remove `module.`

new_state_dict[name] = v

model.load_state_dict(new_state_dict)

2.Pytorch会由于使用版本不同,加载模型可能会出现“num_batches_tracked”空的键值。导致加载模型不匹配问题。

同理,还有些出现“running"等问题,处理方法相同

#去掉“num_batches_tracked”空的键值

model_dict = net.state_dict()

new_model_dict = {}

for i in model_dict.items():

if "num_batches_tracked" in i[0]:

print (i[0])

else:

new_model_dict[i[0]] = i[1]

pretraine_dict = model.load_state_dict(torch.load('file.pth.tar'))

load_dict={}

for kv1,kv2 in zip(new_model_dict.items(),pretrained_dict.items()):

load_dict[kv1[0]] = kv2[1]

model_dict.update(load_dict)

net.load_state_dict(model_dict)

3.常见的比较简单的是,除了最后层其他层均加载

#如:加载除了 ‘fc’ 层之外的层!

net = vgg().......

model_dict = net.state_dict()

#filter

pretrained_dict = {k: v for k, v in new_state_dict.items() if k.find('fc')==-1}

#update

model_dict.update(pretrained_dict)

net.load_state_dict(model_dict)

4.加载的模型的参数 多余 当前模型本身的参数

model_dict = model.state_dict()

pretrained_dict = torch.load('file.pth.tar')

# 1. filter out unnecessary keys

pretrained_dict = {k: v for k, vin pretrained_dict.items() if k in model_dict}

# 2. overwrite entries in the existing state dict

model_dict.update(pretrained_dict)

# 3. load the new state dict

model.load_state_dict(model_dict)

5.加载的模型参数 少于 当前模型本身的参数

常用于自己改变了网络结构,给网络添加了新的层

model.load_state_dict(checkpoint['state_dict'],strict=False)

#load_state_dict严格匹配参数的键名称

#strict=False表示只加载与键值匹配的参数,并忽略其他参数键。

6.设置部分层不参与训练

for name,param in model.base_model.named_parameters():

if name 满足某些条件:

param.requires_grad = False

#同时,optimizer也需要做相应调整,只优化相应层

params = filter(lambda p: p.requires_grad, model.parameters()

optimizer = torch.optim.SGD(params,

args.lr,

momentum = args.momentum,

weight_decay = args.weight_decay)

7.pytorch版本问题:AttributeError: ‘module’ object has no attribute '_rebuild_tensor_v2’

这是因为训练模型时使用的是新版本的pytorch,而加载时使用的是旧版本的pytorch

解决办法:在代码开头加上:

参考链接:https://discuss.pytorch.org/t/question-about-rebuild-tensor-v2/14560

import torch._utils

try:

torch._utils._rebuild_tensor_v2

except AttributeError:

def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):

tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)

tensor.requires_grad = requires_grad

tensor._backward_hooks = backward_hooks

return tensor

torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值