import torch def change_feature(check_point): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 由于本文中是使用cpu,因此使用torch.load中将设备加载到cpu中,实际上可以直接使用torch.load进行加载,默认是cpu设备。 check_point = torch.load(check_point, map_location=device) import collections dicts = collections.OrderedDict() for k, value in check_point.items(): print("names:{}".format(k)) # 打印结构 print("shape:{}".format(value.size())) # if "module" in k: # 去除命名中的module # k = k.split(".")[1:] # k = ".".join(k) # print k # dicts[k] = value # torch.save(dicts, 'logs/vgg/2022.7.21_1/ep004-loss7.409-val_loss7.197.pth') if __name__ == "__main__": model_path ='model_data/ssd_weights.pth' change_feature(model_path)
权值文件查看代码pytorch
于 2022-07-23 21:06:56 首次发布