find_badcase

代码注释

###库
import math
import pandas as pd
import os
import re
import json
import numpy as np
from PIL import Image,ImageDraw,ImageFont
from tqdm import tqdm
import time
import shutil
from multiprocessing import Pool
import time
import glob
global yolo_image_path
global yolo_txt_path
global original_image_path
global json_file_path
global output_combined_path
global diff_img
global diff_label

####批量创建文件夹
pwd = os.getcwd()
excel_path = pwd + "\\" + "path.xlsx"
df = pd.read_excel(excel_path)

print("Read path successfully")

s = df.iloc[:, 0]

yolo_image_path = s[0]
yolo_txt_path = s[1]
original_image_path = s[2]
json_file_path = s[3]
output_combined_path = s[4]
diff_img = s[5]
diff_label = s[6]

for key in range(4, 7):
    if not os.path.exists(s[key]):
        print("创建文件夹成功")
        os.makedirs(s[key])



classnames= {"car":0,
            "truck":1,
            "trailer":2,
            "bus":3,
            "van" :4,
            "other":5,
            "exlusion":6, # Wrong spelling
            "exclusion":6,
            "occlusion":7,
            "wheel":8,
            }
###正则表达          
def get_frame(path):
    """
    186590   FV0180V9_Label_20200730_202059_065.mf4_color16Bit_186590__G3NWide00
    """
    pattern = re.compile(r"[0-9]{6}")
    return re.findall(pattern, path)[-1]
###########################################################################################################






def merge_side_box(upLeft, lowRight, sideUp, sideLow, view):
    '''
    at the first stage, merge the front view and side view into a large box
    :param upLeft:
    :param lowRight:
    :param sideUp:
    :param sideLow:
    :param view:
    :return: the axis of the merged box
    '''
    x = np.array([upLeft[0], lowRight[0], sideUp[0], sideLow[0]])
    y = np.array([upLeft[1], lowRight[1], sideUp[1], sideLow[1]])
    l_upLeft    = [np.min(x), np.min(y)]
    l_lowRight  = [np.max(x), np.max(y)]
    return l_upLeft, l_lowRight

def encode_box(upLeft, lowRight, category, using_float=False):
    '''
    write the related information into text for training and validation
    the format of the text is:
    index image_path image_w image_h object_type_1 minx_1 miny_1 maxx_1 maxy_1 ... object_type_n minx_n miny_n maxx_n maxy_n
    example:
    0 D:\job\sandbox_fvg3\odet_cn\tools\Demo_SuperB_fv0180v3_20200525_051918_007.mf400_remap_4I_screenRGB888_0111.jpg 1664 512 7 619 176 679 273 7 712 177 771 274 7 755 207 790 263 1 786 237 801 256
    1 D:\job\sandbox_fvg3\odet_cn\tools\aa\Demo_SuperB_fv0180v3_20200525_051918_007.mf400_remap_4I_screenRGB888_0111.jpg 1664 512 7 619 176 679 273 7 712 177 771 274 7 755 207 790 263 1 786 237 801 256

    :param upLeft:
    :param lowRight:
    :param category:
    :param using_float:
    :return:
    '''
    ret = []
    #check the validation of the boxes
    #todo check to use the float or int
    if True:
        if using_float:
            minx = "{:4.2f}".format(upLeft[0])
            miny = "{:4.2f}".format(upLeft[1])
            maxx = "{:4.2f}".format(lowRight[0])
            maxy = "{:4.2f}".format(lowRight[1])
        else:
            minx = "{:d}".format(int(upLeft[0]))
            miny = "{:d}".format(int(upLeft[1]))
            maxx = "{:d}".format(int(lowRight[0]))
            maxy = "{:d}".format(int(lowRight[1]))
    else:
        return ret

    #check the validation of the category
    if category in classnames.keys():
        cate = "{:d}".format(classnames.get(category))
    else:
        return None
        # print("Found the undefined category: {}".format(category))
        # cate = "{:d}".format(g_lc_cate_gen3.get("other"))

    ret = [int(cate), max(0,int(minx)), max(0,int(miny)), min(1664,int(maxx)), min(512,int(maxy))]
    return ret




