# 存取模型
PATH='./xxx.xx' # 模型名字
torch.save(net.state_dict(),PATH) # 保存,net为需要保存的网络
pretrained_net = torch.load(PATH) # 读取
net2 = Net() # Net()为保存的模型同结构的模型
net2.load_state_dict(pretrained_net) # 加载权重
pytorch模型存取与加载
最新推荐文章于 2024-05-03 23:20:06 发布