mnist手写数据集神经网络C语言,matlab练习程序(神经网络识别mnist手写数据集)...

记得上次练习了神经网络分类,不过当时应该有些地方写的还是不对。

这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码。

mnist数据集训练数据一共有28*28*60000个像素,标签有60000个。

测试数据一共有28*28*10000个,标签10000个。

这里神经网络输入层是784个像素,用了100个隐含层,最终10个输出结果。

arc代表的是神经网络结构,可以增加隐含层,不过我试了没太大效果,毕竟梯度消失。

因为是最普通的神经网络,最终识别错误率大概在5%左右。

迭代曲线:

2740ffd26ee02b223da5b7762fb35223.png

代码如下:

clear all;

close all;

clc;

load mnist_uint8;

train_x = double(train_x) / 255;

test_x = double(test_x) / 255;

train_y = double(train_y);

test_y = double(test_y);

mu=mean(train_x);

sigma=max(std(train_x),eps);

train_x=bsxfun(@minus,train_x,mu); %每个样本分别减去平均值

train_x=bsxfun(@rdivide,train_x,sigma); %分别除以标准差

test_x=bsxfun(@minus,test_x,mu);

test_x=bsxfun(@rdivide,test_x,sigma);

arc = [784 100 10]; %输入784,隐含层100,输出10

n=numel(arc);

W = cell(1,n-1); %权重矩阵

for i=2:n

W{i-1} = (rand(arc(i),arc(i-1)+1)-0.5) * 8 *sqrt(6 / (arc(i)+arc(i-1)));

end

learningRate = 2; %训练速度

numepochs = 5; %训练5遍

batchsize = 100; %一次训练100个数据

m = size(train_x, 1); %数据总量

numbatches = m / batchsize; %一共有numbatches这么多组

%% 训练

L = zeros(numepochs*numbatches,1);

ll=1;

for i = 1 : numepochs

kk = randperm(m);

for l = 1 : numbatches

batch_x = train_x(kk((l - 1) * batchsize + 1 : l * batchsize), :);

batch_y = train_y(kk((l - 1) * batchsize + 1 : l * batchsize), :);

%% 正向传播

mm = size(batch_x,1);

x = [ones(mm,1) batch_x];

a{1} = x;

for ii = 2 : n-1

