caffe实现多标签输入(multilabel、multitask)

caffe里自带的convert_imageset.cpp直接生成一个data和label都集成在Datum的lmdb(Datum数据结构见最
后),只能集成一个label。而我们平时遇到的分类问题可能会有多个label比如颜色,种类等。

目前网上有多种解决方法:
1. 修改caffe代码,步骤繁琐,但是对于理解代码有帮助
2. 加入多个data和label层作为输入,简单可行,但是需要准备多个lmdb,较为麻烦
3. 等等等

特别推荐我的另一篇博客caffe实现多标签输入修改源码版,实现起来比这一篇还要简单和通俗易懂

注意:caffe的数据层输入格式(batch_size, c, h, w),通道是BGR

本文采用多label的lmdb+Slice Layer的方法

  • 生成多label的lmdb
import numpy as np
import os
import lmdb
from PIL import Image 
import numpy as np 
import sys
# Make sure that caffe is on the python path:
caffe_root = 'your caffe root path'
sys.path.insert(0, caffe_root + '/python')
import caffe
####################pre-treatment############################
#txt with labels eg. (0001.jpg 2 5)
file_input=open('your label txt','r')
img_list=[]
label1_list=[]
label2_list=[]
for line in file_input:
    content=line.strip()
    content=content.split(' ')
    img_list.append(int(content[0]))
    label1_list.append(int(content[1]))
    label2_list.append(int(content[2]))
    del content
file_input.close() 
####################train data(images)############################
#your data lmdb path
#注意一定要先删除之前生成的lmdb,因为lmdb会在之前的数据基础上新增数据,而不会先清空
#os.system('rm -rf  ' + your data(images) lmdb path)
in_db=lmdb.open('your data(images) lmdb path',map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
    for in_idx,in_ in enumerate(img_list):         
        im_file='your images path'+in_
        im=Image.open(im_file)
        im = im.resize((w,h),Image.BILINEAR)#放缩图片,分类一般用
        #双线性BILINEAR,分割一般用最近邻NEAREST,**注意准备测试数据时一定要一致**
        im=np.array(im) # im: (w,h)RGB->(h,w,3)RGB
        im=im[:,:,::-1]#把im的RGB调整为BGR
        im=im.transpose((2,0,1))#把height*width*channel调整为channel*height*width
        im_dat=caffe.io.array_to_datum(im)
        in_txn.put('{:0>10d}'.format(in_idx),im_dat.SerializeToString())   
        print 'data train: {} [{}/{}]'.format(in_, in_idx+1, len(file_list))        
        del im_file, im, im_dat
in_db.close()
print 'train data(images) are done!'
######train data of label################    
#your labels lmdb path
in_db=lmdb.open('your labels lmdb path',map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
    for in_idx,in_ in enumerate(img_list):
        target_label=np.zeros((2,1,1))# 2种label
        target_label[0,0,0]=label1_list[in_idx]
        target_label[1,0,0]=label2_list[in_idx]
        label_data=caffe.io.array_to_datum(target_label)
        in_txn.put('{:0>10d}'.format(in_idx),label_data.SerializeToString())
        print 'label train: {} [{}/{}]'.format(in_, in_idx+1, len(file_list))
        del target_label, label_data    
in_db.close()
print 'train labels are done!'
  • 修改prototxt
layer {
  name: "data"
  type: "Data"
  include {
    phase: TRAIN
  }
 data_param {
    source: "your data(images) lmdb"
    batch_size: 10
    backend: LMDB
  }
  top: "data"
}
layer {
  name: "label"  
  type: "Data"
  top: "label"
  include {    
    phase: TRAIN  
  }
  data_param {    
    source: "your labels lmdb"    
    batch_size: 10    
    backend: LMDB  
  }
}
layer {
  name: "slice"
  type: "Slice"
  bottom: "label"
  top: "label_1"
  top: "label_2"
  slice_param {
    axis: 1
    slice_point: 1
  }
  # 假如有3个label,n个同理
  #  slice_point: 1
  #  slice_point: 2
  # axis指定目标轴,batch_size,c,h,w 刚刚生成label时c就是label种类数,c在第1个维度(batch_size是0)
  # slice_point指定选定维数的索引(索引的数量必须等于blob数量减去一)。
  # 比如10维,输出3个label 且slice_point: 2 slice_point: 6
  # 1 2 | 3 4 5 6 | 7 8 9 10  '|' 为切点(1,2)(3,4,5,6)(7,8,9)
  # 这种切点我们暂时可忽略
}

layer可视化

测试

测试网络

#deploy.prototxt
......
layer {
  name: "fc8_1"
  type: "InnerProduct"
  bottom: "fc7_1"
  top: "fc8_1"
  inner_product_param {
    num_output: 26 #第1个属性的类别数
  }
}
layer {
  name: "prob_1"
  type: "Softmax"
  bottom: "fc8_1"
  top: "prob_1"
}
layer {
  name: "fc8_2"
  type: "InnerProduct"
  bottom: "fc7_2"
  top: "fc8_2"
  inner_product_param {
    num_output: 12 #第2个属性的类别数
  }
}
layer {
  name: "prob_2"
  type: "Softmax"
  bottom: "fc8_2"
  top: "prob_2"
}

测试代码: 完整代码见另一篇博客caffe利用训练好的模型进行test

image_names = [图片list]
for image_name in image_names:
    image = caffe.io.load_image(image_name)# 用的是skimage库,见附录
    # 利用刚刚的设置进行图片预处理
    transformed_image = transformer.preprocess('data', image)
    net.blobs['data'].data[...] = transformed_image
    # 网络前传(测试无后传)
    output = net.forward()
    output_prob_1= output['prob_1'][0].argmax()  # 第1个属性概率最大的label
    output_prob_2= output['prob_2'][0].argmax()  # 第2个属性概率最大的label
    print image_name,output_prob_1,output_prob_2

总结

本文把多个label集成到一个lmdb中,最后在prototxt中通过Slice层分离开,避免了多个lmdb的繁琐,整体操作起来简单易懂。

知识延伸

  Datum数据结构定义如下,上述代码直接用caffe.io.array_to_datum转化成Datum,其实可以根据它的结构一一
  设置,比如datum.channels = 3,datum.data = im.tobytes() (im为图像数组shape为(3,224,224)),根据caffe
  说明,图像数据也可以放在float_data中)
  因为caffe中的图像样例是单个label,所以可以把图像的label存储于Datum中的label,但是多label的话就不行
  了,需要用上述的方法把多label存储于Datum中的data
message Datum {
  optional int32 channels = 1;
  optional int32 height = 2;
  optional int32 width = 3;
  // the actual image data, in bytes
  optional bytes data = 4;
  optional int32 label = 5;
  // Optionally, the datum could also hold float data.
  repeated float float_data = 6;
  // If true data contains an encoded image that need to be decoded
  optional bool encoded = 7 [default = false];
}
  • 7
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 74
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 74
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值