YOLO数据读取代码解析

源码来源

https://www.cnblogs.com/zkweb/p/14403833.html

原文作者的文章很nice,受益匪浅,我画蛇添足一下,对数据读取部分的代码做一个更详细地解析。

def prepare():
    """准备训练"""
    # 数据集转换到 tensor 以后会保存在 data 文件夹下
    if not os.path.isdir("data"):
        os.makedirs("data")

    # 加载图片和图片对应的区域与分类列表
    # { (路径, 是否左右翻转): [ 区域与分类, 区域与分类, .. ] }
    # 同一张图片左右翻转可以生成一个新的数据,让数据量翻倍
    box_map = defaultdict(lambda: []) ##defaultdict可以在没key的时候,查询key的时候返回的是[]
    for filename in os.listdir(DATASET_1_IMAGE_DIR):
        # 从第一个数据集加载
        xml_path = os.path.join(DATASET_1_ANNOTATION_DIR, filename.split(".")[0] + ".xml")
        if not os.path.isfile(xml_path):
            continue
        tree = ET.ElementTree(file=xml_path)
        objects = tree.findall("object")
        path = os.path.join(DATASET_1_IMAGE_DIR, filename)
        for obj in objects:
            class_name = obj.find("name").text
            x1 = int(obj.find("bndbox/xmin").text)
            x2 = int(obj.find("bndbox/xmax").text)
            y1 = int(obj.find("bndbox/ymin").text)
            y2 = int(obj.find("bndbox/ymax").text)
            if class_name == "mask_weared_incorrect":
                # 佩戴口罩不正确的样本数量太少 (只有 123),模型无法学习,这里全合并到戴口罩的样本
                class_name = "with_mask"
            box_map[(path, False)].append((x1, y1, x2-x1, y2-y1, CLASSES_MAPPING[class_name]))
            box_map[(path, True)].append((x1, y1, x2-x1, y2-y1, CLASSES_MAPPING[class_name]))
    df = pandas.read_csv(DATASET_2_BOX_CSV_PATH)
    for row in df.values:
        # 从第二个数据集加载,这个数据集只包含没有带口罩的图片
        filename, width, height, x1, y1, x2, y2 = row[:7]
        path = os.path.join(DATASET_2_IMAGE_DIR, filename)
        ## False和True注意上方的注释,True的话就将这张图左右翻转一下
        box_map[(path, False)].append((x1, y1, x2-x1, y2-y1, CLASSES_MAPPING["without_mask"]))
        box_map[(path, True)].append((x1, y1, x2-x1, y2-y1, CLASSES_MAPPING["without_mask"]))
    # 打乱数据集 (因为第二个数据集只有不戴口罩的图片)
    box_list = list(box_map.items())
    random.shuffle(box_list)
    print(f"found {len(box_list)} images")

    # 保存图片和图片对应的分类与区域列表
    batch_size = 20
    batch = 0
    image_tensors = [] # 图片列表
    result_tensors = [] # 图片对应的输出结果列表,包含 [ 是否对象中心, 区域偏移, 各个分类的可能性 ]
    result_isobject_masks = [] # 各个图片的包含对象的区域在 Anchors 中的索引
    result_nonobject_masks = [] # 各个图片不包含对象的区域在 Anchors 中的索引 (重叠率低于阈值的区域)
    for (image_path, flip), original_boxes_labels in box_list:
        with Image.open(image_path) as img_original: # 加载原始图片
            sw, sh = img_original.size # 原始图片大小
            if flip:
                ## 自定义resize_image方法,主要做的内容是用空白填充出一个符合(256,192)比例的图
                ## 然后再用reshape来进行缩放
                ## 这样子缩放的时候可以保证原图片的比例没有变化
                img = resize_image(img_original.transpose(Image.FLIP_LEFT_RIGHT)) # 翻转然后缩放图片
            else:
                img = resize_image(img_original) # 缩放图片
            image_tensors.append(image_to_tensor(img)) # 添加图片到列表
        # 生成输出结果的 tensor
        ## size是(锚点数,1+4+分类数量)
        result_tensor = torch.zeros((len(MyModel.Anchors), MyModel.AnchorOutputs), dtype=torch.float)
        result_tensor[:,5] = 1 # 默认分类为 other
        result_tensors.append(result_tensor)
        # 包含对象的区域在 Anchors 中的索引
        result_isobject_mask = []
        result_isobject_masks.append(result_isobject_mask)
        # 不包含对象的区域在 Anchors 中的索引
        result_nonobject_mask = []
        result_nonobject_masks.append(result_nonobject_mask)
        # 根据真实区域定位所属的锚点,然后设置输出结果
        negative_mapping = [1] * len(MyModel.Anchors)
        for box_label in original_boxes_labels:
            x, y, w, h, label = box_label
            if flip: # 翻转坐标
                x = sw - x - w
            ## 前面提到原图需要缩放到256,198的尺寸
            ## 而此时的标记框xywh是原图的,需要将其映射回缩放后的
            x, y, w, h = map_box_to_resized_image((x, y, w, h), sw, sh) # 缩放实际区域
            if w < 20 or h < 20:
                continue # 缩放后区域过小
            # 检查计算是否有问题
            # child_img = img.copy().crop((x, y, x+w, y+h))
            # child_img.save(f"{os.path.basename(image_path)}_{x}_{y}_{w}_{h}_{label}.png")
            # 定位所属的锚点
            # 要求:
            # - 中心点落在锚点对应的区域中
            # - 重叠率超过一定值
            x_center = x + w // 2
            y_center = y + h // 2
            matched_anchors = []
            for index, anchor in enumerate(MyModel.Anchors):
                ax, ay, aw, ah = anchor
                is_center = (x_center >= ax and x_center < ax + aw and
                    y_center >= ay and y_center < ay + ah)
                iou = calc_iou(anchor, (x, y, w, h))
                if is_center and iou > IOU_POSITIVE_THRESHOLD:
                    matched_anchors.append((index, anchor)) # 区域包含对象中心并且重叠率超过一定值
                    negative_mapping[index] = 0
                elif iou > IOU_NEGATIVE_THRESHOLD:
                    negative_mapping[index] = 0 # 区域与某个对象重叠率超过一定值,不应该当作负样本
            for matched_index, matched_box in matched_anchors:
                # 计算区域偏移
                offset = calc_box_offset(matched_box, (x, y, w, h))
                # 修改输出结果的 tensor
                result_tensor[matched_index] = torch.tensor((
                    1, # 是否对象中心
                    *offset, # 区域偏移
                    *[int(c == label) for c in range(len(CLASSES))] # 对应分类
                ), dtype=torch.float)
                # 添加索引值
                # 注意如果两个对象同时定位到相同的锚点,那么只有一个对象可以被识别,这里后面的对象会覆盖前面的对象
                if matched_index not in result_isobject_mask:
                    result_isobject_mask.append(matched_index)
        # 没有找到可识别的对象时跳过图片
        if not result_isobject_mask:
            image_tensors.pop()
            result_tensors.pop()
            result_isobject_masks.pop()
            result_nonobject_masks.pop()
            continue
        # 添加不包含对象的区域在 Anchors 中的索引
        for index, value in enumerate(negative_mapping):
            if value:
                result_nonobject_mask.append(index)
        # 排序索引列表
        result_isobject_mask.sort()
        # 保存批次
        if len(image_tensors) >= batch_size:
            prepare_save_batch(batch, image_tensors, result_tensors,
                result_isobject_masks, result_nonobject_masks)
            image_tensors.clear()
            result_tensors.clear()
            result_isobject_masks.clear()
            result_nonobject_masks.clear()
            batch += 1
    # 保存剩余的批次
    if len(image_tensors) > 10:
        prepare_save_batch(batch, image_tensors, result_tensors,
            result_isobject_masks, result_nonobject_masks)

