【代码复盘】 找出inference结果中检测困难的图片 find_hard_pics.py

任务

阈值划到0.3或0.35,把所有图片检测一下,找出那些检测困难的图片,把txt收集起来。
检测困难,是指一张图片中,所有目标(或大部分)都检测不出来,这样的图片可能采集来源有点问题,或者有其他问题。我们需要把这样的图片找出来。

代码

思路

如果一张图片的所有目标或大部分目标检测不出来,那么这张图片的Positive中,TP很少,所以我们需要计算出一张图片的TP占gt label的比例,如果比例很小(如小于20%),那么这张图片就属于检测困难的图片。

算法流程如下:对图片img_path的列表进行迭代,对于一张图片,读取图片的预测label txt,得到所有的Positive box,保存在nd array pred中。读取图片的gt label txt,获得图片的gt box,保存在nd array labels中。从labels中得到这张图片的所有gt box的类别。对gt的所有类别进行迭代,对于一个类别,计算该类别的pred box和gt box的IoU。对于每个pred box,将与它IoU最大且IoU大于IoU阈值(这里取0.3)的gt box加入detected,视为已检测。直到所有gt box都已检测时跳出循环。

实现的细节:

  1. 从txt中读取的原始的gt 标签的格式是yolo格式,即中心点的(x, y, w, h),x y为中心点坐标。在pytorch_yolov4代码中,模型输出的box格式是(x1, y1, x2, y2),即左上角坐标和右下角坐标。我们在将output写入nj_dir中时存储的box格式也是(x1, y1, x2, y2)格式。utils.py中计算IoU的代码使用的box格式是(x1, y1, x2, y2)格式。因此需要将gt box由(x, y, w, h)转化为(x1, y1, x2, y2)。
    gt labels的box是x y w h相对于图片的width和height的比值,因此需要乘以w h w h得到真实像素值。而模型输出和写入预测labels txt的box是真实的像素值。
    tbox = utils.xywh2xyxy(labels[:, 1:5]) * whwh

数据结构(参考代码中的实现):

  1. 读取图片的预测label txt后将预测box存储在ndarray pred中。pred的每行表示一个box,这里pred每列存储什么参考了yolov4的test()
  2. 读取图片的gt label txt后将gt box存储在ndarray labels中。labels的每行表示一个box,labels的列代表什么依然参考了yolov4中的数据结构
  3. 对labels切片得到所有gt box的类别array tcls,tcls的意义不仅在于unique过后可以得到该图片包含的所有类别以便进行迭代,对tcls进行索引可以得到某一类别的box所在的行
  4. tbox。对labels切片labels[:, 1:5],并且将gt box的格式和比值都转换成需要的格式,得到tbox。它的作用是,后面拿bounding box计算IoU时使用对tbox的索引。

注意:
使用numpy数组还是使用tensor?一开始我用的是numpy数组,但是中间用了yolov4的一个

代码层面技术细节:

  1. 读取txt的的内容:读取每一行并去掉每一行末尾的'\n'
    只需要读取时open()并不用加’w’,只需以可读模式打开即可
        with open(txt_path) as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                # print(line)
                if i != 0:
                    pl = line.strip().split(",")

其他mode包括:写入 ‘w’ (截断已经存在的文件);排它性创建 ‘x’ ;追加写 ‘a’
创建txt文件时用open(hard_pics_txt, ‘w’),‘w’在文件不存在时可以创建文件并打开,但是’w’会清空文件原有的内容。所以我们创建文件时可以使用’w’,刚好可以在文件已存在的情况下清空内容。创建文件可以用排它性创建 ‘x’,但是在文件已存在时不做改动,这样就不能清空了。
创建文件,向文件写入用f.write(),读取文件用f.read(),f.readline(), f.readlines(),嗯:

    with open(hard_pics_txt, 'w') as f:
        f.write('Path, TP, num_labels\n') 
        

在文件末尾追加行使用 ‘a’。注意一行的末尾要自己手动加上换行符’\n’,不然不会换行的

        if (num_tp / nl) < hard: 
            with open(hard_pics_txt, 'a') as f:
                f.write("{} {} {}\n".format(img_path, num_tp, nl))

代码

import torch
import numpy as np
import os
import cv2
import utils.utils as utils
print(np.__version__)

iou_thres = 0.3
hard = 0.2

