Caffe编写Python layer
在使用caffe做训练的时候,通常的做法是把数据转为lmdb格式,然后在train.prototxt中指定,最后在开始训练,转为lmdb格式的优点是读取数据更高效,但是缺点就是灵活性比较差。有时候我们的标是自定义的,这个时候就可以借助caffe提供的python接口来进行定义数据层。
1. 需要注意的点
- 编译caffe的时候要指定 WITH_PYTHON_LAYER := 1
- 要指定编译好的[xxx/build/tools/caffe]和路径对应的[xxx/python/]
- Python的类中一定要有的函数是 setup forward reshape backward,做数据层的话后面两个函数可以不定义,但是一定要有这个两个函数存在
- 写好的python文件要放在[xxx/python/]下
- 变量"self.param_str_"是你要解析的参数变量,新版的caffe中改成了self.param_str
2. 编写Python layer的参考demo
import sys
sys.path.append("/export/docker/JXQ-23-46-49.h.chinabank.com.cn/surui/project_caffe/segmentation/ENet/caffe-enet/python")
import caffe
import numpy as np
from random import shuffle
from PIL import Image
height = 224
width = 224
channel = 3
class ImageData(caffe.Layer):
def setup(self, bottom, top):
params = eval(self.param_str_)
source = params["source"]
self.batch_size = params["batch_size"]
self.scale = params["scale"]
self.image_height = params["image_height"]
self.image_width = params["image_width"]
top[0].reshape(self.batch_size, 1, self.image_height, self.image_width) # image
top[1].reshape(self.batch_size) # label
self.img_path_labels = self.read_txt(source)
self.index = 0
def forward(self, bottom, top):
for i in range(self.batch_size):
image, label = self.next_image()
top[0].data[i, ...] = image
top[1].data[i, ...] = label
def reshape(self, bottom, top):
pass
def backward(self, bottom, top):
pass
def read_txt(self, source):
img_path_labels = []
with open(source) as f:
for line in f.readlines():
img_path_label = line.strip("\n").split(" ")
img_path_labels.append(img_path_label)
return img_path_labels
def next_image(self):
if self.index == len(self.img_path_labels):
self.index = 0
shuffle(self.img_path_labels)
imagePath, label = self.img_path_labels[self.index]
image = np.array(Image.open(imagePath).resize(width, height))
image = np.swapaxes(image, 0, 2) # 交换轴,例如: [h,w,b,c] -> [b,w,h,c]
image = image * self.scale
self.index += 1
return image, label
3. 使用Python layer导入训练数据的格式
layer {
name: "input_data"
type: "Python"
top: "data"
top: "label"
include{
phase: TRAIN
}
python_param {
module: "data_layer" # python脚本名data_layer.py
layer: "ImageData" # python类名
param_str: "{'batch_size':32, 'scale':0.0078125, 'image_height':160, 'image_width':200, 'source':'xxx/trainval.txt'}"
}
}
4.参考
https://blog.csdn.net/haima1998/article/details/79066084
https://www.jianshu.com/p/e05d1b210fcb