目标检测Tensorflow:Yolo v3代码详解 (2)

本文详细解析YOLOv3在TensorFlow中的数据预处理过程,包括Dataset()的功能,以及模型训练的两个阶段。在训练完成后,介绍了模型冻结、测试和转换为PB格式的步骤,以供部署使用。
摘要由CSDN通过智能技术生成

三、解析Dataset()数据预处理部分

有了网络结构,我们还不能直接训练,因为,还缺乏对数据的操作,即,我们要如何对网络灌入数据,ground truth 又如何处理等问题,这时候,我们就需要 dataset.py 来为我们分工了。
在这里插入图片描述

import os
import cv2
import numpy as np
import tensorflow as tf
import core.utils as utils
from config import cfg
 
 
class Dataset(object):
 
    def __init__(self, train_flag=True):
        """
        :param train_flag: 是否是训练,默认训练
        """
        self.train_flag = train_flag
 
        # 训练数据
        if train_flag:
            self.data_file_path = cfg.TRAIN.TRAIN_DATA_PATH
            self.batch_size = cfg.TRAIN.TRAIN_BATCH_SIZE
            pass
        # 验证数据
        else:
            self.data_file_path = cfg.TRAIN.VAL_DATA_PATH
            self.batch_size = cfg.TRAIN.VAL_BATCH_SIZE
            pass
 
        self.train_input_size_list = cfg.TRAIN.INPUT_SIZE_LIST
        self.strides = np.array(cfg.YOLO.STRIDES)
        self.classes = utils.read_class_names(cfg.COMMON.CLASS_FILE_PATH)
        self.class_num = len(self.classes)
        self.anchor_list = utils.get_anchors(cfg.COMMON.ANCHOR_FILE_PATH)
        self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE
        self.max_bbox_per_scale = cfg.COMMON.MAX_BBOX_PER_SCALE
 
        self.annotations = self.read_annotations()
        self.sample_num = len(self.annotations)
        self.batch_num = int(np.ceil(self.sample_num / self.batch_size))
        self.batch_count = 0
        pass
 
    # 迭代器
    def __iter__(self):
        return self
 
    # 使用迭代器 Dataset() 进行迭代,类似于 for 循环
    def __next__(self):
        with tf.device("/gpu:0"):
            # 从 train_input_size_list 中随机获取一个数值 作为 train_input_size
            self.train_input_size = np.random.choice(self.train_input_size_list)
            self.train_output_size = self.train_input_size // self.strides
 
            # 构建 输入图像 计算图
            batch_image = np.zeros((self.batch_size, self.train_input_size, self.train_input_size, 3))
 
            # 构建 3 个尺度预测图
            batch_label_sbbox = np.zeros((self.batch_size, self.train_output_size[0], self.train_output_size[0],
                                          self.anchor_per_scale, 5 + self.class_num))
            batch_label_mbbox = np.zeros((self.batch_size, self.train_output_size[1], self.train_output_size[1],
                                          self.anchor_per_scale, 5 + self.class_num))
            batch_label_lbbox = np.zeros((self.batch_size, self.train_output_size[2], self.train_output_size[2],
                                          self.anchor_per_scale, 5 + self.class_num))
 
            # 构建每个尺度上最多的 bounding boxes 的图
            batch_sbboxes = np.zeros((self.batch_size, self.max_bbox_per_scale, 4))
            batch_mbboxes = np.zeros((self.batch_size, self.max_bbox_per_scale, 4))
            batch_lbboxes = np.zeros((self.batch_size, self.max_bbox_per_scale, 4))
 
            num = 0
            # 是否还在当前的 epoch
            if self.batch_count < self.batch_num:
                # 这个 while 用于一个 epoch 中的数据一条一条凑够一个 batch_size
                while num < self.batch_size:
                    index = self.batch_count * self.batch_size + num
                    # 如果最后一个 batch 不够数据,则 从头拿数据来凑
                    if index >= self.sample_num:
                        index -= self.sample_num
                    annotation = self.annotations[index]
                    image, bboxes = self.parse_annotation(annotation)
                    label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes = self.preprocess_true_boxes(
                        bboxes)
 
                    batch_image[num, :, :, :] = image
 
                    # [batch_size, x_scope, y_scope, iou_flag, 5 + classes]
                    batch_label_sbbox[num, :, :, :, :] = label_sbbox
                    batch_label_mbbox[num, :, :, :, :] = label_mbbox
                    batch_label_lbbox[num, :, :, :, :] = label_lbbox
 
                    batch_sbboxes[num, :, :] = sbboxes
                    batch_mbboxes[num, :, :] = mbboxes
                    batch_lbboxes[num, :, :] = lbboxes
 
                    num += 1
 
                self.batch_count += 1
 
                return batch_image, batch_label_sbbox, batch_label_mbbox, batch_label_lbbox, \
                       batch_sbboxes, batch_mbboxes, batch_lbboxes
            # 下一个 epoch
            else:
                self.batch_count = 0
                np.random.shuffle(self.annotations)
                raise StopIteration
            pass
        pass
 
    # 可以让 len(Dataset()) 返回 self.batch_num 的值
    def __len__(self):
        return self.batch_num
 
    # 获取 annotations.txt 文件信息
    def read_annotations(self):
        with open(self.data_file_path) as file:
            file_info = file.readlines()
            annotation = [line.strip() for line in file_info if len(line.strip().split()[1:]) != 0]
            np.random.shuffle(annotation)
            return annotation
        pass
 
    # 根据 annotation 信息 获取 image 和 bounding boxes
    def parse_annotation(self, annotation):
        # 将 "./data/images\Anime_180.jpg 388,532,588,729,0 917,154,1276,533,0"
        # 根据空格键切成 ['./data/images\\Anime_180.jpg', '388,532,588,729,0', '917,154,1276,533,0']
        line = annotation.split()
        image_path = line[0]
        if not os.path.exists(image_path):
            raise KeyError("%s does not exist ... " % image_path)
        image = np.array(cv2.imread(image_path))
        # 将 bboxes 做成 [[388, 532, 588, 729, 0], [917, 154, 1276, 533, 0]]
        bboxes = np.array([list(map(int, box.split(','))) for box in line[1:]])
 
        # 训练数据,进行仿射变换,让训练模型更好
        if self.train_flag:
            image, bboxes = self.random_horizontal_flip(np.copy(image), np.copy(bboxes))
            image, bboxes = self.random_crop(np.copy(image), np.copy(bboxes))
            image, bboxes = self.random_translate(np.copy(image), np.copy(bboxes))
 
        image, bboxes = utils.image_preporcess(np.copy(image), [self.train_input_size, self.train_input_size],
                                               np.copy(bboxes))
        return image, bboxes
 
    # 随机水平翻转
    def random_horizontal_flip(self, image, bboxes):
 
        if np.random.random() < 0.5:
            _, w, _ = image.shape
            image = image[:, ::-1, :]
            bboxes[:, [0, 2]] = w - bboxes[:, [2, 0]]
 
        return image, bboxes
 
    # 随机裁剪
    def random_crop(self, image, bboxes):
 
        if np.random.random() < 0.5:
            h, w, _ = image.shape
            max_bbox = np.concatenate([np.min(bboxes[:, 0:2], axis=0), np.max(bboxes[:, 2:4], axis=0)], axis=-1)
 
            max_l_trans = max_bbox[0]
            max_u_trans = max_bbox[1]
            max_r_trans = w - max_bbox[2]
            max_d_trans = h - max_bbox[3]
 
            crop_xmin = max(0, int(max_bbox[0] - np.random.uniform(0, max_l_trans)))
            crop_ymin = max(0, int(max_bbox[1] - np.random.uniform(0, max_u_trans)))
            crop_xmax = max(w, int(max_bbox[2] + np.random.uniform(0, max_r_trans)))
            crop_ymax = max(h, int(max_bbox[3] + np.random.uniform(0, max_d_trans)))
 
            image = image[crop_ymin: crop_ymax, crop_xmin: crop_xmax]
 
            bboxes[:, [0, 2]] = bboxes[:, [0, 2]] - crop_xmin
            bboxes[:, [1, 3]] = bboxes[:, [1, 3]] - crop_ymin
 
        return image, bboxes
 
    # 随机平移: 水平和竖直 方向移动变化,被移走后的位置,数值为0,显示为黑色
    def random_translate(self, image, bboxes):
 
        if np.random.random() < 0.5:
            h, w, _ = image.shape
            max_bbox = np.concatenate([np.min(bboxes[:, 0:2], axis=0), np.max(bbox
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值