caffe里自带的convert_imageset.cpp直接生成一个data和label都集成在Datum的lmdb(Datum数据结构见最
后),只能集成一个label。而我们平时遇到的分类问题可能会有多个label比如颜色,种类等。
转自:http://blog.csdn.net/u013010889/article/details/53098346
目前网上有多种解决方法:
1. 修改caffe代码,步骤繁琐,但是对于理解代码有帮助
2. 加入多个data和label层作为输入,简单可行,但是需要准备多个lmdb,较为麻烦
3. 等等等
注意:caffe的数据层输入格式(batch_size, c, h, w),通道是BGR
本文采用多label的lmdb+Slice Layer的方法
import numpy as np
import os
import lmdb
from PIL import Image
import numpy as np
import sys
caffe_root = 'your caffe root path'
sys.path.insert(0, caffe_root + '/python')
import caffe
file_input=open('your label txt','r')
img_list=[]
label1_list=[]
label2_list=[]
for line in file_input:
content=line.strip()
content=content.split(' ')
img_list.append(int(content[0]))
label1_list.append(int(content[1]))
label2_list.append(int(content[2]))
del content
file_input.close()
in_db=lmdb.open('your data(images) lmdb path',map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
for in_idx,in_ in enumerate(img_list):
im_file='your images path'+in_
im=Image.open(im_file)
im = im.resize((w,h),Image.BILINEAR)
im=np.array(im)
im=im[:,:,::-1]
im=im.transpose((2,0,1))
im_dat=caffe.io.array_to_datum(im)
in_txn.put('{:0>10d}'.format(in_idx),im_dat.SerializeToString())
print 'data train: {} [{}/{}]'.format(in_, in_idx+1, len(file_list))
del im_file, im, im_dat
in_db.close()
print 'train data(images) are done!'
in_db=lmdb.open('your labels lmdb path',map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
for in_idx,in_ in enumerate(img_list):
target_label=np.zeros((2,1,1))
target_label[0,0,0]=label1_list[in_idx]
target_label[1,0,0]=label2_list[in_idx]
label_data=caffe.io.array_to_datum(target_label)
in_txn.put('{:0>10d}'.format(in_idx),label_data.SerializeToString())
print 'label train: {} [{}/{}]'.format(in_, in_idx+1, len(file_list))
del target_label, label_data
in_db.close()
print 'train labels are done!'
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
layer {
name: "data"
type: "Data"
include {
phase: TRAIN
}
data_param {
source: "your data(images) lmdb"
batch_size: 10
backend: LMDB
}
top: "data"
}
layer {
name: "label"
type: "Data"
top: "label"
include {
phase: TRAIN
}
data_param {
source: "your labels lmdb"
batch_size: 10
backend: LMDB
}
}
layer {
name: "slice"
type: "Slice"
bottom: "label"
top: "label_1"
top: "label_2"
slice_param {
axis: 1
slice_point: 1
}
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
测试
测试网络
#deploy.prototxt
......
layer {
name: "fc8_1"
type: "InnerProduct"
bottom: "fc7_1"
top: "fc8_1"
inner_product_param {
num_output: 26 #第1个属性的类别数
}
}
layer {
name: "prob_1"
type: "Softmax"
bottom: "fc8_1"
top: "prob_1"
}
layer {
name: "fc8_2"
type: "InnerProduct"
bottom: "fc7_2"
top: "fc8_2"
inner_product_param {
num_output: 12 #第2个属性的类别数
}
}
layer {
name: "prob_2"
type: "Softmax"
bottom: "fc8_2"
top: "prob_2"
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
测试代码: 完整代码见另一篇博客caffe利用训练好的模型进行test
image_names = [图片list]
for image_name in image_names:
image = caffe.io.load_image(image_name)
transformed_image = transformer.preprocess('data', image)
net.blobs['data'].data[...] = transformed_image
output = net.forward()
output_prob_1= output['prob_1'][0].argmax()
output_prob_2= output['prob_2'][0].argmax()
print image_name,output_prob_1,output_prob_2
总结
本文把多个label集成到一个lmdb中,最后在prototxt中通过Slice层分离开,避免了多个lmdb的繁琐,整体操作起来简单易懂。
知识延伸
Datum数据结构定义如下,上述代码直接用caffe.io.array_to_datum转化成Datum,其实可以根据它的结构一一
设置,比如datum.channels = 3,datum.data = im.tobytes() (im为图像数组shape为(3,224,224)),根据caffe
说明,图像数据也可以放在float_data中)
因为caffe中的图像样例是单个label,所以可以把图像的label存储于Datum中的label,但是多label的话就不行
了,需要用上述的方法把多label存储于Datum中的data
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
optional bytes data = 4;
optional int32 label = 5;
repeated float float_data = 6;
optional bool encoded = 7 [default = false];
}