pytorch中模型参数的保存与读取

本文介绍了如何使用PyTorch保存和加载训练好的神经网络模型。通过训练一个简单的神经网络来拟合数据,展示了两种保存模型的方法:一是直接保存整个模型,二是仅保存模型参数。在训练完成后,分别使用两种方法加载模型参数,验证了加载后的模型权重与原始模型一致。
摘要由CSDN通过智能技术生成

对训练好的模型,有两种保存方法:
1.直接将训练好的神经网络进行保存,但是速度会比较慢
2.将训练好的神经网络参数,保存到文件当中,然后进行文件的读取,再将读出的参数赋给新建好的模型,要求新建好的模型与之前的模型相同

import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

#1构建数据集:y=x2
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)

#2.搭建神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.hidden=nn.Linear(1,10)
        self.predict=nn.Linear(10,1)

    def forward(self,x):
        x=F.relu(self.hidden(x))
        x=self.predict(x)
        return x

net=Net()

#3.定义优化器和损失函数
optimizer=torch.optim.SGD(net.parameters(),lr=0.5)
loss_func=torch.nn.MSELoss()

#4.训练模型
for epoch in range(1,101):
    prediction=net(x) #向前传播 得到预期值
    loss=loss_func(prediction,y) #向前传播 算出损失量 构建计算图

    print(epoch,loss)

    optimizer.zero_grad()
    loss.backward() #向后传播 算出梯度 释放计算图
    optimizer.step()# 梯度下降

print(net.hidden.weight.detach())  #net.hidden.weight.item()
print(net.hidden.bias.detach())
print(net.predict.weight.detach())
print(net.predict.weight.detach())

#保存模型参数至文件
torch.save(net.state_dict(), 'net_parameters.pt')
#实例化参数,赋值
m_state_dict = torch.load('net_parameters.pt')
new_net=Net()
new_net.load_state_dict(m_state_dict)

print(new_net.hidden.weight.detach())  #net.hidden.weight.item()
print(new_net.hidden.bias.detach())
print(new_net.predict.weight.detach())
print(new_net.predict.weight.detach())

#保存模型至文件
torch.save(net, 'net.pt')
#实例化模型
new_nets=torch.load('net.pt')

print(new_nets.hidden.weight.detach())  #net.hidden.weight.item()
print(new_nets.hidden.bias.detach())
print(new_nets.predict.weight.detach())
print(new_nets.predict.weight.detach())

新建文件的位置在同一级文件下

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ShuaS2020

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值