以batch的思想实现mnist神经网络

该代码示例展示了如何使用批量(batch)处理优化神经网络的推理过程,以MNIST数据集为例。通过加载预训练的模型,对图像数据进行批量预测,并计算预测准确性。文章强调了批量处理可以减少数据传输时间,提高计算速度。
摘要由CSDN通过智能技术生成

0. 问题的出现

我们可以一张一张的将图像输入神经网络进行处理

  • 但是花费时间会长

为什么不一次性多传入点照片呢?

  • 例如一次直接传入100张图片

1. batch

捆,批的意思

  • 要尽可能地让计算机处于计算状态
  • 而不是数据传送状态
  • 这样效率会更好更高
  • 计算结果的速度会更快

2. 源码

# 在本实例中以batch的思想实现mnist神经网络
import pickle

import numpy as np
import sys, os

sys.path.append(os.pardir)
from dataset.mnist import load_mnist


def sigmod(x):
    y = 1 / (1 + np.exp(-x))

    return y


def softmax1(x):
    if x.ndim == 2:
        x = x.T
        x = x - np.max(x, axis=0)
        y = np.exp(x) / np.sum(np.exp(x), axis=0)
        return y.T

    x = x - np.max(x)  # 溢出对策
    return np.exp(x) / np.sum(np.exp(x))


def softmax(x):
    max = np.max(x)
    exp_a = np.exp(x - max)
    sum_exp_a = np.sum(exp_a)

    y = exp_a / sum_exp_a
    return y


# 拿到数据集
def get_mnist():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test


# 神经网络学习
def init_network():
    # pickle得到学习好的神经网络
    with open(file='sample_weight.pkl', mode='rb') as f:  # 从pickle中读取训练好的网络
        network = pickle.load(file=f)  # 使用pickle类读取打开的pkl文件

    return network


# 神经网络推理
def predict(network, x):
    pass
    # 拿到网络权重
    W1 = network['W1']
    W2 = network['W2']
    W3 = network['W3']

    b1 = network['b1']
    b2 = network['b2']
    b3 = network['b3']

    # 推理
    a1 = np.dot(x, W1) + b1
    z1 = sigmod(a1)

    a2 = np.dot(z1, W2) + b2
    z2 = sigmod(a2)

    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)

    return y


# 主入口
x, t = get_mnist()  # 输入层的值
network = init_network()

batch_size = 100
accuracy_count = 0

# 计算正确率
for i in range(0, len(x), batch_size):
    x_batch = x[i:i + batch_size]
    y_batch = predict(network, x_batch)  # 得到的y是一个ndarray数组
    predict_result = np.argmax(y_batch, axis=1)  # 预测结果

    rigth_count = np.sum(predict_result == t[i:i + batch_size])
    # if predict_result == t[i:i + batch_size]:
    accuracy_count = accuracy_count + rigth_count

print('Accuracy is ' + str(accuracy_count / len(x) * 100) + ' %')

3. 关键参数展示

    p = np.argmax(y_batch, axis=1)

在这里插入图片描述
参数p和t进行比较,从而来让准确度+1

思考

在这里插入图片描述

  • 行表示元祖的第一个值,一般都是输入的个数
  • 列表示原则的第二个值,这里是表示预测的结果

从0到9正好是个数,都分布在这里,所以应该是从行的方向上找最大值的索引,axis=1

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

王摇摆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值