手写数字识别网络代码实现

 首先我们需要去鱼书对应官网下载对应数据集,然后按照书上步骤可一步一步实现,我这里就讲解可能遇到的报错的点。

首先第一个报错点

network=pickle.load(f)

由于书上是在读入MNIST数据集后会自动生成pickle文件夹,但是我们却没有在我们的目录中看到这个文件夹,这时候我们就需要 到我们下载的文件夹ch03中找到sample_weight.pkl放到我们的工作台下。

类似于这样。然后再执行pickle.load(f)就不会报错了。

具体代码如下


import numpy as np
import sys,os
sys.path.append(os.pardir)#为了导入父目录中的文件而进行的设定
from mnist import load_mnist #通过mnist的load_mnist()可以读入MNIST数据
from PIL import Image

def img_show(img):
    pil_img=Image.fromarray(np.uint8(img))
    pil_img.show()

def softmax(x):
    exp_a=np.exp(x)
    sum_exp_a=np.sum(exp_a)
    y=exp_a/sum_exp_a
    return y
def sigmoid(x):
    return 1/(1+np.exp(-x))


(x_train,t_train),(x_test,t_test)=load_mnist(flatten=True,normalize=False)
img=x_train[0]
label=t_train[0]
print(label)

print(img.shape)
img=img.reshape(28,28)
print(img.shape)

#img_show(img)

#接下来对MNIST数据集实现神经网络的推理处理
'''
神经网络的输入层有784个神经元,输出层有10个神经元。输入层的784来源于图像大小28x28=784,输出层的10来源于10类别分类(数字0-9,共10个类别)
此外,这个神经网络有两个隐藏层,第一个隐藏层有50个神经元,第二个隐藏层有100个神经元,当然这个50,100可以设置为任意值
'''

def get_data():
    (x_train,t_train),(x_test,t_test)=\
        load_mnist(normalize=True,flatten=True,one_hot_label=False)
    return x_test,t_test
def init_network():
    with open("sample_weight.pkl",'rb') as f:
        network=pickle.load(f)
    return network

def predict(network,x):
    W1,W2,W3=network['W1'],network['W2'],network['W3']
    b1,b2,b3=network['b1'],network['b2'],network['b3']
    a1=np.dot(x,W1)+b1
    z1=sigmoid(a1)
    a2=np.dot(z1,W2)+b2
    z2=sigmoid(a2)
    a3=np.dot(z2,W3)+b3
    y=softmax(a3)
    return y

x,t=get_data()
network=init_network()
accuracy_cnt=0
for i in range(len(x)):
    y=predict(network,x[i])
    p=np.argmax(y)#获取概率最高的元素的索引
    if p==t[i]:
        accuracy_cnt+=1

print("Accuracy:"+str(float(accuracy_cnt)/len(x)))

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值