def get_name_id_map(classes):
    name_id_map = {}
    for i, name in enumerate(classes):
        name_id_map[name] = i
    print(name_id_map)

    return name_id_map

if __name__ == "__main__":
    pred_dir = "/home/gpu-server/project/PyTorch_YOLOv4/nj_inference"
    gt_dir = "/home/gpu-server/project/data/game_new/mydataset/labels/inference_dir"
    img_dir = "/home/gpu-server/project/data/game_new/mydataset/inference_dir"
    dest_path = "/home/gpu-server/project/data/game_new/mydataset/hard_pics"
    hard_pics_txt = "/home/gpu-server/project/data/game_new/mydataset/hard_pics.txt"

    classes = ['bj_bpmh', 'bj_bpps', 'bj_wkps', 'bj_bpzc', 'bjdsyc', 'bjdszc', 'jyz_pl', 'sly_dmyw', 'hxq_gjtps', 
    'hxq_gjbs', 'hxq_gjzc', 'ywzt_yfyc', 'xmbhyc', 'xmbhzc', 'yw_gkxfw', 'yw_nc', 'gbps', 'wcaqm', 'aqmzc',
    'wcgz_dxdk', 'gzzc_cxck', 'xy', 'kgg_ybh', 'kgg_ybf', 'backup']

    # if os.path.isdir(dest_path):
    #     os.makedirs(dest_path)

    with open(hard_pics_txt, 'w') as f:
        f.write('Path, TP, num_labels\n') 

    name_id_map = get_name_id_map(classes)   

    pred_list = os.listdir(pred_dir)
    for img_txt in pred_list:
        # print("start processing", img_txt, "..")
        txt_path = os.path.join(pred_dir, img_txt)
        gt_path = os.path.join(gt_dir, "game_" + img_txt)
        file_name = os.path.splitext("game_" + img_txt)[0]
        img_path = os.path.join(img_dir, file_name + '.jpg')
        detected = []       

        with open(txt_path) as f:
            lines = f.readlines()
            pred = np.empty([len(lines)-1, 6])
            for i, line in enumerate(lines):
                # print(line)
                if i != 0:
                    pl = line.strip().split(",")
                    id = name_id_map[pl[2]] # int
                    pred[i-1, 5] = id                    
                    pred[i-1, 4] = pl[3]
                    pred[i-1, :4] = pl[-4:]
            pred = pred.astype(np.float32)
            # print(pred)
            # print("pred:", pred.dtype)
        with open(gt_path) as f:
            lines = f.readlines()
            labels = np.empty([len(lines), 5])
            for i, line in enumerate(lines):
                tl = line.strip().split()
                labels[i, 0] = tl[0]
                labels[i, 1:] = tl[1:]
            labels = labels.astype(np.float32)
            # print("labels:",labels)
            # print("labels:", labels.dtype)

        correct = torch.zeros(pred.shape[0], dtype=torch.bool)
        nl = labels.shape[0]
        # print("nl:",nl)
        # print("correct:", correct)
        # print("correct.shape:", correct.shape)
        tcls = labels[:, 0]
        tcls = torch.tensor(tcls)
        img = cv2.imread(img_path)
        height, width, _ = img.shape
        whwh = [width, height, width, height]
        # print(whwh)
        tbox = utils.xywh2xyxy(labels[:, 1:5]) * whwh
        tbox = tbox.astype(np.float32)
        tbox = torch.tensor(tbox)
        pred = torch.tensor(pred)

        for cls in torch.unique(tcls):
            # print("cls:", cls)
            ti = (cls == tcls).nonzero().view(-1)
            pi = (pred[:, 5] == cls).nonzero().view(-1)
            if pi.shape[0]:
                ious, i= utils.box_iou(pred[pi, :4], tbox[ti]).max(1)            
                for j in (ious > iou_thres).nonzero():
                    # print("ti[i[j]]:", ti[i[j]])
                    d = ti[i[j]]
                    # print("d:", d)
                    if d not in detected:
                        detected.append(d)
                        correct[pi[j]] = True
                        if len(detected) == nl:
                            break
        
        # TP按20%算
        num_tp = len(correct.nonzero())
        if (num_tp / nl) < hard: 
            with open(hard_pics_txt, 'a') as f:
                f.write("{} {} {}\n".format(img_path, num_tp, nl))
        else:
            print(img_txt, num_tp / nl)
        # print('finished processing', img_txt, "..")
            

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值