使用Retinanet训练自己的数据集

目录

 

目录

1 构建Retinanet环境

2 生成CSV文件

3训练

4.转化模型

5.测试

6.评测

loss可视化

ap,precision-recall


数据集什么的看我之前博客,资源里也有标记好的数据集,这里主要写一下我配置使用训练过程。

1 构建Retinanet环境

1.代码库下载地址https://github.com/fizyr/keras-retinanet,或git命令:

git clone https://github.com/fizyr/keras-retinanet.git

2.获得代码库后进入keras-retinanet文件夹,确认有未安装numpy
 

cd keras-retinanet

pip install numpy --user

在这个文件夹内运行下面代码来安装keras-retinanet库,确认你已经根据自己的系统需求安装了tensorflow

pip install . --user
python setup.py build_ext --inplace

2 生成CSV文件

训练自己的数据集需要至少两个CSV文件,一个文件包含标注数据,另一个则包含各个类别名及其对应的ID序号映射。

先抛出我的文件位置,新建一个csv文件夹,data文件里放置的是训练图片及标签

三个csv就是我们要生成的

参考博客https://blog.csdn.net/qq_27171347/article/details/88878346

"""
进入到csv文件夹下
运行方式:命令行 python xml2csv.py -i indir(图片及标注的母目录)
      注:必须参数: -i 指定包含有图片及标注的母文件夹,图片及标注可不在同一子目录里,但名称必须一一对应
                     (图片格式默认.jpg,若为其他格式可见代码中注释自行修改)
          可选参数: -p 交叉验证集拆分比,默认0.05
                   -t 生成训练集CSV文件名称,默认train.csv
                   -v 生成交叉验证集CSV文件名称,默认val.csv
                   -c 生成类别映射CSV文件名称,默认class.csv
"""

import os
import xml.etree.ElementTree as ET
import random
import math
import argparse


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--indir', type=str)
    parser.add_argument('-p', '--percent', type=float, default=0.1) #0.05
    parser.add_argument('-t', '--train', type=str, default='train.csv')
    parser.add_argument('-v', '--val', type=str, default='val.csv')
    parser.add_argument('-c', '--classes', type=str, default='class.csv')
    args = parser.parse_args()
    return args

#获取特定后缀名的文件列表
def get_file_index(indir, postfix):
    file_list = []
    for root, dirs, files in os.walk(indir):
        for name in files:
            if postfix in name:
                file_list.append(os.path.join(root, name))
    return file_list

#写入标注信息
def convert_annotation(csv, address_list):
    cls_list = []
    with open(csv, 'w') as f:
        for i, address in enumerate(address_list):
            in_file = open(address, encoding='utf8')
            strXml =in_file.read()
            in_file.close()
            root=ET.XML(strXml)
            for obj in root.iter('object'):
                cls = obj.find('name').text
                cls_list.append(cls)
                xmlbox = obj.find('bndbox')
                b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), 
                     int(xmlbox.find('xmax').text), int(xmlbox.find('ymax').text))
                f.write(file_dict[address_list[i]])
                f.write( "," + ",".join([str(a) for a in b]) + ',' + cls)
                f.write('\n')
    return cls_list


if __name__ == "__main__":
    args = parse_args()
    file_address = args.indir
    test_percent = args.percent
    train_csv = args.train
    test_csv = args.val
    class_csv = args.classes
    Annotations = get_file_index(file_address, '.xml')
    Annotations.sort()
    JPEGfiles = get_file_index(file_address, '.JPG') #可根据自己数据集图片后缀名修改
    JPEGfiles.sort()
    assert len(Annotations) == len(JPEGfiles) #若XML文件和图片文件名不能一一对应即报错
    file_dict = dict(zip(Annotations, JPEGfiles))
    num = len(Annotations)
    test = random.sample(k=math.ceil(num*test_percent), population=Annotations)
    train = list(set(Annotations) - set(test))

    cls_list1 = convert_annotation(train_csv, train)
    cls_list2 = convert_annotation(test_csv, test)
    cls_unique = list(set(cls_list1+cls_list2))

    with open(class_csv, 'w') as f:
        for i, cls in enumerate(cls_unique):
            f.write(cls + ',' + str(i) + '\n')

