6、MNIST数据分类(详细函数代码)

1、MNIST数据训练要点

手写数字识别:

01 像素:28*28=784

02 标签:神经网络对图像进行分类,分配正确的标签,这些标签是0到9共10个数字中的一个,这意味着神经网络有10个输出层节点,每个节点对应一个可能得答案或标签,如果答案是“0”,输出吃呢个第一个节点激发而其余的输出节点则保持抑制状态。

03 输出:试图让神经网络生成0和1的输出,对于激活函数而言是不可能的,这回导致大的权重和饱和网络。因此需要重新调整这些数据,这里使用0.01和0.99来代替0和1.

注:其余基础知识,见之前更新的参数知识

2、详细代码(备有详细注释)

# -*- coding: utf-8 -*-

import numpy

import scipy

import scipy.special

class neuralNetwork:
    
    # initialise the neural network 初始化网络
    def _init_(self,inputnodes,hiddennodes,outputnodes,learningrate):
        
        # set number of nodes in each input,hidden,output layer
        self.inodes=inputnodes
        self.hnodes=hiddennodes
        self.onodes=outputnodes
        
        # link weight matrices,with and who 初始化权重
        # w11,w12,w21,w22
        # 正态分布中心设置为0.0,pow(self.hnodes,-0.5)表示节点数目的0.5次方,最后一个参数是numpy数组的形状大小
        
        self.wih=numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.hnodes,self.inodes))
        self.who=numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.onodes,self.hnodes))
        
        # learning rate 学习率
        self.lr=learningrate
        
        # 激活函数 activation function is the sigmoid function 
        # lambda 是定义函数的简短形式
        
        self.activation_function =lambda x: scipy.special.expit(x)
        
        # train the network 训练网络
        
    def train(self,inputs_list,targets_list):
            
            # convert input list to 2d array
        inputs=numpy.array(inputs_list,ndmin=2).T
        targets=numpy.array(targets_list,ndmin=2).T
            
            # calculate signals into hidden layer  计算隐藏层输入
        hidden_inputs=numpy.dot(self.wih,inputs)
            
            # calculate the signals emerging from hidden layer 计算隐藏层输出信号
        hidden_outputs=self.activation_function(hidden_inputs)
            # 计算输出的
        
        final_inputs=numpy.dot(self.who,hidden_outputs)
            
        final_outputs=self.activation_function(final_inputs)
            
            # output layer error is the (target-actual) 计算误差
        output_errors=targets-final_outputs
            # hidden layer error is the output_error,split by weights.. 隐藏层误差
        hidden_errors=numpy.dot(self.who.T,output_errors)
            
            # 在隐藏层和输出层之间更新权重
            
        self.who += self.lr * numpy.dot((output_errors * final_outputs*(1.0-final_outputs)),numpy.transpose(hidden_outputs))
            
        self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), numpy.transpose(inputs))
            
        pass
            
        
        # query the neural network 查询网络
        
    def query(self,inputs_list):
            
            # convert inputs list tp 2d array 把列表编程数组
        inputs=numpy.array(inputs_list,ndmin=2).T
           
            # calculate signals into hidden layer  计算隐藏层输入
        hidden_inputs=numpy.dot(self.wih,inputs)
            
            # calculate the signals emerging from hidden layer 计算隐藏层输出信号
        hidden_outputs=self.activation_function(hidden_inputs)
            
            # 计算输出的
        
        final_inputs=numpy.dot(self.who,hidden_outputs)
            
        final_outputs=self.activation_function(final_inputs)
            
        return final_outputs
    
# 输入相关参数进行计算

input_nodes = 784
hidden_nodes = 100
output_nodes = 10

learning_rate = 0.3

# create instance of neural network

n = neuralNetwork()

n._init_(784,100,10,0.3)
# load the mnist trai

training_data_file = open("D:/DATA/pycase/number2/project/pretice/mnist_train_100.csv", 'r')
training_data_list = training_data_file.readlines()
training_data_file.close()

for record in training_data_list:

    all_values = record.split(',') # split the record by the ',' commas
# scale and shift the inputs
    inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01

 # create the target output values (all 0.01, except the desiredlabel which is 0.99)

    targets = numpy.zeros(output_nodes) + 0.01 # all_values[0] is the target label for this record

    targets[int(all_values[0])] = 0.99
    
    n.train(inputs, targets)
    
    pass

# load the mnist test data CSV file into a list
test_data_file = open("D:/DATA/pycase/number2/project/pretice/mnist_test_10.csv", 'r')
test_data_list = test_data_file.readlines()
test_data_file.close()


all_values=test_data_list[0].split(',')

print(all_values[0])

 

 

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值