PyTorch 修改权重/字典 Key

想对wiograd后的训练添加预训练权重,因为修改卷积层kernel尺寸后,用了新的key名, 所以修改了一下.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import *
import os
import argparse
from model.vggnet_4_bn import VGG

parser=argparse.ArgumentParser()
# parser.add_argument('--pre_weights', type = str, default = 'rename_winograd/rename_winograd.pth', help = 'pretrained weights')
parser.add_argument('--pre_weights', type = str, default = 'ckp_bn_01_vgg4/model_5_0.9785.pth', help = 'pretrained weights')
opt=parser.parse_args()
print(opt)

model = VGG()
model.load_state_dict(torch.load(opt.pre_weights))
model.cuda()
# 修改model的名字为  features.0.weight  ---> features.0.inner_conv2d.weight, 保留 features.0.weight
from collections import OrderedDict
new_dict = OrderedDict()
for key in model.state_dict():
	if key == "features.0.weight":
		new_dict["features.0.inner_conv2d.weight"] = model.state_dict()[key]
		new_dict["features.0.weight"] = model.state_dict()[key]
	elif key == "features.0.bias":
		new_dict["features.0.inner_conv2d.bias"] = model.state_dict()[key]
		new_dict["features.0.bias"] = model.state_dict()[key]
	elif key == "features.4.weight":
		new_dict["features.4.inner_conv2d.weight"] = model.state_dict()[key]
		new_dict["features.4.weight"] = model.state_dict()[key]
	elif key == "features.4.bias":
		new_dict["features.4.inner_conv2d.bias"] = model.state_dict()[key]
		new_dict["features.4.bias"] = model.state_dict()[key]
	elif key == "features.8.weight":
		new_dict["features.8.inner_conv2d.weight"] = model.state_dict()[key]
		new_dict["features.8.weight"] = model.state_dict()[key]
	elif key == "features.8.bias":
		new_dict["features.8.inner_conv2d.bias"] = model.state_dict()[key]
		new_dict["features.8.bias"] = model.state_dict()[key]
	elif key == "features.12.weight":
		new_dict["features.12.inner_conv2d.weight"] = model.state_dict()[key]
		new_dict["features.12.weight"] = model.state_dict()[key]
	elif key == "features.12.bias":
		new_dict["features.12.inner_conv2d.bias"] = model.state_dict()[key]
		new_dict["features.12.bias"] = model.state_dict()[key]
	else:
		new_dict[key] = model.state_dict()[key]
		
print(new_dict.keys())
MODEL_PATH = "/home/aiden00/pytorch_classfication_person/personvscar_pytorch_pq/rename_winograd/" 
if not os.path.exists(MODEL_PATH):
	os.makedirs(MODEL_PATH)	
torch.save(new_dict, MODEL_PATH + 'model_' + 'winograd' + '.pth')  


  • 10
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值