目标检测Tensorflow:Yolo v3代码详解 (2)
三、解析Dataset()数据预处理部分
有了网络结构,我们还不能直接训练,因为,还缺乏对数据的操作,即,我们要如何对网络灌入数据,ground truth 又如何处理等问题,这时候,我们就需要 dataset.py 来为我们分工了。
import os
import cv2
import numpy as np
import tensorflow as tf
import core.utils as utils
from config import cfg
class Dataset(object):
def __init__(self, train_flag=True):
"""
:param train_flag: 是否是训练,默认训练
"""
self.train_flag = train_flag
# 训练数据
if train_flag:
self.data_file_path = cfg.TRAIN.TRAIN_DATA_PATH
self.batch_size = cfg.TRAIN.TRAIN_BATCH_SIZE
pass
# 验证数据
else:
self.data_file_path = cfg.TRAIN.VAL_DATA_PATH
self.batch_size = cfg.TRAIN.VAL_BATCH_SIZE
pass
self.train_input_size_list = cfg.TRAIN.INPUT_SIZE_LIST
self.strides = np.array(cfg.YOLO.STRIDES)
self.classes = utils.read_class_names(cfg.COMMON.CLASS_FILE_PATH)
self.class_num = len(self.classes)
self.anchor_list = utils.get_anchors(cfg.COMMON.ANCHOR_FILE_PATH)
self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE
self.max_bbox_per_scale = cfg.COMMON.MAX_BBOX_PER_SCALE
self.annotations = self.read_annotations()
self.sample_num = len(self.annotations)
self.batch_num = int(np.ceil(self.sample_num / self.batch_size))
self.batch_count = 0
pass
# 迭代器
def __iter__(self):
return self
# 使用迭代器 Dataset() 进行迭代,类似于 for 循环
def __next__(self):
with tf.device("/gpu:0"):
# 从 train_input_size_list 中随机获取一个数值 作为 train_input_size
self.train_input_size = np.random.choice(self.train_input_size_list)
self.train_output_size = self.train_input_size // self.strides
# 构建 输入图像 计算图
batch_image = np.zeros((self.batch_size, self.train_input_size, self.train_input_size, 3))
# 构建 3 个尺度预测图
batch_label_sbbox = np.zeros((self.batch_size, self.train_output_size[0], self.train_output_size[0],
self.anchor_per_scale, 5 + self.class_num))
batch_label_mbbox = np.zeros((self.batch_size, self.train_output_size[1], self.train_output_size[1],
self.anchor_per_scale, 5 + self.class_num))
batch_label_lbbox = np.zeros((self.batch_size, self.train_output_size[2], self.train_output_size[2],
self.anchor_per_scale, 5 + self.class_num))
# 构建每个尺度上最多的 bounding boxes 的图
batch_sbboxes = np.zeros((self.batch_size, self.max_bbox_per_scale, 4))
batch_mbboxes = np.zeros((self.batch_size, self.max_bbox_per_scale, 4))
batch_lbboxes = np.zeros((self.batch_size, self.max_bbox_per_scale, 4))
num = 0
# 是否还在当前的 epoch
if self.batch_count < self.batch_num:
# 这个 while 用于一个 epoch 中的数据一条一条凑够一个 batch_size
while num < self.batch_size:
index = self.batch_count * self.batch_size + num
# 如果最后一个 batch 不够数据,则 从头拿数据来凑
if index >= self.sample_num:
index -= self.sample_num
annotation = self.annotations[index]
image, bboxes = self.parse_annotation(annotation)
label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes = self.preprocess_true_boxes(
bboxes)
batch_image[num, :, :, :] = image
# [batch_size, x_scope, y_scope, iou_flag, 5 + classes]
batch_label_sbbox[num, :, :, :, :] = label_sbbox
batch_label_mbbox[num, :, :, :, :] = label_mbbox
batch_label_lbbox[num, :, :, :, :] = label_lbbox
batch_sbboxes[num, :, :] = sbboxes
batch_mbboxes[num, :, :] = mbboxes
batch_lbboxes[num, :, :] = lbboxes
num += 1
self.batch_count += 1
return batch_image, batch_label_sbbox, batch_label_mbbox, batch_label_lbbox, \
batch_sbboxes, batch_mbboxes, batch_lbboxes
# 下一个 epoch
else:
self.batch_count = 0
np.random.shuffle(self.annotations)
raise StopIteration
pass
pass
# 可以让 len(Dataset()) 返回 self.batch_num 的值
def __len__(self):
return self.batch_num
# 获取 annotations.txt 文件信息
def read_annotations(self):
with open(self.data_file_path) as file:
file_info = file.readlines()
annotation = [line.strip() for line in file_info if len(line.strip().split()[1:]) != 0]
np.random.shuffle(annotation)
return annotation
pass
# 根据 annotation 信息 获取 image 和 bounding boxes
def parse_annotation(self, annotation):
# 将 "./data/images\Anime_180.jpg 388,532,588,729,0 917,154,1276,533,0"
# 根据空格键切成 ['./data/images\\Anime_180.jpg', '388,532,588,729,0', '917,154,1276,533,0']
line = annotation.split()
image_path = line[0]
if not os.path.exists(image_path):
raise KeyError("%s does not exist ... " % image_path)
image = np.array(cv2.imread(image_path))
# 将 bboxes 做成 [[388, 532, 588, 729, 0], [917, 154, 1276, 533, 0]]
bboxes = np.array([list(map(int, box.split(','))) for box in line[1:]])
# 训练数据,进行仿射变换,让训练模型更好
if self.train_flag:
image, bboxes = self.random_horizontal_flip(np.copy(image), np.copy(bboxes))
image, bboxes = self.random_crop(np.copy(image), np.copy(bboxes))
image, bboxes = self.random_translate(np.copy(image), np.copy(bboxes))
image, bboxes = utils.image_preporcess(np.copy(image), [self.train_input_size, self.train_input_size],
np.copy(bboxes))
return image, bboxes
# 随机水平翻转
def random_horizontal_flip(self, image, bboxes):
if np.random.random() < 0.5:
_, w, _ = image.shape
image = image[:, ::-1, :]
bboxes[:, [0, 2]] = w - bboxes[:, [2, 0]]
return image, bboxes
# 随机裁剪
def random_crop(self, image, bboxes):
if np.random.random() < 0.5:
h, w, _ = image.shape
max_bbox = np.concatenate([np.min(bboxes[:, 0:2], axis=0), np.max(bboxes[:, 2:4], axis=0)], axis=-1)
max_l_trans = max_bbox[0]
max_u_trans = max_bbox[1]
max_r_trans = w - max_bbox[2]
max_d_trans = h - max_bbox[3]
crop_xmin = max(0, int(max_bbox[0] - np.random.uniform(0, max_l_trans)))
crop_ymin = max(0, int(max_bbox[1] - np.random.uniform(0, max_u_trans)))
crop_xmax = max(w, int(max_bbox[2] + np.random.uniform(0, max_r_trans)))
crop_ymax = max(h, int(max_bbox[3] + np.random.uniform(0, max_d_trans)))
image = image[crop_ymin: crop_ymax, crop_xmin: crop_xmax]
bboxes[:, [0, 2]] = bboxes[:, [0, 2]] - crop_xmin
bboxes[:, [1, 3]] = bboxes[:, [1, 3]] - crop_ymin
return image, bboxes
# 随机平移: 水平和竖直 方向移动变化,被移走后的位置,数值为0,显示为黑色
def random_translate(self, image, bboxes):
if np.random.random() < 0.5:
h, w, _ = image.shape
max_bbox = np.concatenate([np.min(bboxes[:, 0:2], axis=0), np.max(bbox