用pycaffe训练人工神经网络步骤

from:http://blog.csdn.net/leo_is_ant/article/details/50506256

Caffe 是一个做CNN的工具。但是如果我只想搭建一套普通的神经网络,那么应该怎么做呢?这篇文章参考了一下两篇文章来一发CNN搭建神经网络的实验。

 

http://nbviewer.ipython.org/github/joyofdata/joyofdata-articles/blob/master/deeplearning-with-caffe/Neural-Networks-with-Caffe-on-the-GPU.ipynb

https://github.com/Franck-Dernoncourt/caffe_demos

 

  第一步构建我们的feature,label。我们可以把feature,与label整理成sklearn类似的数据格式。具体格式如下(featureMat是一个list1[list2] list2 是特征向量label是list 表征每个label 当然转换成 numpy.array()好像也可以)总之下如图:


                

第二步我们需要转换成caffe需要的数据格式,一开始想转换成HDF5格式,后来train这一步出现错误,error提示信息为numberof labels must match number of predictions;查看代码发现之前demo是做多标签的。但是这里需求是一个但标签多类分类(loss为softmax而非cross-entropy)所以我将数据按照lmdb格式组织了起来。

 

第三步写solver和train_val的prototxt。与model里的prototxt一样,只不过这里没有卷积层,是一个全连接加上tanh的传统神经网络结构。网络结构图奉上:




第四步训练测试。至此完成了caffe普通神经网络的训练…..貌似把人家搞退化了。

最后相关代码奉上:


lmdb生成代码

[python]  view plain  copy
  1. def load_data_into_lmdb(lmdb_name, features, labels=None):  
  2.     env = lmdb.open(lmdb_name, map_size=features.nbytes*2)  
  3.       
  4.     features = features[:,:,None,None]  
  5.     for i in range(features.shape[0]):  
  6.         datum = caffe.proto.caffe_pb2.Datum()  
  7.           
  8.         datum.channels = features.shape[1]  
  9.         datum.height = 1  
  10.         datum.width = 1  
  11.           
  12.         if features.dtype == np.int:  
  13.             datum.data = features[i].tostring()  
  14.         elif features.dtype == np.float:   
  15.             datum.float_data.extend(features[i].flat)  
  16.         else:  
  17.             raise Exception("features.dtype unknown.")  
  18.           
  19.         if labels is not None:  
  20.             datum.label = int(labels[i])  
  21.           
  22.         str_id = '{:08}'.format(i)  
  23.         with env.begin(write=True) as txn:  
  24.             txn.put(str_id, datum.SerializeToString())  

训练代码

[python]  view plain  copy
  1. def train(solver_prototxt_filename):  
  2.     ''''' 
  3.     Train the ANN 
  4.     '''  
  5.     caffe.set_mode_cpu()  
  6.     solver = caffe.get_solver(solver_prototxt_filename)  
  7.     solver.solve()  
  8.       


预测代码

[python]  view plain  copy
  1. def get_predicted_output(deploy_prototxt_filename, caffemodel_filename, input, net = None):  
  2.     ''''' 
  3.     Get the predicted output, i.e. perform a forward pass 
  4.     '''  
  5.     if net is None:  
  6.         net = caffe.Net(deploy_prototxt_filename,caffemodel_filename, caffe.TEST)  
  7.           
  8.     out = net.forward(data=input)  
  9.     return out[net.outputs[0]]  


  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值