a{ii} = 1.7159*tanh(2/3.*(a{ii - 1} * W{ii - 1}'));

a{ii} = [ones(mm,1) a{ii}];

end

a{n} = 1./(1+exp(-(a{n - 1} * W{n - 1}')));

e = batch_y - a{n};

L(ll) = 1/2 * sum(sum(e.^2)) / mm;

ll=ll+1;

%% 反向传播

d{n} = -e.*(a{n}.*(1 - a{n}));

for ii = (n - 1) : -1 : 2

d_act = 1.7159 * 2/3 * (1 - 1/(1.7159)^2 * a{ii}.^2);

if ii+1==n

d{ii} = (d{ii + 1} * W{ii}) .* d_act;

else

d{ii} = (d{ii + 1}(:,2:end) * W{ii}).* d_act;

end

end

for ii = 1 : n-1

if ii + 1 == n

dW{ii} = (d{ii + 1}' * a{ii}) / size(d{ii + 1}, 1);

else

dW{ii} = (d{ii + 1}(:,2:end)' * a{ii}) / size(d{ii + 1}, 1);

end

end

%% 更新参数

for ii = 1 : n - 1

W{ii} = W{ii} - learningRate*dW{ii};

end

end

end

%% 测试,相当于把正向传播再走一遍

mm = size(test_x,1);

x = [ones(mm,1) test_x];

a{1} = x;

for ii = 2 : n-1

a{ii} = 1.7159 * tanh( 2/3 .* (a{ii - 1} * W{ii - 1}'));

a{ii} = [ones(mm,1) a{ii}];

end

a{n} = 1./(1+exp(-(a{n - 1} * W{n - 1}')));

[~, i] = max(a{end},[],2);

labels = i; %识别后打的标签

[~, expected] = max(test_y,[],2);

bad = find(labels ~= expected); %有哪些识别错了

er = numel(bad) / size(x, 1) %错误率

plot(L);

关注公众号: MATLAB基于模型的设计 (ID:xaxymaker) ,每天推送MATLAB学习最常见的问题,每天进步一点点,业精于勤荒于嬉。

打开微信扫一扫哦!

用Kersa搭建神经网络【MNIST手写数据集】

MNIST手写数据集的识别算得上是深度学习的”hello world“了,所以想要入门必须得掌握.新手入门可以考虑使用Keras框架达到快速实现的目的. 完整代码如下: # 1. 导入库和模块 fro ...

TensorFlow实战第五课(MNIST手写数据集识别)

Tensorflow实现softmax regression识别手写数字 MNIST手写数字识别可以形象的描述为机器学习领域中的hello world. MNIST是一个非常简单的机器视觉数据集.它由 ...

利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)

.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

利用卷积神经网络实现MNIST手写数据识别

代码: import torch import torch.nn as nn import torch.utils.data as Data import torchvision # 数据库模块 im ...

TensorFlow系列专题(六):实战项目Mnist手写数据集识别

欢迎大家关注我们的网站和系列教程:http://panchuang.net/ ,学习更多的机器学习.深度学习的知识! 目录: 导读 MNIST数据集 数据处理 单层隐藏层神经网络的实现 多层隐藏层神经 ...

MNIST手写数据集在运行中出现问题解决方案

今天在运行手写数据集的过程中,出现一个问题,代码没有问题,但是运行的时候一直报错,错误如下: urllib.error.URLError:

Pytorch1.0入门实战一:LeNet神经网络实现 MNIST手写数字识别

记得第一次接触手写数字识别数据集还在学习TensorFlow,各种sess.run(),头都绕晕了.自从接触pytorch以来,一直想写点什么.曾经在2017年5月,Andrej Karpathy发表 ...

TensorFlow——MNIST手写数据集

MNIST数据集介绍 MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集. ...

keras—神经网络CNN—MNIST手写数字识别

from keras.datasets import mnist from keras.utils import np_utils from plot_image_1 import plot_imag ...

随机推荐

Cocos2dx中利用双向链表实现无限循环滚动层

[Qboy原创] 在Cocos2dX 3.0 中已经实现一些牛逼的滚动层,但是对于有一些需要实现循环滚动的要求确没有实现,笔者在前段时间的一个做了一个游戏,需求是实现在少有的(13个)英雄中进行循环滚 ...

linux驱动面试题2

1.什么是GPIO? general purpose input/output GPIO是相对于芯片本身而言的,如某个管脚是芯片的GPIO脚,则该脚可作为输入或输出高或低电平使用,当然某个脚具有复用的 ...

初学Python的一些细节

一.python的数据类型 1.python的基本数据类型包括数值数据类型和字符串数据类型:基本数据类型的特点是不允许改变,如果改变基本数据类型的值,会导致内存的重新分配. int 整形 二进制    ...

May 31. 2018 Week 22nd Thursday

The good seaman is known in bad weather. 惊涛骇浪,方显英雄本色. As we all know, the true worth of a person is ...

Mysql命令行tab自动补全方法

在mysql命令行有时为了方便想要按tbl键自动补全命令,以便节约时间. 具体方法如下: 第一步:修改my.cnf vi mysql/etc/my.cnf 将下图红框的代码注释,修改成如下代码: #d ...

Android项目开发第二天,关于GitHub

一. 今天在网上学习了如何使用GitHub,了解了GitHub是干什么的. 作为开源代码库以及版本控制系统,Github拥有超过900万开发者用户.随着越来越多的应用程序转移到了云上,Github已经 ...

TP引用样式表和js文件及验证码

TP引用样式表和js文件及验证码 引入样式表和js文件

PReLU全名Parametric Rectified Linear Unit. PReLU-nets在ImageNet 2012分类数据集top-5上取得了4.94%的错误率,首次超越了人工分类的错 ...

Sentinel 哨兵 实现redis高可用

本文链接:http://www.cnblogs.com/zhenghongxin/p/8885879.html 我们知道redis是有主从复制的,例如下图: 但如果master主进程挂掉之后,没有sl ...

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值