模型的保存和加载
1 只保存和加载模型参数
torch.save(model.state_dict(), PATH) ###将模型的参数保存到这个地址下,后缀名为pt
model = model(*args, **kwargs) ###定义模型
model.load_state_dict(torch.load(PATH, map_location=lambda storage, loc: storage)) ##导入模型参数
2 保存和加载整个模型
torch.save(model,path)
model=torch.load(path)
这种方式可以直接保存整个模型,在应用的时候不用再重新定义模型。
定义网络结构
这里定义了最简单的网络结构。两层的全连接层
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1=nn.Linear(1,3) ###线性层
self.layer2=nn.Linear(3,1)
def forward(self,x):
x=self.layer1(x)
x=torch.relu(x) ###relu激活函数
x=self.layer2(x)
return x
训练神经网络
import torch
import torch.nn as nn
import numpy as np