项目地址:https://github.com/zhongqianli/caffe_python_layer
caffe自定义网络层的一种方式是使用python layer,这种方式需要使用pycaffe运行,命令行的方式运行会报错。
编写DataAugmentationLayer
这个类的基类是caffe.Layer,需要编写setup,reshape,forward,backward四个方法,每个方法都有top和bottom参数,可以通过top[0].data和bottom[0].data获取一个4维的数据,分别是batch_size、通道数、高、宽。
import caffe
import json
import cv2
import numpy as np
import random
# 4 pixel pad, random crop
# img: 64x3x32x32
def zeropadding_and_crop(data):
# # cifar10
# # padding_img = np.pad(img, ((4, 4), (4, 4), (4, 4)), "constant", padder=0)
padding_img = np.zeros((np.shape(data)[0], 3, 40, 40), dtype=np.uint8)
padding_img[..., 4:36, 4:36] = data[...]
# #
# cv2.imshow("pad", data[0][0])
row_rand_num = random.randrange(9)
col_rand_num = random.randrange(9)
croped_img = padding_img[..., row_rand_num : row_rand_num + 32, col_rand_num : col_rand_num + 32]
return croped_img
class DataAugmentationLayer(caffe.Layer):
def setup(self, bottom, top):
pass
def reshape(self, bottom, top):
top[0].reshape(*bottom[0].data.shape)
pass
def forward(self, bottom, top):
top[0].data[...] = zeropadding_and_crop(bottom[0].data)
pass
def backward(self, top, propagate_down, bottom):
pass
使用自定义的网络层DataAugmentationLayer
pycaffe最好使用net.xxx的方式创建网络,因为第二种方式会自动命名,可能会出现一下意想不到的问题。
# 第一种方式,推荐使用
net.data = L.Python(net.data_temp,
python_param=dict(module="custom_data_augmentation",
layer="DataAugmentationLayer"),
include=dict(phase=caffe_pb2.Phase.Value('TRAIN')))
# 第二种方式,不推荐这种方式
data = L.Python(data_temp,
python_param=dict(module="custom_data_augmentation",
layer="DataAugmentationLayer"),
include=dict(phase=caffe_pb2.Phase.Value('TRAIN')))