1.初始化参数和获取训练数据
%matplotlib inline
import gluonbook as gb
from mxnet import autograd, nd
batch_size=256
train_iter, test_iter = gb.load_data_fashion_mnist(batch_size)#在线下载训练数据
num_inputs=784 #每个图像的宽高为28*28
num_outputs=10 #10个标签
w=nd.random.normal(scale=0.01,shape=(num_inputs,num_outputs)) #随机生成 w(服从正态分布)
b=nd.zeros(num_outputs)
w.attach_grad() #创建梯度
b.attach_grad()
2.softmax计算与定义模型
表达样本预测各个输出的概率,softmax 运算会先通过 exp 函数对每个元素做指数运算,再对 exp 矩阵同⾏元素求和,最后令矩阵每⾏各元素与该⾏元素之和相除。这样⼀来,最终得到的矩阵每⾏元素和为 1 且⾮负。
def softmax(X):
X_exp=X.exp() #x=e^x
pation=X_exp.sum(axis=1,keepdims=True) #对矩阵的每一行求和,最后矩阵由m*n -> m*1
return X_exp/pation #运用了广播
由训练数据(X)生成预测数据(y)
def net(X):
return softmax(nd.dot(X.reshape((-1, num_inputs)), w) + b)
3.定义损失函数与准确率
softmax使用交叉熵损失函数,定义为:
我们知道最小化 ℓ(Θ) 等价于最⼤化 exp(−nℓ(Θ)) =,即最小化交叉熵损失函数等价于最⼤化训练数据集所有标签类别的联合预测概率。
def cross_entropy(y_hat, y):
return - nd.pick(y_hat, y).log()
然后求出预测结果的准确率,函数中参数意义:
假设输入X表示标签2(猫),一共有3个标签(鸡、猫、鸭)
y=(0,1,0)(真实结果)
y_hat=(0.1,0.6,0.3)(可能预测值1)··········>预测正确
y_hat=(0.5,0.2,0.3)(可能预测值2)··········>预测错误
y_hat.argmax(axis=1)返回矩阵 y_hat每⾏中最⼤元素的索引
def accuracy(y_hat, y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar()
4.训练模型
定义超参数:
num_epochs, lr = 5, 0.1#迭代周期数 num_epochs 和学习率 lr
开始训练,训练步骤:
1:从数据集中去小批量数据,根据X得出预测的Y=nd.dot(x,w)+b
2.预测Y与从数据中获取的Y(数据真实结果)构成损失函数
3.求损失函数求梯度,使用梯度下降让损失函数只变小
4.重复第一步
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,params=None, lr=None, trainer=None):
for epoch in range(num_epochs):
for X, y in train_iter:
with autograd.record(): #申请存储梯度所需要的内存
y_hat = net(X)
l = loss(y_hat, y) #损失函数
l.backward() #⾃动求梯度(对一次函数来说相当于导数)
if trainer is None:
gb.sgd(params, lr, batch_size)#梯度下降
else:
trainer.step(batch_size)
最后运行结果
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs,batch_size, [w, b], lr)
for X, y in test_iter:
break
true_labels = gb.get_fashion_mnist_labels(y.asnumpy())
pred_labels = gb.get_fashion_mnist_labels(net(X).argmax(axis=1).asnumpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]
gb.show_fashion_mnist(X[0:9], titles[0:9])
结果如下: