基于MNIST数据集的神经网络推理处理 python Demo

#  神经网络的推理处理
#  机器学习中的大部分问题可以分为回归和分类问题
#
import pickle
import sys
import os

import numpy as np

sys.path.append(os.pardir)

from dataset.mnist import load_mnist

#  恒等函数
def softmax(i):
    j = np.max(i)  # 溢出策略
    exa = np.exp(i - j)
    suma = np.sum(exa)
    sy = exa / suma
    return sy

# sigmoid 函数实现 各层之间进行平滑变化
def sigmoid(i):
    return 1 / (1 + np.exp(-i))

#  将数据限定在某个范围内的处理称为正规化(normalization)
#  对神经网络的输入数据进行某种既定的转换称为预处理
#       很多预处理都会考虑到数据的整体分布
#       比如利用数据的整体的均值或标准差,移动数据,使得数据整体以0为中心分布,或者进行正规化,把数据的延展控制在一定的范围内,边界清晰
#   还有将数据整体的分布形状均匀化的方法,即数据白化

# 获取数据
def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test


#  初始化网络 读取文件中的权重参数
def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network


#  推理预测
def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']
    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)
    return y


#  进行推理然后评价识别精度
x, t = get_data()


network = init_network()
W1, W2, W3 = network['W1'], network['W2'], network['W3']

# print(x.shape)
# print(x[0].shape)
# print(W1.shape)
# print(W2.shape)
# print(W3.shape)

# np.argmax 函数获取矩阵中概率最高的元素
xx = np.array(([
    [0.1, 0.8, 0.1],
    [0.3, 0.1, 0.6],
    [0.2, 0.5, 0.3],
    [0.8, 0.1, 0.1]
]))
print(xx)
yy = np.argmax(xx, axis=1)
print(yy)

#  直接使用模型进行训练
accuracy_cnt = 0

#  使用批处理思想进行更改
batch_size = 100  # 批数量
for i in range(0, len(x), batch_size):
    x_batch = x[i:i+batch_size]
    y_batch = predict(network, x_batch)
    p = np.argmax(y_batch, axis=1)  # 获取元素最大索引
    accuracy_cnt += np.sum(p == t[i:i+batch_size])

#  原始方式
# for i in range(len(x)):
#     y = predict(network, x[i])
#     p = np.argmax(y)  # 获取概率最高元素的索引
#     if p == t[i]:
#         accuracy_cnt += 1

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))



  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

P("Struggler") ?

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值