批量修改pth文件里的参数名

#Batch modify pth files#
import torch

#导入pth文件
path_2='./model.pth'
model_2=dict(torch.load(path_2))

#原pth文件有的params 将其换成state_dict
model_2['state_dict'] = model_2.pop('params')

dict=[]
for k in model_2['state_dict'].keys():
    k_="{}".format(k) #在输出变量时加上引号
    dict.append(k)

#修改成新名字
for k in dict:
    #k是旧名
    k_="{}".format(k) #在输出变量时加上引号
    older_val=model_2['state_dict'][k_]
    #print(k_)

    #新名
    k_new="generator.{}".format(k_)
    #print(k_new)

    # 修改参数名,pop该方法返回从列表中移除的元素对象。
    model_2['state_dict'][k_new] = model_2['state_dict'].pop(k_)

torch.save(model_2,'./model_changed.pth')
print(model_2)

#merge path_1\path_2


import torch

path_1='/model_1.pth'
path_2='/model_2.pth'

model_1=torch.load(path_1)
model_2=torch.load(path_2)

for k,v in model_2.items():
   for i,j in v.items():
      model_1['state_dict'][i]=j
      #print(j)
      #print(i)
      #print(model_2['state_dict'][i])

torch.save(model_1,'./model_3.pth')

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值