手写数字识别神经网络完整代码,带详细注释。

本文通过详尽的注释,介绍了一个使用神经网络进行手写数字识别的完整代码实现。涉及数据集包括训练集(60000条)和测试集(10000条),如mnist_train.csv和mnist_test.csv。通过运行操作文件handwrite_number.py,可以完成数据整理、网络训练及识别过程。
摘要由CSDN通过智能技术生成

神经网络使用的两个数据集:一个是训练集(60000条),一个是测试测试集(10000条),下载后直接放在程序目录下就可以了

mnist手写字体训练集6000条,mnist_train.csv

mnist手写字体测试集10000条,mnist_test.csv

神经网络文件(nn3l.py)

"""
@文件:nn3l.py
@功能:这是一个3层的神经网络
@作者:Kwina
@日期:2020年7月14日
@说明:传入数据和输出数据都为列表形式
"""
import numpy as np
import scipy.special


class NeuralNetwork:
    """
    神经网络
    """
    def __init__(self, inputs, hiddens, outputs, learn_rate):
        self.inodes = inputs  # 输入层节点数
        self.hnodes = hiddens  # 隐藏层节点数
        self.onodes = outputs  # 输出层节点数
        self.lr = learn_rate  # 学习效率
        self.epoch = 0  # 迭代次数
        self.file_prefix = None  # 上次保存的权重文件的前缀
        self.wih = None  # 输入层至隐藏层的权重矩阵
        self.who = None  # 隐藏层至输出层的权重矩阵

    def set_weight(self):
        """
        设置权重
        """
        # 初始权重矩阵
        if self.epoch == 0:
            self.wih = np.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
            self.who = np.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))
        else:
            self.wih = NeuralNetwork.txt2array2d(self.file_prefix + '_wih.txt')
            self.who = NeuralNetwork.txt2array2d(self.file_prefix + '_who.txt')

    def train(self, inputs_list, targets_list):
        """
        训练网络
        :param inputs_list:训练源数据列表
        :param targets_list:正确结果列表
        """
        inputs = np.array(inputs_list, ndmin=2).T
        targets = np.array(targets_list, ndmin=2).T

        # 隐藏层计算
        hidden_inputs = np.dot(self.wih, inputs)
        hidden_outputs = NeuralNetwork.sigmoid(hidden_inputs)

        # 输出层计算
        final_inputs = np.dot(self.who, hidden_outputs)
        final_outputs = NeuralNetwork.sigmoid(final_inputs)

        # 误差计算
        output_errors = targets - final_outputs
        hidden_errors = np.dot(self.who.T, output_errors)

        # 反向传播更新权值
        self.wih += self.lr * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), inputs.T)
        self.who += self.lr * np.dot((output_errors * final_outputs * (1.0 - final_outputs)), hidden_outputs.T)

    def query(self, inputs_list):
        """
        识别未知数据
        :param inputs_list: 数据应该和训练源数据长度一样
        :param epoch: 使用迭代多少代的数据
        :return: 数据列表,数字最大就是识别结果
        """

        # 输入
        inputs = np.array(inputs_list, ndmin=2).T

        # 隐藏层计算
        hidden_inputs = np.dot(self.wih, inputs)
        hidden_outputs = NeuralNetwork.sigmoid(hidden_inp
  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值