![v2-10cf651a28ef42b5e14b334a44c7dbcd_1440w.jpg?source=172ae18b](http://img-02.proxy.5ce.com/view/image?&type=2&guid=c4b53eb2-bd2f-eb11-8da9-e4434bdf6706&url=https://pic4.zhimg.com/v2-10cf651a28ef42b5e14b334a44c7dbcd_1440w.jpg?source=172ae18b)
线性回归模型是非常基本的回归模型,其原理比较简单,所以文章仅仅简单介绍,并给出pytorch实现代码。
文章分为4个部分:
(1)线性回归模型的理论基础
(2)网络设计与pytorch实现
(3)构建数据与网络训练测试
(4)拟合过程可视化
1、线性回归模型的理论基础
这里简单的线性回归模型视为:y=kx+b,其中x是输入数据,k和b是需要学习的参数,y是网络的预测输出。那么学习的目的就是让网络预测输出尽可能接近真实标签y_real,所以损失函数可以使用MSE损失如下:
L=MSE(y,y_real)
优化的目的就是最小化L,从而学习到合适的w和b。
2、网络设计与pytorch实现
2.1 网络结构
还是继承pytorch提供的nn.Module()类。通过把nn.Linear()绑定到类实例属性,以及实现forward()方法实现前向传播:
class
2.2 优化算法选择SGD优化:
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
2.3 损失函数选择为MSE:
self.loss_function = torch.nn.MSELoss()
3、构建数据与网络训练测试
3.1 构建数据
这里构建的数据为一次函数的数据,也就是y=kx+b,但是真实的数据往往不是严格线性的,所以需要添加一个扰动噪声:
import torch
import matplotlib.pyplot as plt
def create_linear_data(nums_data, if_plot= False):
"""
Create data for linear model
Args:
nums_data: how many data points that wanted
Returns:
x with shape (nums_data, 1)
"""
x = torch.linspace(0,1,nums_data)
x = torch.unsqueeze(x,dim=1)
k = 2
y = k * x + torch.rand(x.size())
if if_plot:
plt.scatter(x.numpy(),y.numpy(),c=x.numpy())
plt.show()
data = {"x":x, "y":y}
return data
data = create_linear_data(300, if_plot=True)
print(data["x"].size())
得到的数据格式如下:
![v2-9676aa803f6d25ad4c1f4be6aa939e79_b.jpg](http://img-01.proxy.5ce.com/view/image?&type=2&guid=c4b53eb2-bd2f-eb11-8da9-e4434bdf6706&url=https://pic2.zhimg.com/v2-9676aa803f6d25ad4c1f4be6aa939e79_b.jpg)
3.2 网络训练
训练网络的顺序为:读取数据---数据送入网络---得到网络输出---用输出与标签计算损失---最小化损失---更新梯度。所以对应的代码如下:
def train(self, data, model_save_path="model.pth"):
"""
Train the model and save the parameters
Args:
model_save_path: saved name of model
data: (x, y) = data, and y = kx + b
Returns:
None
"""
x = data["x"]
y = data["y"]
for epoch in range(self.epoches):
prediction = self.model(x)
loss = self.loss_function(prediction, y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if epoch % 500 == 0:
print("epoch: {}, loss is: {}".format(epoch, loss.item()))
torch.save(self.model.state_dict(), "linear.pth")
代码最后一行torch.save()保存了模型的参数,用于测试阶段使用。
训练得到的结果如下:
![v2-64795b74f036eb19d3232aeb6936d656_b.jpg](http://img-01.proxy.5ce.com/view/image?&type=2&guid=c4b53eb2-bd2f-eb11-8da9-e4434bdf6706&url=https://pic3.zhimg.com/v2-64795b74f036eb19d3232aeb6936d656_b.jpg)
3.3 模型测试
模型测试阶段需要读取训练阶段保存的参数,并重新赋值给网络:
def test(self, x, model_path="linear.pth"):
"""
Reload and test the model, plot the prediction
Args:
model_path: the model's path and name
data: (x, y) = data, and y = kx + b
Returns:
None
"""
x = data["x"]
y = data["y"]
self.model.load_state_dict(torch.load(model_path))
prediction = self.model(x)
plt.scatter(x.numpy(), y.numpy(), c=x.numpy())
plt.plot(x.numpy(), prediction.detach().numpy(), color="r")
plt.show()
![v2-be3829f22681d703c8c06e1d1aa1d4fd_b.jpg](http://img-02.proxy.5ce.com/view/image?&type=2&guid=c4b53eb2-bd2f-eb11-8da9-e4434bdf6706&url=https://pic2.zhimg.com/v2-be3829f22681d703c8c06e1d1aa1d4fd_b.jpg)
4、完整代码
import torch
import matplotlib.pyplot as plt
def create_linear_data(nums_data, if_plot= False):
"""
Create data for linear model
Args:
nums_data: how many data points that wanted
Returns:
x with shape (nums_data, 1)
"""
x = torch.linspace(0,1,nums_data)
x = torch.unsqueeze(x,dim=1)
k = 2
y = k * x + torch.rand(x.size())
if if_plot:
plt.scatter(x.numpy(),y.numpy(),c=x.numpy())
plt.show()
data = {"x":x, "y":y}
return data
data = create_linear_data(300, if_plot=True)
print(data["x"].size())
class LinearRegression(torch.nn.Module):
"""
Linear Regressoin Module, the input features and output
features are defaults both 1
"""
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(1,1)
def forward(self,x):
out = self.linear(x)
return out
linear = LinearRegression()
print(linear)
class Linear_Model():
def __init__(self):
"""
Initialize the Linear Model
"""
self.learning_rate = 0.001
self.epoches = 10000
self.loss_function = torch.nn.MSELoss()
self.create_model()
def create_model(self):
self.model = LinearRegression()
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
def train(self, data, model_save_path="model.pth"):
"""
Train the model and save the parameters
Args:
model_save_path: saved name of model
data: (x, y) = data, and y = kx + b
Returns:
None
"""
x = data["x"]
y = data["y"]
for epoch in range(self.epoches):
prediction = self.model(x)
loss = self.loss_function(prediction, y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if epoch % 500 == 0:
print("epoch: {}, loss is: {}".format(epoch, loss.item()))
torch.save(self.model.state_dict(), "linear.pth")
def test(self, x, model_path="linear.pth"):
"""
Reload and test the model, plot the prediction
Args:
model_path: the model's path and name
data: (x, y) = data, and y = kx + b
Returns:
None
"""
x = data["x"]
y = data["y"]
self.model.load_state_dict(torch.load(model_path))
prediction = self.model(x)
plt.scatter(x.numpy(), y.numpy(), c=x.numpy())
plt.plot(x.numpy(), prediction.detach().numpy(), color="r")
plt.show()
def compare_epoches(self, data):
x = data["x"]
y = data["y"]
num_pictures = 16
fig = plt.figure(figsize=(10,10))
current_fig = 0
for epoch in range(self.epoches):
prediction = self.model(x)
loss = self.loss_function(prediction, y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if epoch % (self.epoches/num_pictures) == 0:
current_fig += 1
plt.subplot(4, 4, current_fig)
plt.scatter(x.numpy(), y.numpy(), c=x.numpy())
plt.plot(x.numpy(), prediction.detach().numpy(), color="r")
plt.show()
linear = Linear_Model()
data = create_linear_data(100)
# linear.train(data)
# linear.test(data)
linear.compare_epoches(data)
参考:
PyTorch 进阶之路(二):如何实现线性回归 -ZAKER新闻www.myzaker.com![v2-c14016f5d65702d9b7285f422d8c5d28_ipico.jpg](http://img-03.proxy.5ce.com/view/image?&type=2&guid=c4b53eb2-bd2f-eb11-8da9-e4434bdf6706&url=https://pic1.zhimg.com/v2-c14016f5d65702d9b7285f422d8c5d28_ipico.jpg)
![v2-2e57da85a4c6b7e51247f62dbd2986be_ipico.jpg](http://img-01.proxy.5ce.com/view/image?&type=2&guid=c4b53eb2-bd2f-eb11-8da9-e4434bdf6706&url=https://pic3.zhimg.com/v2-2e57da85a4c6b7e51247f62dbd2986be_ipico.jpg)