def resize_image(img):
    """缩放图片,比例不一致时填充"""
    sw, sh = img.size
    sw_new, sh_new, pad_w, pad_h = calc_resize_parameters(sw, sh)
    img_new = Image.new("RGB", (sw_new, sh_new))
    ## 从pad_w,pad_h这个坐标开始粘贴img
    img_new.paste(img, (pad_w, pad_h))
    img_new = img_new.resize(IMAGE_SIZE)
    return img_new
def calc_resize_parameters(sw, sh):
    """计算缩放图片的参数"""
    sw_new, sh_new = sw, sh
    dw, dh = IMAGE_SIZE
    pad_w, pad_h = 0, 0
    if sw / sh < dw / dh:
        sw_new = int(dw / dh * sh)
        pad_w = (sw_new - sw) // 2 # 填充左右
    else:
        sh_new = int(dh / dw * sw)
        pad_h = (sh_new - sh) // 2 # 填充上下
    return sw_new, sh_new, pad_w, pad_h
def map_box_to_resized_image(box, sw, sh):
    """把原始区域转换到缩放后的图片对应的区域"""
    x, y, w, h = box
    sw_new, sh_new, pad_w, pad_h = calc_resize_parameters(sw, sh)
    scale = IMAGE_SIZE[0] / sw_new
    x = int((x + pad_w) * scale)
    y = int((y + pad_h) * scale)
    w = int(w * scale)
    h = int(h * scale)
    if x + w > IMAGE_SIZE[0] or y + h > IMAGE_SIZE[1] or w == 0 or h == 0:
        return 0, 0, 0, 0
    return x, y, w, h
def calc_iou(rect1, rect2):
    """计算两个区域重叠部分 / 合并部分的比率 (intersection over union)"""
    x1, y1, w1, h1 = rect1
    x2, y2, w2, h2 = rect2
    xi = max(x1, x2)
    yi = max(y1, y2)
    wi = min(x1+w1, x2+w2) - xi
    hi = min(y1+h1, y2+h2) - yi
    if wi > 0 and hi > 0: # 有重叠部分
        area_overlap = wi*hi
        area_all = w1*h1 + w2*h2 - area_overlap
        iou = area_overlap / area_all
    else: # 没有重叠部分
        iou = 0
    return iou

代码中双##的部分就是我的注释,原文的作者代码已经足够详细了,但是作为初学者还是有些地方看不懂,所以就再补充一点。

  • 0
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

rglkt

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值