训练整个网络的详细注释
简单的例子反复写,熟能成巧,直到下笔有神
#训练一个网络
#主要分为以下几个部分:
#(1) 数据准备
#(2)定义网络
#(3)其实还包括损失函数和优化器,但是由于其是调用现成的函数,那么我们简单列举即可
#(4)开始训练,迭代
import torch
from sklearn.datasets import load_boston
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import numpy as np
import torch.utils.data as Data
import torch.nn as nn
from torch.optim import SGD
#数据准备流程
boston_x,boston_y = load_boston(return_X_y=True)
# print("boston_x.shape:",boston_x.shape)
# plt.figure()
# plt.hist(boston_y,bins=20)
# plt.show()
#数据标准化
ss = StandardScaler(with_mean=True,with_std=True)
boston_xs = ss.fit_transform(boston_x)
train_xt = torch.from_numpy(boston_xs.astype(np.float32))
train_yt = torch.from_numpy(boston_y.astype(np.float32))
train_data = Data.TensorDataset(train_xt,train_yt)
train_loader = Data.DataLoader(
dataset=train_data,
batch_size=128,
shuffle=True,
num_workers=0,
)
#定义网络
class MLPNET(nn.Module):
def __init__(self):
super(MLPNET, self).__init__()
#写法1
# self.hidden1 = nn.Linear(
# in_features=13,
# out_features=10,
# bias=True,
# )
# self.active1 = nn.ReLU()
# self.hidden2 = nn.Linear(10,10)
# self.active2 = nn.ReLU()
# self.regression = nn.Linear(10,1)
#写法2
self.hidden = nn.Sequential(
nn.Linear(13,10),
nn.ReLU(),
nn.Linear(10,10),
nn.ReLU(),
)
self.regression = nn.Linear(10,1)
def forward(self,x):
#写法1
# x = self.hidden1(x)
# x = self.active1(x)
# x = self.hidden2(x)
# x = self.active2(x)
# output = self.regression(x)
# return output
#写法2
x = self.hidden(x)
output = self.regression(x)
return output
mlpnet = MLPNET()
print(mlpnet)
optimizer = SGD(mlpnet.parameters(),lr=0.001)
loss_func = nn.MSELoss()
train_loss_all = []
for epoch in range(30):
for step,(b_x,b_y) in enumerate(train_loader):
output = mlpnet(b_x).flatten()
train_loss = loss_func(output,b_y)
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
train_loss_all.append(train_loss.item())
plt.figure()
plt.plot(train_loss_all,"r-")
plt.title("Train loss per iteration")
plt.show()
# #保存模型:整个模型
# torch.save(mlpnet,"E:\Master_Cource\高级编程\propcess\mlpnet.pkl")
# #保存后加载出来看看
# mlpnetload = torch.load("E:\Master_Cource\高级编程\propcess\mlpnet.pkl")
# print(mlpnetload) #显示的是模型的结构
#只保存模型的参数
torch.save(mlpnet.state_dict(),"E:\Master_Cource\高级编程\propcess\mlpnet_state_dict.pkl")
mlpnet_param = torch.load("E:\Master_Cource\高级编程\propcess\mlpnet_state_dict.pkl")
print(mlpnet_param) #显示的是Tensor,模型各层的参数