from IPython import display
print("IPython ok");
from matplotlib import pyplot as plt
print("plt ok");
from mxnet import autograd,nd
print("mxnet ok");
import random
print("randomn ok");
import os,sys
IPython ok
plt ok
mxnet ok
randomn ok
#模型:y=Xw+b+ϵ
#设置特征数、样本数
num_inputs=2
num_examples=1000
#设置真实的权值w和偏移b
true_w=[2,-3.4]
true_b=4.2
#随机生成 两列1000行的数据,features的每一行是一个长度为2的向量 (相当于X)
features=nd.random.normal(scale=1,shape=(num_examples,num_inputs))
#标签生成 ,labels的每一行是一个长度为1的向量 (相当于Y)
labels=true_w[0]*features[:,0]+true_w[1]*features[:,1]+true_b
labels+=nd.random.normal(scale=0.01,shape=labels.shape)
#输出看一下
print(features[0], labels[0])
[-0.48362762 -0.921403 ]
<NDArray 2 @cpu(0)>
[6.353317]
<NDArray 1 @cpu(0)>
#输出看一下
print(features[0], labels[0])
#定义绘制散点图方法
def use_svg_display():
# 用矢量图显示
display.set_matplotlib_formats('svg')
def set_figsize(figsize=(3.5, 2.5)):
use_svg_display()
# 设置图的尺寸
plt.rcParams['figure.figsize'] = figsize
#绘制散点图
set_figsize()
plt.scatter(features[:, 1].asnumpy(), labels.asnumpy(), 1);
[-0.48362762 -0.921403 ]
<NDArray 2 @cpu(0)>
[6.353317]
<NDArray 1 @cpu(0)>
语法提示
一、关于range()
#range()返回一个iterator,例a=range(5)则返回range(0, 5),可以用for i in a: 来获取0到4
另外,如果希望直接得到一个list,那么就用list(range(5)),就可以获得0到4的list。 这个是python3和python2的区别之一
#定义读取数据的函数(遍历数据集)
#这个函数输入参数为(希望输出的随机样本数,输入特征,输出标签)
def data_iter(batch_size,features,labels):
num_examples=len(features)
#range()返回一个iterator,例a=range(5)则返回range(0, 5),可以用for i in a: 来获取0到4
indices=list(range(num_examples))
#这里上面的indices是一个[0..到总样本数]的列表,即[0,1,2,3,4,5.....]
#random.shuffle()目的是将indices重新随机排序
random.shuffle(indices)
#i从0开始,到总样本数,间隔为batch_size(步长)
for i in range(0,num_examples,batch_size):
#print("now i is:",i)
#j是一个array,为indices[i,i+batch_size]
j=nd.array(indices[i:min(i+batch_size,num_examples)])
#print("now j is:",j)
#这里返回的不是一个一个返回,而是以数组作为索引来返回,先返回输入特征,features[i]到features[i+batch_size]
#再返回labels,即返回labels[i]到labels[i+batch_size]
yield features.take(j),labels.take(j) #take函数根据索引返回对应元素
batch_size = 10
#只生成一次
#next(data_iter(batch_size, features, labels))
#这里也是只生成一次,因为有个break
for X, y in data_iter(batch_size, features, labels):
print(X, y)
break
[[-0.5459386 -1.8447278 ]
[-0.2551912 -1.3350252 ]
[ 0.38828224 -0.21730624]
[-0.41434413 0.79045075]
[-0.9193256 -0.59021354]
[ 0.10230371 0.56393176]
[-0.63922304 -0.4232461 ]
[-0.8403067 -1.2707146 ]
[-0.23893945 -0.86148673]
[ 0.3051429 -1.6782148 ]]
<NDArray 10x2 @cpu(0)>
[ 9.398996 8.23137 5.7183204 0.67691046 4.360369 2.4872274
4.3754888 6.833427 6.65252 10.516858 ]
<NDArray 10 @cpu(0)>
#初始化要估计的参数
#为什么w是2行1列的?
#答:因为我们的x(即输入是一行两列的),按照公式y_head=x矩阵乘w+b,所以w为2行1列
w=nd.random.normal(scale=0.01,shape=(num_inputs,1))
b=nd.zeros(shape=(1,))
print(w)
print(true_w)
[[ 0.00059183]
[-0.00173846]]
<NDArray 2x1 @cpu(0)>
[2, -3.4]
#创建待估计参数的梯度
#调用attach_grad函数来申请存储梯度所需要的内存
w.attach_grad()
b.attach_grad()
#定义模型计算公式
def linreg(X,w,b): #实际上这个函数已经封装在d2lzh包中
#利用nd中的.dot()函数做矩阵乘法
return nd.dot(X,w)+b
#定义损失函数
def squared_loss(y_hat,y):
#这里用了reshape是为了防止结构不一致时它自动复制来构建与前面相同结构
return(y_hat-y.reshape(y_hat.shape))**2/2;
#定义优化算法,这里用梯度下降法
def sgd(params,lr,batch_size):
#这里的意思就是,对传入的参数(列表形式的),对应元素-学习率*梯度/batch_size
#为什么要除以batch_size,
#答:原因是mxnet自动求梯度模块计算出来的梯度是批量样本的梯度和,所以要除以样本数来取平均值
for param in params:
param[:]=param-lr*param.grad/batch_size
#训练模型部分
lr=0.03
num_epochs=3
net=linreg
loss=squared_loss
#迭代次数设置
for epoch in range(num_epochs):
#在每次迭代中,会使用训练集中的所有样本一次
#X、y分别是小批量样本的特征和标签
for X,y in data_iter(batch_size,features,labels):
#记录相关计算?
with autograd.record():
#计算损失函数
l=loss(net(X,w,b),y) #这里的l是由每次随机选择出来的小批量样本根据待估计参数估计出来与真实值的损失
l.backward() #计算损失函数对param的梯度
sgd([w,b],lr,batch_size) #更新待估计参数
train_l=loss(net(features,w,b),labels) #训练之后的损失
print("epoch %d ,loss %f" % (epoch+1,train_l.mean().asnumpy()))
epoch 1 ,loss 16.032024
epoch 1 ,loss 15.253522
epoch 1 ,loss 14.489670
epoch 1 ,loss 13.800489
epoch 1 ,loss 12.592360
epoch 1 ,loss 11.466541
epoch 1 ,loss 10.983962
epoch 1 ,loss 10.376640
epoch 1 ,loss 10.049289
epoch 1 ,loss 9.340507
epoch 1 ,loss 8.621167
epoch 1 ,loss 8.071239
epoch 1 ,loss 7.874382
epoch 1 ,loss 7.561985
epoch 1 ,loss 6.782496
epoch 1 ,loss 6.073438
epoch 1 ,loss 5.859311
epoch 1 ,loss 5.377839
epoch 1 ,loss 5.043412
epoch 1 ,loss 4.618695
epoch 1 ,loss 4.380817
epoch 1 ,loss 4.016131
epoch 1 ,loss 3.875211
epoch 1 ,loss 3.573936
epoch 1 ,loss 3.295837
epoch 1 ,loss 3.060077
epoch 1 ,loss 2.751062
epoch 1 ,loss 2.635000
epoch 1 ,loss 2.492720
epoch 1 ,loss 2.375383
epoch 1 ,loss 2.243919
epoch 1 ,loss 2.097693
epoch 1 ,loss 2.012014
epoch 1 ,loss 1.835119
epoch 1 ,loss 1.742900
epoch 1 ,loss 1.721242
epoch 1 ,loss 1.554879
epoch 1 ,loss 1.428716
epoch 1 ,loss 1.357238
epoch 1 ,loss 1.231590
epoch 1 ,loss 1.089459
epoch 1 ,loss 1.049516
epoch 1 ,loss 0.998722
epoch 1 ,loss 0.900703
epoch 1 ,loss 0.838384
epoch 1 ,loss 0.787540
epoch 1 ,loss 0.751771
epoch 1 ,loss 0.713378
epoch 1 ,loss 0.662468
epoch 1 ,loss 0.606591
epoch 1 ,loss 0.581789
epoch 1 ,loss 0.551202
epoch 1 ,loss 0.516823
epoch 1 ,loss 0.493626
epoch 1 ,loss 0.461487
epoch 1 ,loss 0.429556
epoch 1 ,loss 0.398790
epoch 1 ,loss 0.383569
epoch 1 ,loss 0.364635
epoch 1 ,loss 0.348120
epoch 1 ,loss 0.315813
epoch 1 ,loss 0.287658
epoch 1 ,loss 0.262841
epoch 1 ,loss 0.255787
epoch 1 ,loss 0.248302
epoch 1 ,loss 0.234603
epoch 1 ,loss 0.223034
epoch 1 ,loss 0.212251
epoch 1 ,loss 0.200142
epoch 1 ,loss 0.187317
epoch 1 ,loss 0.176090
epoch 1 ,loss 0.169432
epoch 1 ,loss 0.154877
epoch 1 ,loss 0.146811
epoch 1 ,loss 0.142990
epoch 1 ,loss 0.137739
epoch 1 ,loss 0.130291
epoch 1 ,loss 0.117630
epoch 1 ,loss 0.109481
epoch 1 ,loss 0.106655
epoch 1 ,loss 0.099498
epoch 1 ,loss 0.093083
epoch 1 ,loss 0.087889
epoch 1 ,loss 0.082439
epoch 1 ,loss 0.077900
epoch 1 ,loss 0.070729
epoch 1 ,loss 0.068127
epoch 1 ,loss 0.065624
epoch 1 ,loss 0.063323
epoch 1 ,loss 0.059628
epoch 1 ,loss 0.055787
epoch 1 ,loss 0.052828
epoch 1 ,loss 0.051324
epoch 1 ,loss 0.047842
epoch 1 ,loss 0.045890
epoch 1 ,loss 0.042634
epoch 1 ,loss 0.039844
epoch 1 ,loss 0.036246
epoch 1 ,loss 0.034407
epoch 1 ,loss 0.032353
epoch 2 ,loss 0.030018
epoch 2 ,loss 0.026958
epoch 2 ,loss 0.025471
epoch 2 ,loss 0.024189
epoch 2 ,loss 0.022893
epoch 2 ,loss 0.021189
epoch 2 ,loss 0.020086
epoch 2 ,loss 0.019043
epoch 2 ,loss 0.017653
epoch 2 ,loss 0.017073
epoch 2 ,loss 0.016404
epoch 2 ,loss 0.015186
epoch 2 ,loss 0.013850
epoch 2 ,loss 0.012788
epoch 2 ,loss 0.011747
epoch 2 ,loss 0.011144
epoch 2 ,loss 0.010457
epoch 2 ,loss 0.009806
epoch 2 ,loss 0.009321
epoch 2 ,loss 0.008852
epoch 2 ,loss 0.008318
epoch 2 ,loss 0.007895
epoch 2 ,loss 0.007369
epoch 2 ,loss 0.007081
epoch 2 ,loss 0.006513
epoch 2 ,loss 0.006094
epoch 2 ,loss 0.005815
epoch 2 ,loss 0.005599
epoch 2 ,loss 0.005391
epoch 2 ,loss 0.005240
epoch 2 ,loss 0.004764
epoch 2 ,loss 0.004334
epoch 2 ,loss 0.004216
epoch 2 ,loss 0.004025
epoch 2 ,loss 0.003883
epoch 2 ,loss 0.003643
epoch 2 ,loss 0.003530
epoch 2 ,loss 0.003371
epoch 2 ,loss 0.003195
epoch 2 ,loss 0.002966
epoch 2 ,loss 0.002790
epoch 2 ,loss 0.002612
epoch 2 ,loss 0.002443
epoch 2 ,loss 0.002359
epoch 2 ,loss 0.002150
epoch 2 ,loss 0.001961
epoch 2 ,loss 0.001887
epoch 2 ,loss 0.001691
epoch 2 ,loss 0.001565
epoch 2 ,loss 0.001477
epoch 2 ,loss 0.001428
epoch 2 ,loss 0.001376
epoch 2 ,loss 0.001300
epoch 2 ,loss 0.001248
epoch 2 ,loss 0.001166
epoch 2 ,loss 0.001115
epoch 2 ,loss 0.001063
epoch 2 ,loss 0.001003
epoch 2 ,loss 0.000918
epoch 2 ,loss 0.000886
epoch 2 ,loss 0.000796
epoch 2 ,loss 0.000739
epoch 2 ,loss 0.000670
epoch 2 ,loss 0.000618
epoch 2 ,loss 0.000602
epoch 2 ,loss 0.000563
epoch 2 ,loss 0.000538
epoch 2 ,loss 0.000510
epoch 2 ,loss 0.000467
epoch 2 ,loss 0.000445
epoch 2 ,loss 0.000423
epoch 2 ,loss 0.000393
epoch 2 ,loss 0.000382
epoch 2 ,loss 0.000344
epoch 2 ,loss 0.000332
epoch 2 ,loss 0.000308
epoch 2 ,loss 0.000285
epoch 2 ,loss 0.000277
epoch 2 ,loss 0.000262
epoch 2 ,loss 0.000255
epoch 2 ,loss 0.000232
epoch 2 ,loss 0.000218
epoch 2 ,loss 0.000206
epoch 2 ,loss 0.000190
epoch 2 ,loss 0.000187
epoch 2 ,loss 0.000178
epoch 2 ,loss 0.000171
epoch 2 ,loss 0.000167
epoch 2 ,loss 0.000163
epoch 2 ,loss 0.000157
epoch 2 ,loss 0.000154
epoch 2 ,loss 0.000146
epoch 2 ,loss 0.000142
epoch 2 ,loss 0.000134
epoch 2 ,loss 0.000130
epoch 2 ,loss 0.000128
epoch 2 ,loss 0.000123
epoch 2 ,loss 0.000120
epoch 2 ,loss 0.000118
epoch 2 ,loss 0.000115
epoch 3 ,loss 0.000111
epoch 3 ,loss 0.000109
epoch 3 ,loss 0.000106
epoch 3 ,loss 0.000103
epoch 3 ,loss 0.000098
epoch 3 ,loss 0.000095
epoch 3 ,loss 0.000093
epoch 3 ,loss 0.000089
epoch 3 ,loss 0.000086
epoch 3 ,loss 0.000083
epoch 3 ,loss 0.000082
epoch 3 ,loss 0.000081
epoch 3 ,loss 0.000078
epoch 3 ,loss 0.000077
epoch 3 ,loss 0.000074
epoch 3 ,loss 0.000073
epoch 3 ,loss 0.000072
epoch 3 ,loss 0.000071
epoch 3 ,loss 0.000070
epoch 3 ,loss 0.000068
epoch 3 ,loss 0.000066
epoch 3 ,loss 0.000065
epoch 3 ,loss 0.000064
epoch 3 ,loss 0.000064
epoch 3 ,loss 0.000062
epoch 3 ,loss 0.000062
epoch 3 ,loss 0.000062
epoch 3 ,loss 0.000061
epoch 3 ,loss 0.000060
epoch 3 ,loss 0.000060
epoch 3 ,loss 0.000059
epoch 3 ,loss 0.000059
epoch 3 ,loss 0.000059
epoch 3 ,loss 0.000059
epoch 3 ,loss 0.000058
epoch 3 ,loss 0.000058
epoch 3 ,loss 0.000058
epoch 3 ,loss 0.000057
epoch 3 ,loss 0.000057
epoch 3 ,loss 0.000057
epoch 3 ,loss 0.000057
epoch 3 ,loss 0.000056
epoch 3 ,loss 0.000056
epoch 3 ,loss 0.000055
epoch 3 ,loss 0.000055
epoch 3 ,loss 0.000055
epoch 3 ,loss 0.000054
epoch 3 ,loss 0.000054
epoch 3 ,loss 0.000054
epoch 3 ,loss 0.000054
epoch 3 ,loss 0.000054
epoch 3 ,loss 0.000054
epoch 3 ,loss 0.000054
epoch 3 ,loss 0.000053
epoch 3 ,loss 0.000053
epoch 3 ,loss 0.000054
epoch 3 ,loss 0.000053
epoch 3 ,loss 0.000053
epoch 3 ,loss 0.000053
epoch 3 ,loss 0.000053
epoch 3 ,loss 0.000053
epoch 3 ,loss 0.000053
epoch 3 ,loss 0.000053
epoch 3 ,loss 0.000053
epoch 3 ,loss 0.000053
epoch 3 ,loss 0.000053
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
epoch 3 ,loss 0.000052
print("model w is:",w.asnumpy())
print("true w is:",true_w)
print("model b is:",b)
print("true b is:",true_b)
model w is: [[ 2.0001972]
[-3.399224 ]]
true w is: [2, -3.4]
model b is:
[4.1994486]
<NDArray 1 @cpu(0)>
true b is: 4.2