@本文来源于公众号:csdn2299,喜欢可以关注公众号 程序员学府
有时候我们训练了一个模型, 希望保存它下次直接使用,不需要下次再花时间去训练 ,本节我们来讲解一下PyTorch快速搭建神经网络及其保存提取方法详解
一、PyTorch快速搭建神经网络方法
先看实验代码:
import torch
import torch.nn.functional as F
# 方法1,通过定义一个Net类来建立神经网络
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
def forward(self, x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
net1 = Net(2, 10, 2)
print('方法1:\n', net1)
# 方法2 通过torch.nn.Sequential快速建立神经网络结构
net2 = torch.nn.Sequential(
torch.nn.Linear(2, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 2),
)
print('方法2:\n', net2)
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同
'''''
方法1:
Net (
(hidden):