有时如果需要自己添加层到Caffe当中,这时候需要自己编写py文件
#!/usr/bin/envs python
#coding=utf-8
import sys
caffe_root='/home/dyh/caffe/caffe/'
sys.path.insert(0, caffe_root+'python')
import numpy as np
import caffe
import cv2
import yaml
#类名一定要和配置文件里面的layer相同
class MyLayer(caffe.Layer): #这是一个caffe类,不要变
#caffe规定按如下方式写
def setup(self, bottom, top):
self.num = yaml.load(self.param_str)['num']
print('Parameter num:',self.num)
def reshape(self, bottom, top):
pass
def forward(self, bottom, top):
top[0].reshape(*bottom[0].shape)#输出和输入的shape一样
print(bottom[0].data.shape)
print(bottom[0].data)
#测试一下加一个数值能不能行
top[0].data[...] = bottom[0].data + self.num#输出的data就等于输入的加上自定义层的参数
print('输出',top[0].data[...])
def backward(self,top,bottom,propagate_down):
pass
net = caffe.Net('/home/dyh/eclipse-workspace/MyCaffeLayer/src/conv.prototxt',caffe.TEST)
im = np.array(cv2.imread('/home/dyh/eclipse-workspace/MyCaffeLayer/src/cat.jpg'))
print(im.shape)
#添加一个维度[h,w,c]->[b,h,w,c],
im_input = im[np.newaxis,:,:]
print(im_input.shape)
#要将[b,h,w,c]->[b,c,h,w]要符合caffe的规定
im_input2 = im_input.transpose([0,3,1,2])
print(im_input2.shape)
#*表示zip成一个seq,网络的data层是1,3,100,100,reshape和im_input2一样
net.blobs['data'].reshape(*im_input2.shape)
#数据很多,用...代替
net.blobs['data'].data[...] = im_input2#把数据单发到data层
net.forward()
conv.prototxt的简单内容如下,具体内容如何需要具体实施,这个只是简单的加法操作
name: "convolution"
input: "data"
input_dim: 1
input_dim: 3
input_dim: 100
input_dim: 100
layer {
name: "conv"
type: "Convolution"
bottom: "data"
top: "conv"
convolution_param {
num_output:3
kernel_size: 5
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler{
type: "constant"
value: 0
}
}
}
#自己定义的层按如下格式
layer{
name:"MyPythonLayer"#名字自己写,后续接的上就行
type:'Python'#类型一定是这个
bottom:'conv'
top:'output'
python_param{
module:'mypythonlayer'#此处和自己定义层的py文件名一样
layer:'MyLayer'#此处要和自己定义的class一样
param_str:"'num': 21"#这里数字前面要加空格,注意!!因该是提取的时候需要以空格分界
}
}
输出如下
(496, 700, 3)
(1, 496, 700, 3)
(1, 3, 496, 700)
(1, 3, 492, 696)
[[[[ 11.0702715 11.13561 11.150491 ... 13.735498
13.558763 13.209033 ]
[ 11.158289 11.258202 11.289419 ... 13.523567
13.414061 12.992344 ]
[ 11.340612 11.440386 11.423347 ... 13.58939
13.315084 13.156676 ]
...
[ 8.85025 8.900326 8.8069105 ... 9.247419
9.061868 9.001272 ]
[ 8.848199 8.788747 8.691421 ... 9.226819
9.138367 8.871523 ]
[ 8.832649 8.70563 8.590294 ... 9.1012335
9.061207 8.858202 ]]
[[-33.13818 -32.91376 -32.57329 ... -17.359497
-17.008904 -16.725342 ]
[-33.628136 -33.54023 -33.262566 ... -17.534294
-17.166283 -16.70674 ]
[-33.889587 -33.84192 -33.683243 ... -17.484074
-17.18139 -16.803131 ]
...
[-58.23536 -58.174156 -58.293545 ... -10.737489
-10.5319605 -10.538096 ]
[-58.214077 -58.24072 -58.419518 ... -10.745275
-10.566773 -10.593955 ]
[-58.26086 -58.29385 -58.45631 ... -11.090007
-10.717339 -10.459574 ]]
[[ -3.7222815 -3.7532597 -3.728013 ... -2.322114
-2.5792527 -2.3096333 ]
[ -3.9120486 -3.9047034 -3.8735025 ... -2.3396149
-2.4620175 -2.3884416 ]
[ -4.085577 -4.0669966 -3.9437149 ... -2.4459817
-2.2547388 -2.3984628 ]
...
[ -4.3689103 -4.2448835 -4.1534586 ... -0.92515635
-0.8958546 -0.985209 ]
[ -4.2951756 -4.1776423 -4.202016 ... -0.96538967
-1.0222634 -0.97236586]
[ -4.2042675 -4.1664195 -4.2614 ... -0.9604343
-1.1583881 -0.96595764]]]]
('\xe8\xbe\x93\xe5\x87\xba', array([[[[ 32.07027 , 32.13561 , 32.15049 , ..., 34.735497 ,
34.55876 , 34.209034 ],
[ 32.158287 , 32.2582 , 32.28942 , ..., 34.523567 ,
34.414062 , 33.992344 ],
[ 32.340614 , 32.440384 , 32.423347 , ..., 34.58939 ,
34.315086 , 34.156677 ],
...,
[ 29.85025 , 29.900326 , 29.806911 , ..., 30.24742 ,
30.061867 , 30.001272 ],
[ 29.848198 , 29.788746 , 29.691422 , ..., 30.226818 ,
30.138367 , 29.871523 ],
[ 29.83265 , 29.705631 , 29.590294 , ..., 30.101234 ,
30.061207 , 29.858202 ]],
[[-12.13818 , -11.913761 , -11.573292 , ..., 3.640503 ,
3.9910965, 4.274658 ],
[-12.628136 , -12.54023 , -12.262566 , ..., 3.4657059,
3.8337173, 4.2932606],
[-12.889587 , -12.841919 , -12.683243 , ..., 3.5159264,
3.8186092, 4.196869 ],
...,
[-37.23536 , -37.174156 , -37.293545 , ..., 10.262511 ,
10.4680395, 10.461904 ],
[-37.214077 , -37.24072 , -37.419518 , ..., 10.254725 ,
10.433227 , 10.406045 ],
[-37.26086 , -37.29385 , -37.45631 , ..., 9.909993 ,
10.282661 , 10.540426 ]],
[[ 17.277718 , 17.24674 , 17.271988 , ..., 18.677887 ,
18.420748 , 18.690367 ],
[ 17.087952 , 17.095297 , 17.126497 , ..., 18.660385 ,
18.537983 , 18.611559 ],
[ 16.914423 , 16.933002 , 17.056286 , ..., 18.554018 ,
18.745262 , 18.601538 ],
...,
[ 16.63109 , 16.755116 , 16.846542 , ..., 20.074844 ,
20.104145 , 20.014791 ],
[ 16.704824 , 16.822357 , 16.797985 , ..., 20.03461 ,
19.977737 , 20.027634 ],
[ 16.795732 , 16.83358 , 16.7386 , ..., 20.039566 ,
19.841612 , 20.034042 ]]]], dtype=float32))
可以看到原来输入进来的图片的像素值都已经增加了21