该博客利用Python的mxnet库来简单实现softmax算法,源代码如下:主要包括训练函数,网络构成函数,损失函数,和预测函数。
import gluonbook as gb
from mxnet import autograd,nd
def softmax(x):
x_exp=x.exp()
partition=x_exp.sum(axis=1,keepdims=True)
return x_exp/partition
def net(X,num_inputs,w,b):
print(b)
return softmax(nd.dot(X.reshape((-1,num_inputs)),w)+b)
def loss(y_hat,lables):
ls=-nd.pick(y_hat,lables).log()
return ls
def train(train_inter,test_inter,batch_size,num_inputs,
params=None,lr=None,trainer=None):
num_epochs=3
for epoch in range(num_epochs):
ls_sum=0
acc_sum=0
for x,y in train_inter:
with autograd.record():
y_hat=net(x,num_inputs,params[0],params[1])
l=loss(y_hat,y)
l.backward()
if trainer is None:
gb.sgd(params,lr,batch_size)
else:
trainer.step(batch_size)
ls_sum+=l.mean().asscalar()
acc_sum+=accuracy(y_hat,y)
test_acc=evaluate_accuraxy(test_inter,params)
#print(ls_sum/len(train_inter),acc_sum/len(train_inter),test_acc)
return params
def accuracy(y_hat,lables):
return (y_hat.argmax(axis=1)==lables.astype('float32')).mean().asscalar()
def evaluate_accuraxy(test_inter,params):
acc=0
for x,y in test_inter:
acc+=accuracy(net(x,num_inputs,params[0],params[1]),y)
return acc/len(test_inter)
def pred(test_inter,num_inputs,params):
for x,y in test_inter:
break
true_lables=gb.get_fashion_mnist_labels(y.asnumpy())
print(type(true_lables))
pre=net(x,num_inputs,params[0],params[1]).argmax(axis=1)
pred_lables=gb.get_fashion_mnist_labels(pre.asnumpy())
titles = [true + '\n' + pred for true, pred in zip(true_lables, pred_lables)]
gb.show_fashion_mnist(x[:9],titles[:9])
if __name__=="__main__":
batch_size=256
train_iter,test_iter=gb.load_data_fashion_mnist(batch_size)
#使用mnist数据集,并将其分块,
#得到mxnet.gluon.data.dataloader.DataLoader类型的数据
num_inputs=784
num_out=10
w=nd.random.normal(scale=0.01,shape=(num_inputs,num_out))
b=nd.zeros(num_out)
w.attach_grad()
b.attach_grad()
params=[w,b]
lr=0.05
params=train(train_iter, test_iter,batch_size, num_inputs,params, lr)
pred(test_iter,num_inputs,params)
预测结果如下图: