faster-rcnn tensorflow windows版本训练人脸检测。

项目下载

github faster-rcnn windows项目:https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3.5
按照项目说明,配置项目。过程中会遇到问题:
pil 库无法安装,因为当前python版本为3.5,pil库最多支持到2.7,pip无法安装,conda安装会覆盖当前python版本为2.7 切记!!解决方法:32.7版本后直接下载pillow库即可。
博主环境:python3.5,tensorflow1.13-gpu环境,

数据说明

VOCDevkit2007 文件夹为当前项目使用到的数据
Annotations中的xml为每张图片的标记数据,分类名,检测框坐标。文件名对于图片名。
JPEGImages 为所有图片原数据
ImageSets\Main\trainval.txt 为图片名 索引文件。数据加载 图片名从这个文件中读取。

项目训练

运行train.py文件进行 训练。config/config.py为配置文件模型保存和batch_size 学习率等都保存在这。模型输出文件夹在default\voc_2007_trainval\default,训练完毕需要将模型文件复制到\output\vgg16\voc_2007_trainval+voc_2012_trainval\default下,demo运行使用模型

Demo运行

demo.py文件运行,直接运行报错,找不到模型文件。demo中用到的数据在data/demo/下。
错误解决,修改demo文件:
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_40000.ckpt',), 'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
代码中根据网络模式 将模型名修改成自己的模型名,迭代训练次数不一样,模型名不一样。
parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]', choices=NETS.keys(), default='vgg16')
default 修改成vgg16。
其他错误,一些缺包的错误,安装包即可。一些无法安装的包,一般都是替代包。跟pil一样。仔细看说明,错误提示中会有提示替代包
到此,整个原生faster-rcnn 已经算是完成了。

改造成人脸检测

人脸检测数据说明

WIDER FACE 图片库,下载地址http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/
在这里插入图片描述
这三个文件需要下载。训练数据和验证数据annotations是标记数据。测试数据看自己的需求。

为了方便faster rcnn做训练,必须先把wider face数据库转为voc2007格式,也就是和原生数据一样的格式最后会附代码。直接转换。

数据准备

将data\VOCDevkit2007\VOC2007 下所有文件删除

下载的数据解压到data\VOCDevkit2007\Wider_face\文件夹下
在这里插入图片描述
新建数据格式转换文件
在这里插入图片描述
文件代码在最后。
运行脚本缺包需要安装,需要等待一段时间,复制图片,制作标签等。
脚本运行完毕data\VOCDevkit2007\Wider_face 文件夹下得到三个文件夹就是voc格式
在这里插入图片描述
将三个文件夹移动到\data\VOCDevkit2007\VOC2007文件夹下
进入到data\VOCDevkit2007\VOC2007\ImageSets\Main下
var.txt 文件中的内容复制到train.txt中 。修改train.txt文件名为:trainval.txt
到此数据已完全符合训练数据格式。

开始训练

修改lib/datasets/pascal_voc.py 中self._classes,添加自己的分类比如,face,
开始训练,如果还是报错 可尝试删除data/cache文件夹内容再次运行。

运行

/demo.py 文件中修改 CLASSES 添加分类 face
修改分类数量:net.create_architecture(sess, “TEST”, 22,

完毕!!!

总结训练过程中遇到的一些问题:

在这里插入图片描述
解决方法:报错原因是fg_inds和bg_inds的数量都小于0,这张图片没办法训练了,所以直接跳过这张图。办法是调整config.py里的roi_bg_threshold_high和roi_bg_threshold_low,一般把roi_bg_threshold_low改成0.0就不会出现这个问题。
在这里插入图片描述
解决方法:打开lib/database/pascal_voc.py文件,每一行后面的-1删除。原因是因为我们制作的xml文件中有些框的坐标是从左上角开始的,也就是(0,0)如果再减一就会出现log(-1)的情况。
改完之后就不会出现RuntimeWarning: invalid value encountered in log targets_dw = np.log(gt_widths / ex_widths)这个问题了,loss也不会出现等于nan了,如果还出现loss=nan,可以再试试调小学习率以及各个损失项的占比重。亲测有效。。

derface_to_voc.py 文件代码:

"""
Created on 19-4-18

@author: 段大帅
"""
from skimage import io
import shutil
import random
import os
import string

headstr = """\
<annotation>
    <folder>VOC2007</folder>
    <filename>%06d.jpg</filename>
    <source>
        <database>My Database</database>
        <annotation>PASCAL VOC2007</annotation>
        <image>flickr</image>
        <flickrid>NULL</flickrid>
    </source>
    <owner>
        <flickrid>NULL</flickrid>
        <name>company</name>
    </owner>
    <size>
        <width>%d</width>
        <height>%d</height>
        <depth>%d</depth>
    </size>
    <segmented>0</segmented>
"""
objstr = """\
    <object>
        <name>%s</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>%d</xmin>
            <ymin>%d</ymin>
            <xmax>%d</xmax>
            <ymax>%d</ymax>
        </bndbox>
    </object>
"""

tailstr = '''\
</annotation>
'''

def all_path(filename):
    return os.path.join('Wider_face', filename)

def writexml(idx, head, bbxes, tail):
    filename = all_path("Annotations/%06d.xml" % (idx))
    f = open(filename, "w")
    f.write(head)
    for bbx in bbxes:
        f.write(objstr % ('face', bbx[0], bbx[1], bbx[0] + bbx[2], bbx[1] + bbx[3]))
    f.write(tail)
    f.close()


def clear_dir():
    if shutil.os.path.exists(all_path('Annotations')):
        shutil.rmtree(all_path('Annotations'))
    if shutil.os.path.exists(all_path('ImageSets')):
        shutil.rmtree(all_path('ImageSets'))
    if shutil.os.path.exists(all_path('JPEGImages')):
        shutil.rmtree(all_path('JPEGImages'))

    shutil.os.mkdir(all_path('Annotations'))
    shutil.os.makedirs(all_path('ImageSets/Main'))
    shutil.os.mkdir(all_path('JPEGImages'))


def excute_datasets(idx, datatype):
    f = open(all_path('ImageSets/Main/' + datatype + '.txt'), 'a')
    f_bbx = open(all_path('wider_face_split/wider_face_' + datatype + '_bbx_gt.txt'), 'r')

    while True:
        filename = f_bbx.readline().strip('\n')
        if not filename:
            break
        try:
            im = io.imread(all_path('WIDER_' + datatype + '/images/'+filename))
        except IOError:
            print('错误文件名已跳过,',filename)
            continue
        head = headstr % (idx, im.shape[1], im.shape[0], im.shape[2])
        nums = f_bbx.readline().strip('\n')
        bbxes = []
        for ind in range(int(nums)):
            bbx_info = f_bbx.readline().strip(' \n').split(' ')
            bbx = [int(bbx_info[i]) for i in range(len(bbx_info))]
            #x1, y1, w, h, blur, expression, illumination, invalid, occlusion, pose
            if bbx[7]==0:
                bbxes.append(bbx)
        writexml(idx, head, bbxes, tailstr)
        shutil.copyfile(all_path('WIDER_' + datatype + '/images/'+filename), all_path('JPEGImages/%06d.jpg' % (idx)))
        f.write('%06d\n' % (idx))
        idx +=1
    f.close()
    f_bbx.close()
    return idx


# 打乱样本
def shuffle_file(filename):
    f = open(filename, 'r+')
    lines = f.readlines()
    random.shuffle(lines)
    f.seek(0)
    f.truncate()
    f.writelines(lines)
    f.close()


if __name__ == '__main__':
    clear_dir()
    idx = 1
    idx = excute_datasets(idx, 'train')
    idx = excute_datasets(idx, 'val')
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值