def parse_json_files(file_path):
    '''
    procee the json file and get the image file path/image size/object boxes with type
    :param file_list:
    :return:
    '''
    boxes       = []
    with open(file_path, encoding='utf-8') as jfile:
        data = json.load(jfile)
        task    = data.get("Task")
        objects = data.get("objects")

        for object in objects:
            ddtypes         = object.get("ddtypes")
            ddAttributes    = object.get("ddAttributes")
            attribute       = dict(zip(ddtypes,ddAttributes))

            shape           = object.get("shape")
            category        = object.get("class")
            upLeft          = object.get("ul")
            lowRight        = object.get("lr")

            if task == "vdet":
                view            = attribute.get("view")
                variation       = attribute.get("variation")
                is_lshape       = attribute.get("lshape")
                if is_lshape:
                    sideUp              = object.get("su")
                    sideLow             = object.get("sl")
                    upLeft, lowRight    = merge_side_box(upLeft, lowRight, sideUp, sideLow, view)
            elif task == "pdet":
                lc              = object.get("lc")
                directions      = attribute.get("directions")
            else:
                pass

            box = encode_box(upLeft, lowRight, category)
            if box:
                boxes.append(box)
                # boxes_draw.append(box)
            else:
                continue



    return boxes
#####################################################################################################################
def bbox_diatance(box1, box2):
    """
        计算IOU
    """
    box1 = box1[1:]

    box2 = box2[1:5]

    b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
    b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
    b1x_c = (b1_x2 + b1_x1) // 2
    b1y_c = (b1_y1 + b1_y2) // 2
    b2x_c = (b2_x1 + b2_x2) // 2
    b2y_c = (b2_y1 + b2_y2) // 2
    distance = math.sqrt((b1x_c - b2x_c) ** 2 + (b1y_c - b2y_c) ** 2)

    return int(distance)


###计算iou
def bbox_iou(box1, box2):
    """
        计算IOU
    """

    box1 = box1[1:]
    #     print(box1)
    box2 = box2[1:5]
    b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
    b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]

    inter_rect_x1 = max(b1_x1, b2_x1)
    inter_rect_y1 = max(b1_y1, b2_y1)
    inter_rect_x2 = min(b1_x2, b2_x2)
    inter_rect_y2 = min(b1_y2, b2_y2)

    inter_area = max(inter_rect_x2 - inter_rect_x1 + 1, 0) * \
                 max(inter_rect_y2 - inter_rect_y1 + 1, 0)

    b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
    b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)

    iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)

    return iou


##判断是否输出
def is_output(json_box, yolo_txt_box, iou_threshold=0.5, distance_threshold=100):
    flag = True
    if len(json_box) != len(yolo_txt_box):
       # print(1)
       # print(json_box,"111111111111111111",yolo_txt_box)
        return False
    if flag:
        for json_box_id in range(len(json_box)):
            distance_list = []
            for yolo_box_id in range(len(yolo_txt_box)):
                distance_list.append(bbox_diatance(json_box[json_box_id], yolo_txt_box[yolo_box_id]))  ##计算欧氏距离
            #             print(distance_list)
            if min(distance_list) >= distance_threshold:
                #print(2)
                return False

            yolo_iou_id = distance_list.index(min(distance_list))  ##找到计算iou的yolo_box
            #             print( yolo_iou_id)
            if json_box[json_box_id][0] != yolo_txt_box[yolo_iou_id][0]:
                #print(3)
                return False
            #             print(bbox_iou(json_box[json_box_id], yolo_txt_box[yolo_iou_id]))
            if bbox_iou(json_box[json_box_id], yolo_txt_box[yolo_iou_id]) <= iou_threshold:
               # print(4)
                return False

    return flag

################################################################################################
classnames_index = ["car",
                    "truck",
                    "trailer",
                    "bus",
                    "van",
                    "other",
                    "exclusion",
                    "occlusion",
                    "wheel"]
font = ImageFont.truetype(
    font='simhei.ttf',
    size=np.floor(3e-2 * 1000 + 0.5).astype('int32')
)  # 获得字体


def hsv_to_rgb(h, s, v):
    if s == 0.0:
        return v, v, v
    i = int(h * 6.0)  # XXX assume int() truncates!
    f = (h * 6.0) - i
    p = v * (1.0 - s)
    q = v * (1.0 - s * f)
    t = v * (1.0 - s * (1.0 - f))
    i = i % 6
    if i == 0:
        return v, t, p
    if i == 1:
        return q, v, p
    if i == 2:
        return p, v, t
    if i == 3:
        return p, q, v
    if i == 4:
        return t, p, v
    if i == 5:
        return v, p, q


hsv_tuples = [(x / len(classnames), 1., 1.)
              for x in range(len(classnames))]

colors = list(map(lambda x: hsv_to_rgb(*x), hsv_tuples))
colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))


