用自己的数据训练Faster-RCNN,tensorflow版本(二)

本文档详细介绍了如何使用tensorflow版本的Faster-RCNN训练自定义数据集ID_card,包括数据格式、读写接口编写、factory.py、config.py、VGG_train.py和VGG_test.py的修改,以及训练和测试模型的步骤。重点在于自定义数据读取接口ID_card.py和模型配置的调整。
摘要由CSDN通过智能技术生成

我用的Faster-RCNN是tensorflow版本,fork自githubFaster-RCNN_TF
参考博客http://www.cnblogs.com/CarryPotMan/p/5390336.html

用自己的数据训练Faster-RCNN,tensorflow版本(一)中我们详细介绍了Faster-rcnn_TF中pascal_voc数据的读写接口,接下来介绍一下,如何编写自己的数据读写接口。

3、编写自己的数据读写接口

我们要用自己的数据进行训练,就得编写自己数据的读写接口,下面参考pascal_voc.py来编写。根据用自己的数据训练Faster-RCNN,tensorflow版本(一)中对pascal_voc.py文件的分析,发现,pascal_voc.py用了非常多的路径拼接,很麻烦,我们不用这么麻烦,简单一点就可以。

3.1、介绍一下我自己的训练数据集格式

我主要是从自然图片中检测出文本,因此我只有background 和text两类物体,我并没有像pascal_voc数据集里面一样每个图像用一个xml来标注,先说一下我的数据格式:

所有需要用到的数据我都放在了目录Data/ID_card/下面。
目录Data/ID_card/下面包含2个文件夹,分别是train,test。
先介绍train,目录Data/ID_card/train/里面包含:
1、所有的训练图片
2、gt_ID_card.txt
3、train.txt

我把train集合中所有图片的gt,集中放在了一个gt_ID_card.txt文件里面,gt_ID_card.txt格式如下:
gt_ID_card.txt

以第一行为例:
ID_card/back_1.jpg: 是图片的名字;
数字1:代表该张图片上只有一个文本(text);
后面的四个数值:分别是文本框左上角和右下角的坐标。我的图片里面只有一行文本,所以只有一组文本框的坐标。

train.txt文件存放的是所有图片的名字,没有后缀,如下图:
train.txt

3.2、编写自己的数据读写接口ID_card.py

主要修改的关键函数就是:def _load_annotation(self)——读取图片gt。

编写自己的数据读写接口ID_card.py,内容如下:

#coding:utf-8
# --------------------------------------------------------
# 
# Written by lisiqi
# --------------------------------------------------------

import datasets
import os
import datasets.imdb
import xml.dom.minidom as minidom
import numpy as np
import scipy.sparse
import scipy.io as sio
import utils.cython_bbox
import cPickle
import subprocess

class ID_card(datasets.imdb):
    def __init__(self, image_set, data_path=None):
        datasets.imdb.__init__(self, 'ID_card_' + image_set) #image_set 为train或者val或者trainval或者test。
        self._image_set = image_set # image_set以train为例
        self._data_path = data_path # 数据所在的路径,根据传进来的参数data_path而定。传进来的参数data_path在我这里就是Data/ID_card/
        self._classes = ('__background__','text') #object的类别,只有两类:背景和文本
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) #构成字典{'__background__':'0','text':'1'}
        self._image_ext = '.jpg' #图片后缀
        self._image_index = self._load_image_set_index() #读取train.txt,获取图片名称(该图片名称没有后缀.jpg)
        # Default to roidb handler
        self._roidb_handler = self.gt_roidb #获取图片的gt
        # PASCAL specific config options
        self.config = {
  'cleanup'  : True,
                       'use_salt' : True,
                       'top_k'    : 2000}

        assert os.path.exists(self._data_path), \ #如果路径Data/ID_card不存在,退出
                'Image Path does not exist: {}'.format(self._data_path)

    def image_path_at(self, i):#获得_image_index 下标为i的图像的路径
        """
        Return the absolute path to image i in the image sequence.
        """
        return self.image_path_from_index(self._image_index[i])

    def image_path_from_index(self, index):#根据_image_index获取图像路径
        """
        Construct an image path from the image's "index" identifier.
        """
        image_path = os.path.join(self._data_path, index, self._image_ext)
        assert os.path.exists(image_path), \
                'Path does not exist: {}'.format(image_path)
        return image_path

    def _load_image_set_index(self):#已做修改
        """
        Load the indexes listed in this dataset's image set file.
        得到图片名称的list。这个list里面是集合self._image_set=train中所有图片的名字(注意,图片名字没有后缀.jpg)
        """
        image_set_file = os.path.join(self._data_path, self._image_set, self._image_set + '.txt') 
        #image_set_file是Data/ID_card/train/train.txt
        #之所以要读这个train.txt文件,是因为train.txt文件里面写的是集合train中所有图片的名字(没有后缀.jpg)
        assert os.path.exists(image_set_file), \
                'Path does not exist: {}'.format(image_set_file)
        with open(image_set_file) as f: #读取train.txt,获取图片名称(没有后缀.jpg)
            image_index = [x.strip() for x in
# 工程内容 这个程序是基于tensorflow的tflearn库实现部分RCNN功能。 # 开发环境 windows10 + python3.5 + tensorflow1.2 + tflearn + cv2 + scikit-learn # 数据集 采用17flowers据集, 官网下载:http://www.robots.ox.ac.uk/~vgg/data/flowers/17/ # 程序说明 1、setup.py---初始化路径 2、config.py---配置 3、tools.py---进度条和显示带框图像工具 4、train_alexnet.py---大数据集预训练Alexnet网络,140个epoch左右,bitch_size为64 5、preprocessing_RCNN.py---图像的处理(选择性搜索、数据存取等) 6、selectivesearch.py---选择性搜索源码 7、fine_tune_RCNN.py---小数据集微调Alexnet 8、RCNN_output.py---训练SVM并测试RCNN(测试的时候测试图片选择第7、16类中没有参与训练的,单朵的花效果好,因为训练用的都是单朵的) # 文件说明 1、train_list.txt---预训练数据数据在17flowers文件夹中 2、fine_tune_list.txt---微调数据2flowers文件夹中 3、1.png---直接用选择性搜索的区域划分 4、2.png---通过RCNN后的区域划分 # 程序问题 1、由于数据集小的原因,在微调时候并没有像论文一样按一个bitch32个正样本,128个负样本输入,感觉正样本过少; 2、还没有懂最后是怎么给区域打分的,所有非极大值抑制集合canny算子没有进行,待续; 3、对选择的区域是直接进行缩放的; 4、由于数据集合论文采用不一样,但是微调和训练SVM时采用的IOU阈值一样,有待调参。
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值