caffe lstm训练mnist手写数字

我们可以把深度学习能做的事情分为两类:时间无关的事情和时间相关的事情。时间无关的话,比如人脸识别,给神经网络一张照片,神经网络就能告诉你这是谁,这是和时间无关的。时间相关的话,比如,我要知道一段视频里的人是在吃饭还是在打哈欠,这个可能通过一张照片是无法判别的,但是通过多张连续的图片,构成一段视频,我们就可以判别这个人是在打哈欠还是在吃饭了。
可是,手写数字明明是一张张无时间关系的静态图片啊?我们怎么用它来训练处理序列数据的lstm这种递归神经网络呢?其实,我们可以这样想:我们把一张手写数字的图片(28*28),看做是28个有序列的数据,每一个数据大小为28个字节,因为这28行我们可以理解为是随着时间的推移,一行一行写出来的。这样,我们就可以用lstm来训练手写数字了。

在 caffe 上配置 LSTM 时,数据的维度比较复杂。比如在 CNN 处理图片时,caffe 的数据维度一般是 NCHW,N 是 batch size,C 是 channel,W 是 width,H 是 height;但是 LSTM 的数据维度是 TN*…,T 是 sequence length,N 是 batch size,但是本人以为这里的N是同时处理的数据流的个数,我们做Mnist手写数字识别的时候,只有一个流,因此这里N为1。T则是序列的长度,毫无疑问,我们这里的序列长度就是28呀。

lstm需要输入两个Blob,一个为(T,N,V)的待处理的数据,另一个为(T,N,1)的标识连续帧的数据。比如我们的例子中,28个长度为28的数据构成一个序列,标识连续帧的cont应该为(0,1,1…共有27个1),而且只需要一行就可以了,因为所有的序列都是一样的,只需要为一个序列做标识就可以了。

cont我们以hdf5的格式给出,对hdf5不了解的先了解下hdf5吧,总之还是蛮简单的,我们只需要生成一行(0,1,1…共有27个1)这样的数据就可以了。
代码如下:

#include "hdf5.h"
#include <stdio.h>
#include <stdlib.h>

#define FILE            "h5ex_d_rdwr.h5"
#define DATASET         "cont"
#define DIM0            1
#define DIM1            28

int main(void)
{
	hid_t       file, space, dset;          /* Handles */
	herr_t      status;
	hsize_t     dims[2] = { DIM0, DIM1 };
	int         wdata[DIM0][DIM1],    
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值