def pinjie(img1, img2, save_path, json_box, image_name):
    img1 = Image.open(img1)
    draw1 = ImageDraw.ImageDraw(img1)

    for i in range(len(json_box)):
        cat = json_box[i][0]
        x1, y1, x2, y2 = json_box[i][1:]
        x1 = max(int(x1), 0)
        y1 = max(int(y1), 0)
        x2 = min(int(x2), 1664)
        y2 = min(int(y2), 512)
        draw1.rectangle(((x1, y1), (x2, y2)), fill=None, outline=colors[int(cat)], width=5)
        label = classnames_index[cat]
        label_size = draw1.textsize(label, font)
        y1 = y1 - label_size[1] if y1 - label_size[1] >= 0 else y1 + 5
        x1 = x1 if y1 - label_size[1] >= 0 else x1 + 5

        draw1.text((x1, y1), str(label), fill=colors[int(cat)], font=font)

        draw1.text((750, 0), str("Manual"), fill=(250, 0, 0), font=font)

    im2 = Image.open(img2)
    draw2 = ImageDraw.ImageDraw(im2)

    draw2.text((750, 0), str("YOLO"), fill=(250, 0, 0), font=font)
    # 单幅图像尺寸
    width, height = img1.size

    # 创建空白长图
    result = Image.new(img1.mode, (width, height * 2))

    # 拼接图片

    result.paste(img1, box=(0, 0 * height))
    result.paste(im2, box=(0, 1 * height))

    # 保存图片
    save_image_path = save_path + "\\" + image_name
    # print(save_image_path)
    result.save(save_image_path)
    pass


def get_files(path, suffix=".json"):
    '''
    obtain the json file list in the path recursively
    :param path:
    :param suffix:
    :return: json file list
    '''
    file_list = []
    try:
        if os.path.exists(path):
            # browse each folder
            for home, dirs, files in os.walk(path):
                for file in files:
                    if suffix in file:
                        file_path = os.path.join(home, file)
                        # print("{}".format(file_path))
                        file_list.append(file_path)
    except Exception as e:
        print(e)
    return file_list


######读取json对应的yolo检测结果
def split_squares_in_one_frame_yolo(filepath):
    """
    vals = "type left top right bottom prob\n
            type left top right bottom prob\n
            ..."
    return: [Square, Square, ...]
    """

    vals = open(filepath).readlines()
    res = []
    for line in vals:
        temp = []
        temp = line.split(" ")
        temp[5] = temp[5].strip("\n")
        for i in range(len(temp) - 1):
            temp[i] = int(temp[i])
        res.append(temp)

    return res


def read(filepath):
    """
    return [Square, Square, ...]
    """

    txt_result = split_squares_in_one_frame_yolo(filepath)

    return txt_result











def main(json_file_path):
    boxes_data = parse_json_files(json_file_path)

    json_path = json_file_path ##单个json文件路径
    json_box = boxes_data  ##单张图片的人工bbox
    #     print(json_path)
#########################路径分割#############################################################################################
    front,mid,tail= os.path.basename(json_path).split(".")
######################################################img#################################################################
    img_name = front+".mf400_remap_4I_screenRGB888_" + f"{int(mid.split('_')[2]):07d}" + ".png"
    img_path = original_image_path + "\\" + img_name
    # print(img_name)
    ######################################################txt#################################################################
    txt_name= front+".mf400_remap_4I_screenRGB888_" + f"{int(mid.split('_')[2]):07d}" + ".txt"
    yolo_txt_file_path = yolo_txt_path + "\\" + txt_name
#     print(yolo_txt_file_path)
    yolo_txt_box = read(yolo_txt_file_path)  ##单张图片的yolo_bbox
#     print("yolo_txt_box",yolo_txt_box,"json_box",json_box)
####判断是否符合输出条件
    flag = is_output(json_box, yolo_txt_box)
##输出结果

    yolo_result_path = yolo_image_path + "\\" + img_name

    diff_json_path = diff_label + "\\" + json_path.split("\\")[-1]

    diff_img_path = diff_img + "\\" + img_name
    save_path = output_combined_path
    print(flag)
    if not flag:
        # diff_img_count+=1
        pinjie(img_path, yolo_result_path, save_path, json_box, img_name)

        shutil.copy(img_path, diff_img_path)
        shutil.copy(json_path, diff_json_path)


if __name__ == "__main__":



#########################




########################################
    start=time.time()
    print("running.....")
    filelist = get_files(json_file_path)
    # print(filelist)
    all_img = len(filelist)
    pool = Pool(4)
    # pool.map(main, filelist)


    pbar = tqdm(total=all_img)
    pbar.set_description(' Flow ')
    update = lambda *args: pbar.update()

    res =[pool.apply_async(main, (name,),callback=update) for name in filelist]

    # pool.map_async(main, filelist)
    pool.close()
    pool.join()
    diff_count=len(glob.glob(output_combined_path+"/*.png"))

    print("******************done******************")
    print("total_time:",time.time()-start)
    print("all_img_count:",all_img," diff_count:", diff_count)

    pass
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值