__init __():初始化函数,需要指定一个学习率。
train(): 训练函数。
query()
函数接受神经网络的输入,返回网络的输出。
神经网络代码
class neuralNetwork:
def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
self.inodes = inputnodes
self.hnodes = hiddennodes
self.onodes = outputnodes
self.wih = numpy.random.normal(0.0, pow(self.inodes, -0.5), (self.hnodes, self.inodes))
self.who = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.onodes, self.hnodes))
self.lr = learningrate
self.activation_function = lambda x: scipy.special.expit(x)
pass
def train(self, inputs_list, targets_list):
inputs = numpy.array(inputs_list, ndmin=2).T
targets = numpy.array(targets_list, ndmin=2).T
hidden_inputs = numpy.dot(self.wih, inputs)
hidden_outputs = self.activation_function(hidden_inputs)
final_inputs = numpy.dot(self.who, hidden_outputs)
final_outputs = self.activation_function(final_inputs)
output_errors = targets - final_outputs
hidden_errors = numpy.dot(self.who.T, output_errors)
self.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs)), numpy.transpose(hidden_outputs))
self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), numpy.transpose(inputs))
pass
def query(self, inputs_list):
inputs = numpy.array(inputs_list, ndmin=2).T
hidden_inputs = numpy.dot(self.wih, inputs)
hidden_outputs = self.activation_function(hidden_inputs)
final_inputs = numpy.dot(self.who, hidden_outputs)
final_outputs = self.activation_function(final_inputs)
return final_outputs
传参
使用比输入节点的数量小的值,强制网络尝试总结输入的主要特点。如果选择太少的隐藏层节点,会限制网络的能力,使网络难以找到足够的特征或模式
应该选择多少个隐藏层节点,并不存在一个最佳方法
大的学习率会导致在梯度下降过程中有一些来回跳动和超调,减小学习率可以延长学习时间
input_nodes = 784
hidden_nodes = 200
output_nodes = 10
learning_rate = 0.1
n = neuralNetwork(input_nodes,hidden_nodes,output_nodes, learning_rate)
导入mnist训练数据集
from google.colab import drive
drive.mount('./mount')
training_data_file = open("mount/My Drive/Colab Notebooks/mnist_data/mnist_train.csv",'r')
training_data_list = training_data_file.readlines()
training_data_file.close()
训练
选择0.01作为范围最低点,避免0值输入会人为地造成权重更新失败。没有选择 0.99作为输入的上限值,是因为不需要避免输入1.0会造成这个问题。只需要避免输出值为1.0。
epochs = 10
for e in range(epochs):
for record in training_data_list:
all_values = record.split(',')
inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
targets = numpy.zeros(output_nodes) + 0.01
targets[int(all_values[0])] = 0.99
n.train(inputs, targets)
pass
pass
测试集测试
test_data_file = open("mount/My Drive/Colab Notebooks/mnist_data/mnist_test.csv", 'r')
test_data_list = test_data_file.readlines()
test_data_file.close()
scorecard = []
for record in test_data_list:
all_values = record.split(',')
correct_label = int(all_values[0])
inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
outputs = n.query(inputs)
label = numpy.argmax(outputs)
if (label == correct_label):
scorecard.append(1)
else:
scorecard.append(0)
pass
pass
scorecard_array = numpy.asarray(scorecard)
print ("performance = ", scorecard_array.sum() / scorecard_array.size)
#performance = 0.9716
尝试识别自己写的手写数字
修改图片尺寸
import glob
from PIL import Image
import numpy
import os
srcPath = 'mount/My Drive/Colab Notebooks/ps_predata/'
name = os.listdir(srcPath)
i = 0
for img_file_name in glob.glob('mount/My Drive/Colab Notebooks/ps_predata/*.PNG'):
image = Image.open(img_file_name)
resized_image = image.resize((28,28),Image.ANTIALIAS)
resized_image.show()
resized_image_name = name[i]
i=i+1
resized_image.save('mount/My Drive/Colab Notebooks/ps_data/'+resized_image_name)
使用自己做的数据集进行测试
import imageio
our_own_dataset = []
k=0
count=0
for image_file_name in glob.glob("mount/My Drive/Colab Notebooks/ps_data/*.PNG"):
count=count+1
correct_label = int(image_file_name[-5:-4])
img_array = imageio.imread(image_file_name, as_gray=True)
img_data = 255.0 - img_array.reshape(784)
img_data = (img_data / 255.0 * 0.99) + 0.01
outputs = n.query(img_data)
label = numpy.argmax(outputs)
#print(correct_label,label)
if(correct_label==label):
k=k+1
#print('correct',k)
#print('sum',count)
print('accuracy',k/count)
数据预处理遇到的问题:
像素值:mnist黑底白字,自己做的数据集白底黑字,灰度值取反
背景颜色:除了数字外的灰度值即背景应该统一,否则会成为干扰的“噪声”
数字颜色:数字灰度值太小有一定的影响
mnist手写字体数据集下载地址
https://pjreddie.com/media/files/mnist_test.csv
Colab笔记本
https://colab.research.google.com/?pli=1#scrollTo=Nma_JWh-W-IF