pycaffe定义数据读取层

pycaffe简介

pycaffe的是caffe的python接口,由于caffe是利用C++编写的深度学习框架,有时我们为了验证想法,如果利用C++去实现的话,将会使得验证周期加长;而pycaffe可利用python接口,快速的完成新层的定义,下面将介绍下如何利用pycaffe定义数据读取层。

利用pycaffe编写数据读取层

为方便讲解,以centernet为例讲解如何使用pycaffe建立数据读取层。
centernet论文中关于损失的描述如下:
L d e t = L k + λ s i z e L s i z e + λ o f f L o f f L_{det}=L_{k}+\lambda_{size}L_{size}+\lambda_{off}L_{off} Ldet=Lk+λsizeLsize+λoffLoff
总共有三个loss,其中 L k L_{k} Lk为中心点热图的损失; L s i z e L_{size} Lsize为预测物体大小w和h的损失, L o f f L_{off} Loff为中心点偏移的损失。

明白了如何计算损失,那么只需要做对应的数据标签即可。具体步骤如下所示:

1.首先新建一个py文件,并输入如下代码:

import sys
sys.path.append('/home/zkai/ZK/study/yolov3/caffe/FaceBoxes-master_addseg/python')
#上述添加的路径为caffe所在目录下的python
import caffe
import cv2
import numpy as np

2.定义数据读取类:

主要实现setup和forward函数;setup函数读取相应参数,如batch_size、source、crop_size等等参数,这些参数同样需要在网络结构参数net.prototxt中设置好。

class Centernet_Read_data(caffe.Layer):
    def setup(self, bottom, top):
        self.top_names = ['data', 'centermap','xy_offset','wh','pts_offset']
        params = eval(self.param_str)
        self.batch_size = params['batch_size']
        self.source=params['source']

        self.crop_h=params['crop_size'][0]
        self.crop_w=params['crop_size'][1]
        self.num_classes=params['num_classes']
        self.stride=params['stride']
      
        with open(self.source,'r')as f:
            annos=f.readlines()
        self.lines=[]
        for anno in annos:
            anno=anno[:-1].split(' ')
            imgname=anno[0]
            xywh=np.array(list(map(float,anno[1:]))).reshape(-1,4)
            self.lines.append([imgname,xywh])
        random.shuffle(self.lines)  
        self.datalen=len(annos)  
        self.index=0
		top[0].reshape(self.batch_size, 3, params['crop_size'][0], params['crop_size'][1])
        top[1].reshape(self.batch_size, self.num_classes, params['crop_size'][0]/self.stride, params['crop_size'][1]/self.stride) #heatmap
        top[2].reshape(self.batch_size, 2, params['crop_size'][0]/self.stride, params['crop_size'][1]/self.stride) #heatmap offset
        top[3].reshape(self.batch_size, 2, params['crop_size'][0]/self.stride, params['crop_size'][1]/self.stride) #w h
    def forward(self, bottom, top):
        for i in range(self.batch_size):
            if self.index>=self.datalen:
                self.index=0
            img=cv2.imread(self.lines[self.index][0],1)
            x1,y1,x2,y2=self.lines[self.index][1]#.copy()
            new_img=img.copy(
            new_x1=x1
            new_y1=y1
            new_x2=x2
            new_y2=y2     
            '''add data augumentation code in this area'''

			'''resize image to network input size and adjust x1,y1,x2,y2 to input size'''
			img_h,img_w,_=new_img.shape
            if img_h==0 or img_w==0:
                print(self.lines[self.index][0])
            scale_x=float(self.crop_w)/img_w
            scale_y=float(self.crop_h)/img_h
            scale=min(scale_x,scale_y)
            n_w=int(img_w*scale)
            n_h=int(img_h*scale)
            new_pts[:,0]=new_pts[:,0]*scale
            new_pts[:,1]=new_pts[:,1]*scale


            re_img=cv2.resize(new_img,(n_w,n_h))
            dx=(self.crop_w-n_w)//2
            dy=(self.crop_h-n_h)//2
            n_img=np.zeros([self.crop_h,self.crop_w,3],dtype=np.float32)
            
            n_img[dy:n_h+dy,dx:n_w+dx,:]=re_img    
            n_imgdata=np.transpose(n_img,[2,0,1])/255.0 
            # np.save("%d_img.npy"%i,n_imgdata)
            new_x1=new_x1*scale
            new_y1=new_y1*scale
            new_x2=new_x2*scale
            new_y2=new_y2*scale

            new_x1=new_x1+dx
            new_y1=new_y1+dy
            new_x2=new_x2+dx
            new_y2=new_y2+dy
            new_pts[:,0]=new_pts[:,0]+dx
            new_pts[:,1]=new_pts[:,1]+dy

            center_x=((new_x1+new_x2)/2)/self.stride
            center_y=((new_y1+new_y2)/2)/self.stride
            w=(new_x2-new_x1)#/self.stride
            h=(new_y2-new_y1)#/self.stride

            c_x=int(center_x)
            c_y=int(center_y)
            v_x=center_x-c_x
            v_y=center_y-c_y


            heatmap=np.zeros([1,self.crop_h//self.stride,self.crop_w//self.stride],dtype=np.float32)
            heatmap=put_heatmap(heatmap,0,[c_x,c_y],3)
            # np.save('heatmap.npy',heatmap)

            top[0].data[i,...]=n_imgdata
            top[1].data[i,...]=heatmap
            # np.save('%d_heatmap.npy'%i,top[1].data[i,...])
            top[2].data[i,0,c_y,c_x]=v_x
            top[2].data[i,1,c_y,c_x]=v_y
            top[3].data[i,0,c_y,c_x]=np.log(w/self.stride)
            top[3].data[i,1,c_y,c_x]=np.log(h/self.stride)
    def reshape(self, bottom, top):
        pass
    def backward(self, top, propagate_down, bottom):
        pass     
            

3.网络结构net.prototxt中设置相应参数:

数据层设置如下的参数, type: “Python”, module: “centernet_heatmap_read_data”,module为文件名称;layer为定义的类名称,Centernet_Read_data

layer {
  name: "data"
  type: "Python"
  top: "data"
  top: "heatmap"
  python_param {
    module: "centernet_read_data"
    layer: "Centernet_Read_data"
    param_str: "{\'source\':'/home/zkai/ZK/study/centernet/caffe-centernet/data/train1.txt' ,\'batch_size\':16,\'stride\':4,\'num_classes\':1,\'crop_size\':[512,512]}"
  }

同时训练时如果出现如下错误:
No module named centernet_read_data
可在训练终端输入export PYTHONPATH=/path/to/layer location; /path/to/layer location为刚刚建立的文件所在路径。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值