我的研究重点原本是在 Torch 上,我也很喜欢用 Torch 去实现网络。但最近不得不转到 caffe 上。
在实现上篇博文:论文阅读:Reading Text in the Wild with Convolutional Neural Networks 的代码时,在 Bounding Box Regression 部分,需要用 caffe 来实现这个网络。
而一开始我构建论文中提到的这个 CNN 网络时,并没有用 caffe 提供的接口,而是直接手写 train.prototxt
、test.prototxt
文件。结果除了很多错,caffe 的 layers 名称也在变化,所以极其不推荐这种方式。我这个网络较浅时还好,一旦网络很深,写起来累死人。
我将这个网络用 caffe 中的 draw_net.py
脚本生成图像,网络结构如下:
调用 caffe 的 python 接口,网络构建方式如下:
import os, sys, glob
CAFFE_PATH = '/home/chenxp/caffe/python'
sys.path.append(CAFFE_PATH)
import caffe
import caffe.io
from caffe import layers as L
from caffe import params as P
import h5py
def Bounding_Box_Reg(hdf5, batch_size):
n = caffe.NetSpec()
n.data, n.label = L.HDF5Data(batch_size=batch_size, source=hdf5, ntop=2)
n.conv1 = L.Convolution(n.data, kernel_size=5, num_output=64, pad=2, weight_filler=dict(type='xavier'))
n.relu1 = L.ReLU(n.conv1, in_place=True)
n.pool1 = L.Pooling(n.conv1, kernel_size=2, stride=2, pool=P.Pooling.MAX)
n.conv2 = L.Convolution(n.pool1, kernel_size=5, num_output=128, pad=2, weight_filler=dict(type='xavier'))
n.relu2 = L.ReLU(n.conv2, in_place=True)
n.pool2 = L.Pooling(n.conv2, kernel_size=2, stride=2, pool=P.Pooling.MAX)
n.conv3 = L.Convolution(n.pool2, kernel_size=3, num_output=256, pad=1, weight_filler=dict(type='xavier'))
n.relu3 = L.ReLU(n.conv3, in_place=True)
n.pool3 = L.Pooling(n.conv3, kernel_size=2, stride=2, pool=P.Pooling.MAX)
n.conv4 = L.Convolution(n.pool3, kernel_size=3, num_output=512, pad=1, weight_filler=dict(type='xavier'))
n.relu4 = L.ReLU(n.conv4, in_place=True)
n.pool4 = L.Pooling(n.conv4, kernel_size=2, stride=2, pool=P.Pooling.MAX)
n.ip1 = L.InnerProduct(n.pool4, num_output=4000, weight_filler=dict(type='xavier'))
n.dp1 = L.Dropout(n.ip1, dropout_ratio=0.5)
n.ip2 = L.InnerProduct(n.ip1, num_output=4, weight_filler=dict(type='xavier'))
n.loss = L.EuclideanLoss(n.ip2, n.label)
return n.to_proto()
with open('BBR_train.prototxt', 'w') as f:
f.write(str(Bounding_Box_Reg('train.h5', 16)))
with open('BBR_test.prototxt', 'w') as f:
f.write(str(Bounding_Box_Reg('test.h5', 16)))
最后生成两个 prototxt
文件,一个是 BBR_train.prototxt
,另一个是 BBR_test.prototxt
文件。
下图是自动生成的 BBR_train.prototxt
文件:
这样的方式更快速,还不容易出错~^_^