利用pycaffe自定义实现数据加载
pycaffe简介
pycaffe的是caffe的python接口,由于caffe是利用C++编写的深度学习框架,有时我们为了验证想法,如果利用C++去实现的话,将会使得验证周期加长;而pycaffe可利用python接口,快速的完成新层的定义,下面将介绍下如何利用pycaffe定义数据读取层。
利用pycaffe编写数据读取层
为方便讲解,以centernet为例讲解如何使用pycaffe建立数据读取层。
centernet论文中关于损失的描述如下:
L
d
e
t
=
L
k
+
λ
s
i
z
e
L
s
i
z
e
+
λ
o
f
f
L
o
f
f
L_{det}=L_{k}+\lambda_{size}L_{size}+\lambda_{off}L_{off}
Ldet=Lk+λsizeLsize+λoffLoff
总共有三个loss,其中
L
k
L_{k}
Lk为中心点热图的损失;
L
s
i
z
e
L_{size}
Lsize为预测物体大小w和h的损失,
L
o
f
f
L_{off}
Loff为中心点偏移的损失。
明白了如何计算损失,那么只需要做对应的数据标签即可。具体步骤如下所示:
1.首先新建一个py文件,并输入如下代码:
import sys
sys.path.append('/home/zkai/ZK/study/yolov3/caffe/FaceBoxes-master_addseg/python')
#上述添加的路径为caffe所在目录下的python
import caffe
import cv2
import numpy as np
2.定义数据读取类:
主要实现setup和forward函数;setup函数读取相应参数,如batch_size、source、crop_size等等参数,这些参数同样需要在网络结构参数net.prototxt中设置好。
class Centernet_Read_data(caffe.Layer):
def setup(self, bottom, top):
self.top_names = ['data', 'centermap','xy_offset','wh','pts_offset']
params = eval(self.param_str)
self.batch_size = params['batch_size']
self.source=params['source']
self.crop_h=params['crop_size'][0]
self.crop_w=params['crop_size'][1]
self.num_classes=params['num_classes']
self.stride=params['stride']
with open(self.source,'r')as f:
annos=f.readlines()
self.lines=[]
for anno in annos:
anno=anno[:-1].split(' ')
imgname=anno[0]
xywh=np.array(list(map(float,anno[1:]))).reshape(-1,4)
self.lines.append([imgname,xywh])
random.shuffle(self.lines)
self.datalen=len(annos)
self.index=0
top[0].reshape(self.batch_size, 3, params['crop_size'][0], params['crop_size'][1])
top[1].reshape(self.batch_size, self.num_classes, params['crop_size'][0]/self.stride, params['crop_size'][1]/self.stride) #heatmap
top[2].reshape(self.batch_size, 2, params['crop_size'][0]/self.stride, params['crop_size'][1]/self.stride) #heatmap offset
top[3].reshape(self.batch_size, 2, params['crop_size'][0]/self.stride, params['crop_size'][1]/self.stride) #w h
def forward(self, bottom, top):
for i in range(self.batch_size):
if self.index>=self.datalen:
self.index=0
img=cv2.imread(self.lines[self.index][0],1)
x1,y1,x2,y2=self.lines[self.index][1]#.copy()
new_img=img.copy(
new_x1=x1
new_y1=y1
new_x2=x2
new_y2=y2
'''add data augumentation code in this area'''
'''resize image to network input size and adjust x1,y1,x2,y2 to input size'''
img_h,img_w,_=new_img.shape
if img_h==0 or img_w==0:
print(self.lines[self.index][0])
scale_x=float(self.crop_w)/img_w
scale_y=float(self.crop_h)/img_h
scale=min(scale_x,scale_y)
n_w=int(img_w*scale)
n_h=int(img_h*scale)
new_pts[:,0]=new_pts[:,0]*scale
new_pts[:,1]=new_pts[:,1]*scale
re_img=cv2.resize(new_img,(n_w,n_h))
dx=(self.crop_w-n_w)//2
dy=(self.crop_h-n_h)//2
n_img=np.zeros([self.crop_h,self.crop_w,3],dtype=np.float32)
n_img[dy:n_h+dy,dx:n_w+dx,:]=re_img
n_imgdata=np.transpose(n_img,[2,0,1])/255.0
# np.save("%d_img.npy"%i,n_imgdata)
new_x1=new_x1*scale
new_y1=new_y1*scale
new_x2=new_x2*scale
new_y2=new_y2*scale
new_x1=new_x1+dx
new_y1=new_y1+dy
new_x2=new_x2+dx
new_y2=new_y2+dy
new_pts[:,0]=new_pts[:,0]+dx
new_pts[:,1]=new_pts[:,1]+dy
center_x=((new_x1+new_x2)/2)/self.stride
center_y=((new_y1+new_y2)/2)/self.stride
w=(new_x2-new_x1)#/self.stride
h=(new_y2-new_y1)#/self.stride
c_x=int(center_x)
c_y=int(center_y)
v_x=center_x-c_x
v_y=center_y-c_y
heatmap=np.zeros([1,self.crop_h//self.stride,self.crop_w//self.stride],dtype=np.float32)
heatmap=put_heatmap(heatmap,0,[c_x,c_y],3)
# np.save('heatmap.npy',heatmap)
top[0].data[i,...]=n_imgdata
top[1].data[i,...]=heatmap
# np.save('%d_heatmap.npy'%i,top[1].data[i,...])
top[2].data[i,0,c_y,c_x]=v_x
top[2].data[i,1,c_y,c_x]=v_y
top[3].data[i,0,c_y,c_x]=np.log(w/self.stride)
top[3].data[i,1,c_y,c_x]=np.log(h/self.stride)
def reshape(self, bottom, top):
pass
def backward(self, top, propagate_down, bottom):
pass
3.网络结构net.prototxt中设置相应参数:
数据层设置如下的参数, type: “Python”, module: “centernet_heatmap_read_data”,module为文件名称;layer为定义的类名称,Centernet_Read_data
layer {
name: "data"
type: "Python"
top: "data"
top: "heatmap"
python_param {
module: "centernet_read_data"
layer: "Centernet_Read_data"
param_str: "{\'source\':'/home/zkai/ZK/study/centernet/caffe-centernet/data/train1.txt' ,\'batch_size\':16,\'stride\':4,\'num_classes\':1,\'crop_size\':[512,512]}"
}
同时训练时如果出现如下错误:
No module named centernet_read_data
可在训练终端输入export PYTHONPATH=/path/to/layer location; /path/to/layer location为刚刚建立的文件所在路径。