data层用在训练或测试阶段,为模型提供数据接口,caffe可以接受的数据类型包括数据库类型(如LMDB、LevelDB)、hdf5、内存数据、图片数据等。
1、数据库类型
该类型数据必须指定数据库文件夹路径,该文件夹内包含一个data.mdb文件和一个lock.mdb文件;还需要指定batch_size.
可选参数包括:
-
rand_skip: 在开始的时候,路过某个数据的输入。通常对异步的SGD很有用。
-
backend: 选择是采用LevelDB还是LMDB, 默认是LevelDB.
prototxt文件对应内容:
layer {
name: "mnist"
type: "Data"
top: "data"
top: "label"
include {
phase: TRAIN
}
transform_param {
scale: 0.003
mean_value: 52
}
data_param {
source: "examples/mnist/mnist_train_lmdb"
batch_size: 16
backend: LMDB
}
}
用python API定义该层代码:
n.data, n.label = caffe.layers.Data(batch_size=16,
source= "examples/mnist/mnist_train_lmdb",
ntop=2,
backend = P.Data.LMDB,
include=dict(phase=caffe.TRAIN),
transform_param=dict(scale=0.003, mean_value=52))
2、hdf5类型
同样需要指定扩展名为h5的数据文件路径,此外,也可指定包含多个h5路径的文件;也需要指定batch_size.
prototxt文件对应内容:
layer {
name: "InputData"
type: "HDF5Data"
top: "data"
top: "label"
include {
phase: TRAIN
}
hdf5_data_param {
source: "./training_data_paths.txt"
batch_size: 64
}
}
python API代码:
net.data, net.label = caffe.layers.HDF5Data(
name="InputData",
source='./training_data_paths.txt',
batch_size=64,
include=dict(phase=caffe.TRAIN),
ntop=2
)
3、图片类型数据
这种类型经常用来做分类任务,图片数据文件每一行给出图片的路径及该图片对应的类别。
layer {
name: "InputData"
type: "ImageData"
top: "data"
top: "label"
transform_param {
mirror: true
crop_size: 40
}
image_data_param {
source: "train.txt"
batch_size: 32
shuffle: true
new_height: 48
new_width: 48
is_color: true
root_folder: "/"
}
}
python API代码:
net.data ,net.label = caffe.layers.ImageData(
name="InputData"
source="train.txt",
batch_size=32,
new_width=48,
new_height=48,
ntop=2,
is_color=True,
shuffle=True,
root_folder='/',
transform_param=dict(crop_size=40,mirror=True))
4、内存型数据
直接用内存中的数据训练模型,这类数据往往是ndarray型。
layer {
name: "data"
type: "ImageData"
top: "data"
top: "label"
transform_param {
scale: 0.00390625
mean_value: 20.0
}
image_data_param {
source: "img_list"
}
memory_data_param {
batch_size: 16
channels: 1
height: 1
width: 230
}
}
python API代码:
def conv_pool_net():
n = caffe.NetSpec()
n.data, n.label = L.ImageData(source='img_list',
memory_data_param=dict(batch_size=16,
height=1,
width=230,
channels=1),
ntop=2,transform_param=dict(scale=0.00390625,
mean_value=20))
return n.to_proto()
print(str(conv_pool_net()))
5、部署时注意的问题
部署时,data层要做一下转换(图片来自网络):
最后来张网上的总结图片,总结的很好:
参考:
https://www.cnblogs.com/houjun/p/9909764.html