进入csv文件夹下,python xml2csv.py -i /home/zbb/keras-retinanet/CSV/data

class.csv:

#类别,序号(从0开始)
#class_name,id

plane,0

train.csv:

#路径,xmin,ymin,xmax,ymax,类别名
#path/to/image.jpg,x1,y1,x2,y2,class_name
/data/imgs/img_001.jpg,837,346,981,456,plane

 

3训练

csv 后第一个参数接标注csv文件路径 , 第二个接类别映射csv文件路径, 第三个参数可选择添加交叉验证集
示例:retinanet-train csv ./train.csv ./class.csv --val-annotations ./val.csv

一般还需指定 --epochs 训练轮数 默认值50
            --batch-size 一批训练多少个 默认值1
            --steps 一轮训练多少步 默认10000 需按照自己数据集size大小计算 steps = size / batch-size

至于是否加载权重训练,backbone选择(默认Resnet50,可选参见keras_retinanet/models),学习率大小等按照自己需要进行指定。
有250个样本, batch_size=1,训练100轮, 则命令如下:

retinanet-train --batch-size 1 --steps 250 --epochs 100 csv ./train.csv ./class.csv --val-annotations ./val.csv

4.转化模型


retinanet-convert-model 训练出的模型地址 转化后的推断模型地址

retinanet-convert-model ./snapshots/resnet50_csv_100.h5 ./model/resnet50_csv_100.h5

5.测试

返回上一层文件夹,即是keras-retinanet下,新建test.py文件运行测试,进行修改,可测试保存多张图片

import keras
from keras_retinanet import models
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image
from keras_retinanet.utils.visualization import draw_box, draw_caption
from keras_retinanet.utils.colors import label_color

import matplotlib.pyplot as plt
import cv2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import numpy as np
import time

import tensorflow as tf

def get_session():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    return tf.Session(config=config)

# 设置tensorflow session 为Keras 后端
keras.backend.tensorflow_backend.set_session(get_session())
#转化后的推断模型地址
#model_path = os.path.join('..', 'snapshots', 'predict.h5')
model_path ='/home/zbb/keras-retinanet/CSV/model/resnet50_csv_100.h5'
#加载模型
model = models.load_model(model_path, backbone_name='resnet50')
#建立ID与类别映射字典
labels_to_names = {0: 'plane'}
#加载需要检测的图片
#image_path = '/home/zbb/keras-retinanet/CSV/test/50.JPG'
path='/home/zbb/keras-retinanet/CSV/test/'
save_path='/home/zbb/keras-retinanet/CSV/result/'
image_names = sorted(os.listdir(path))
for image_path in image_names:
	image = read_image_bgr(path+image_path)
	print(path+image_path)
	# copy到另一个对象并转为RGB文件
	draw = image.copy()
	draw = cv2.cvtColor(draw, cv2.COLOR_BGR2RGB)
	# 图像预处理
	image = preprocess_image(image)
	image, scale = resize_image(image)
	# 模型预测
	start = time.time()
	boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))
	print("processing time: ", time.time() - start)
	# 矫正比例
	boxes /= scale
	# 目标检测可视化展示
	for box, score, label in zip(boxes[0], scores[0], labels[0]):
		# 设置预测得分最低阈值
		if score < 0.75:
			break
		color = label_color(label)
		b = box.astype(int)
		draw_box(draw, b, color=color)
		caption = "{} {:.3f}".format(labels_to_names[label], score)
		draw_caption(draw, b, caption)
	#图片展示
	plt.figure(figsize=(15, 15))
	plt.axis('off')
	plt.imshow(draw)
	plt.savefig(save_path+image_path,format='JPG',transparent=True,pad_inches=0,dpi=300,bbox_inches='tight')
	#plt.show()

结果挺好的,检测了100张,没有错误,速度也还不错,大概两个小时

6.评测

loss可视化

tensorboard --logdir='/home/zbb/keras-retinanet/CSV/logs' 

对应是文件夹,结果如下:

ap,precision-recall:

没有找到,有会的可以博客下面留言交流

  • 4
    点赞
  • 55
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值