# -*- coding: utf-8 -*-
# @Time :2020/3/1 11:38
# @Author :XiaoMa
# @File :5.py
import torch as t
from matplotlib import pyplot as plt
#设置随机数种子
t.manual_seed(1000)
from IPython import display
def get_fake_data(batch_size=8):
'''产生随机数据:y=x*2+3,加上了一些噪声'''
x=t.rand(batch_size,1)*20
y=x*2+(1+t.randn(batch_size,1))*3
return x,y
# x,y=get_fake_data()
# plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
# plt.show()
#随机初始化参数
w=t.rand(1,1)
b=t.zeros(1,1)
lr=0.001
for ii in range(20000):
x,y=get_fake_data()
#forward:计算loss
y_pred=x.mm(w) + b.expand_as(y)
loss=0.5*(y_pred-y)**2 #均方误差
loss=loss.sum()
#backward:手动计算梯度
dloss=1
dy_pred=dloss*(y_pred-y)
dw=x.t().mm(dy_pred)
db=dy_pred.sum()
#更新参数
w.sub_(lr*dw)
b.sub_(lr*db)
if ii %1000==0:
#画图
display.clear_output(True)
x=t.arange(0,20).view(-1,1)
y=x.float().mm(w) + b.expand_as(x)
plt.plot(x.numpy(),y.numpy()) #预测值
x2,y2=get_fake_data(batch_size=20)
plt.scatter(x2.numpy(),y2.numpy()) #true data
plt.xlim(0,20)
plt.ylim(0,41)
plt.show()
plt.pause(0.5)
print('w:',w.item(),'b',b.item())
程序学习的结果:
w: 1.59645676612854 b 0.14356936514377594
w: 1.95401132106781 b 2.5674352645874023
w: 1.9625165462493896 b 2.9557881355285645
w: 1.8890539407730103 b 3.001418352127075
w: 2.0090134143829346 b 2.9461421966552734
w: 2.0715832710266113 b 3.1086935997009277
w: 1.9367080926895142 b 2.9967336654663086
w: 2.0456297397613525 b 2.9070372581481934
w: 1.8773678541183472 b 2.946399450302124
w: 2.1052446365356445 b 3.148202419281006
w: 1.81801438331604 b 3.0318338871002197
w: 2.119907855987549 b 2.948784351348877
w: 2.113875150680542 b 3.1145777702331543
w: 1.9084047079086304 b 2.946380376815796
w: 2.09352970123291 b 2.8815550804138184
w: 2.078930616378784 b 3.0505666732788086
w: 1.9736963510513306 b 3.040769100189209
w: 1.9155480861663818 b 2.9243123531341553
w: 2.0103886127471924 b 3.0951311588287354
w: 1.8826669454574585 b 3.0347161293029785
w接近2,b接近3。
拟合的结果: