import torch
import collections
weights_files = "weight.pth.tar"
weights = torch.load(weights_files)
weights = collections.OrderedDict([('new_key', v) if k == 'key' else (k, v) for k, v in weights .items()])
for k, v in weights.items(): # key, value
print(k) # 打印参数名
torch.save(weights, "new_weight.pth.tar")
修改PyTorch权重文件中的key值
于 2023-11-28 15:02:43 首次发布