import torch as tc
import matplotlib.pyplot as plt
class EasyReg(object):
def __init__(self, x, y, w, b):
self.x = x
self.y = y
self.w = w
self.b = b
def predict(self):
return tc.matmul(self.x, self.w) + b
def lossFunc(self, ypre):
return (self.y - ypre).pow(2).mean()
def training(self, learningRate, iterNum):
for i in range(iterNum):
ypre = self.predict()
loss = self.lossFunc(ypre)
if self.w.grad != None:
self.w.grad.data.zero_()
if self.b.grad != None:
self.b.grad.data.zero_()
loss.backward()
self.w.data -= learningRate * self.w.grad.data
self.b.data -= learningRate * self.b.grad.data
if i % 30 == 0:
print("w: ", self.w.data, "b: ", self.b.data, "loss: ", loss.data)
return self.predict()
x = tc.rand([100, 1], dtype = tc.float32) * 5
r = tc.rand([100, 1])
y = x * 3 + r
w = tc.rand([1, 1], dtype = tc.float32, requires_grad = True)
b = tc.rand(1, dtype = tc.float32, requires_grad = True)
test = EasyReg(x, y, w, b)
newY = test.training(0.1, 1000)
plt.scatter(x.numpy().reshape(-1), y.data.numpy().reshape(-1))
plt.plot(x.numpy().reshape(-1), newY.data.numpy().reshape(-1), c = "red")
plt.show()
对于带有扰动的y(x) = y + e ,寻找一条直线能尽可能的反应y,则令y = w*x+b,损失函数
loss = 实际值和预测值的均方根误差。在训练中利用梯度下降法使loss不断减小,便可以最终找到
一条最优的直线。