前言:关于数据的对接,这一部分有很大的不同,因为使用的代码不一样,在此做一下使用记录。初始版本的数据接口在这一句:
all_imgs, classes_count, class_mapping = get_data(options.train_path)
我在get_data.py中发现all_imgs包含图片的宽和高这两句,因此,在标注框生成的时候,加入了宽和高的返回值。
每个变量的数据构成如下:
- all_imgs:路径,宽,高,标注框的分类、属性,所归属的训练集
- classes_count:总共的类别及每一个类别所包含的图片数
- class_mapping:类别及其数字标签
据此,我根据pascal_voc_parser.py重新写了一个读取数据的文件,具体怎么用,你理解一下train.py代码,很容易就明白了。
准备:
初始版本源码地址:https://github.com/yhenon/keras-frcnn
最新版本源码地址https://github.com/fizyr/keras-retinanet.
数据:使用的生成标注框后的分类数据
代码:
注:本文中的注释可能不是很规范,具体的话可以自己理解一下。
import os
import numpy as np
import tensorflow as tf
def get_data(input_path):
all_imgs = []
classes_count = {}
class_mapping = {}
# 解析train_lable_rpn图片数据
print('Parsing train_lable_rpn.txt')
# 遍历读取数据
hd = tf.gfile.FastGFile(input_path, "r")
for line in hd.readlines():
# 对所取得的信息按空格进行切分
lineinfo = line.split(" ")
# 取取图片路径
pic_path = lineinfo[0]
# 取图片的类别
pic_class = lineinfo[1]
# 取图片的宽和高
pic_scale = lineinfo[2].split(",")
# 将图片信息存储在annotation_data中
annotation_data = {'filepath': pic_path, 'width': int(pic_scale[0]),
'height': int(pic_scale[1]), 'bboxes': [], 'imageset':'trainval'}
# 统计类别信息及每一类的数量
if pic_class not in classes_count:
classes_count[pic_class] = 1
else:
classes_count[pic_class] += 1
# 统计类别及每一个类所对应的标签
if pic_class not in class_mapping:
class_mapping[pic_class] = len(class_mapping)
# 提取预选框信息
num_objs = int(lineinfo[3])
objs = range(num_objs)
boxes = np.zeros((num_objs, 4), dtype=np.uint16)
# 遍历提取预选框
for ix, obj in enumerate(objs):
bbox = lineinfo[4+obj] # 预选框从第三个位置开始
bbox = bbox.split(",")
x1, y1, x2, y2 = [float(pos)-1 if int(pos) >= 1 else float(pos) for pos in bbox]
boxes[ix, :] = [x1, y1, x2, y2]
difficulty = 1
annotation_data['bboxes'].append(
{'class': pic_class, 'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'difficult': difficulty})
all_imgs.append(annotation_data)
return all_imgs, classes_count, class_mapping
# if __name__ == '__main__':
# get_data('E:/wangr/ZOUZHEN/WorkSpace/VSCode/MachineLearning/DeepLearning/Pet_Dog_Identify/data_bases/enhance/train_lable_rpn.txt')
至于train.py中的代码,我并没有做太大的修改,仅仅添加了默认的读取训练数据的路径,修改了解析读取数据的get_data函数,也就是我现在编写的这个函数。当然,我借鉴了很多大神的代码,并不完全是我自己写的—。—,
如有错误,还请告知改正!