Visdrone数据集 | 根据目标坐标做聚类算法

1. 直接聚类后裁剪

  • clus_output:聚类后二值图片
  • img:原始图片
  • output:根据目标坐标和有无目标——》二值图片
  • txt:原始标注
  • txt_del:删掉有无目标的标注
# -*- coding: UTF-8 -*-
# 聚类代码
import cv2
import os
import sys
from sklearn import cluster
import numpy as np
import matplotlib.pyplot as plt


def dot_Visualization(txt_path, img_path, save_path):  # 可视化像素点
    name_id = os.listdir(txt_path)
    img_id = os.listdir(img_path)
    image_total = []
    for i, image_id in enumerate(img_id):
        data = []
        img = cv2.imread(img_path + img_id[i])
        w, h = img.shape[0], img.shape[1]
        image = np.zeros((w, h), np.uint8)
        txt_file = open(txt_path + name_id[i], "r")
        for j, line in enumerate(txt_file):
            data.append(line)
            x = int(data[j].split(',')[0])
            y = int(data[j].split(',')[1])
            if data[j].split(',')[2] == '1\n' or data[j].split(',')[2] == '1':
                cv2.circle(image, (x, y), 7, (255,255,255), -1)
                cv2.imwrite(save_path + img_id[i], image)
            else:
                continue
        image_total.append(image)
    return image_total

def del_zeros(txt_path, save_del_path):  # 去掉坐标后面的0和1,只用坐标来聚类
    name_id = os.listdir(txt_path)
    data = []
    new_lines = []
    for i, image_id in enumerate(name_id):
        file_ori = open(txt_path + name_id[i], "r")
        file_new = open(save_del_path + name_id[i], "w")
        lines = file_ori.readlines()
        for line in lines:
            if line.split(',')[2] == '0' or line.split(',')[2] == '0\n':
                del line
            else: 
                if lines[-1] == line:
                    new_line = line[:-2]
                    file_new.write(new_line) 
                else:
                    new_line = line[:-3] + '\n' 
                    file_new.write(new_line)
                


def cluster_method(save_del_path, img_path, save_cluster_path, image):  # 聚类操作
    name_id = os.listdir(save_del_path)
    img_id = os.listdir(img_path)
    data = []
    data_change = []
    
    for i, image_id in enumerate(name_id):
        txt_file = open(save_del_path + name_id[i], "r")
        for j, xy in enumerate(txt_file):
            data.append(xy)
        for val in data:
            arr = val.split(',')
            arr = [int(i) for i in arr]
            data_change.append(arr)
            data_arr = np.array(data_change)
            # import pdb
            # pdb.set_trace()

        [centroid, label, inertial] = cluster.k_means(data_arr, n_clusters=3)
        cluster_1 = centroid[0].astype(np.int).tolist()
        cluster_2 = centroid[1].astype(np.int).tolist()
        
        # cluster_1_right = cluster_1[0] + 300
        # cluster_1_left = cluster_1[0] - 300
        # cluster_1_top = cluster_1[1] + 300
        # cluster_1_bottom = cluster_1[1] - 300

        # cluster_2_right = cluster_2[0] + 300
        # cluster_2_left = cluster_2[0] - 300
        # cluster_2_top = cluster_2[1] + 300
        # cluster_2_bottom = cluster_2[1] - 300

        cluster_1_x1 = cluster_1[0] - 300
        cluster_1_y1 = cluster_1[1] + 300
        cluster_1_x2 = cluster_1[0] + 300
        cluster_1_y2 = cluster_1[1] - 300

        cluster_2_x1 = cluster_2[0] - 300
        cluster_2_y1 = cluster_2[1] + 300
        cluster_2_x2 = cluster_2[0] + 300
        cluster_2_y2 = cluster_2[1] - 300

        cv2.circle(image[i], (cluster_1[0], cluster_1[1]), 5, (125,125,125), -1)
        cv2.circle(image[i], (cluster_2[0], cluster_2[1]), 5, (125,125,125), -1)
        cv2.rectangle(image[i], (cluster_1_x1, cluster_1_y1), (cluster_1_x2, cluster_1_y2), (75,75,75), 2)
        cv2.rectangle(image[i], (cluster_2_x1, cluster_2_y1), (cluster_2_x2, cluster_2_y2), (75,75,75), 2)
        cv2.imwrite(save_cluster_path + img_id[i], image[i])
        

if __name__ == '__main__':
    # txt存放的路径
    txt_path = "D:/code/cluster/txt/"
    # 原图片路径 
    img_path = "D:/code/cluster/img/"
    # 画出来的图片保存的路径
    save_path = "D:/code/cluster/output/"
    save_cluster_path = "D:/code/cluster/clus_output/"
    save_del_path = "D:/code/cluster/txt_del/"

    image = dot_Visualization(txt_path, img_path, save_path)
    del_zeros(txt_path, save_del_path)
    cluster_method(save_del_path, img_path, save_cluster_path, image)
    print("All Done....")

2. 分步骤进行聚类、裁剪、可视化图片

(1)使用DBSCAN进行聚类

final_cluster.py

# -*- coding: UTF-8 -*-
# 聚类代码
import cv2
import os
import sys
from sklearn import cluster
import numpy as np
import matplotlib.pyplot as plt
import time
import copy
from xml.dom.minidom import Document
import scipy.misc as misc


def dot_Visualization(img_data, txt, save_path, idx):  # 可视化像素点
    image_total = []
    data = []
    w, h = img_data.shape[0], img_data.shape[1]
    image = np.zeros((w, h), np.uint8)
    for j, line in enumerate(txt):
        data.append(line)
        x = int(data[j].split(',')[0])
        y = int(data[j].split(',')[1])
        cv2.circle(image, (x, y), 2, (255,255,255), -1)
        cv2.imwrite(save_path + idx, image)
        if line.isspace():  # 判断当前行是不是空
            continue
    image_total.append(image)
    return image_total

def cluster_method(txt, image, img_data, idx):  # 聚类操作
    start_time = time.time()
    data = []
    data_change = []
    data0_xy = []
    data1_xy = []
    data2_xy = []
    data3_xy = []
    # data4_xy = []
    # data5_xy = []
    center0 = []
    center1 = []
    center2 = []
    center3 = []
    # center4 = []
    # center5 = []
    data_xy = []
    
    for j, xy in enumerate(txt):
        data.append(xy)
    for val in data:
        arr = val.split(',')
        arr = [int(i) for i in arr]
        data_change.append(arr)
        data_arr = np.array(data_change)
    # 随机抽取某1000行
    row_rand_arr = np.arange(data_arr.shape[0])
    np.random.shuffle(row_rand_arr)
    row_rand = data_arr[row_rand_arr[0:1000]]
    # k-means
    # [centroid, label, inertial] = cluster.k_means(data_arr, n_clusters=3)
    # print(centroid)
    # centroid = centroid.tolist()

    # mini-batch kmeans
    # centroid = cluster.MiniBatchKMeans(n_clusters=2, batch_size=16).fit(data_arr)
    # centroid = centroid.cluster_centers_

    # ----------------dbscan——start---------------------
    # [centroid, label, inertial] = cluster.dbscan(data_arr)
    # centroid = cluster.DBSCAN(eps=40, min_samples=50).fit_predict(row_rand)  # clus_output  eps=40, min_samples=50
    centroid = cluster.DBSCAN(eps=40, min_samples=50).fit_predict(row_rand)  
    centroid = centroid.tolist()
    print(centroid)
    data = row_rand.tolist()
    for i in range(len(centroid)):
        if centroid[i] == 0:
            center0.append(data[i])
        elif centroid[i] == 1:
            center1.append(data[i])
        elif centroid[i] == 2:
            center2.append(data[i])
        elif centroid[i] == 3:
            center3.append(data[i])
        # elif centroid[i] == 4:
        #     center4.append(data[i])
        # elif centroid[i] == 5:
        #     center5.append(data[i])

    cen0_x, cen0_y, cen1_x, cen1_y = 0, 0, 0, 0
    cen2_x, cen2_y, cen3_x, cen3_y = 0, 0, 0, 0
    # cen4_x, cen4_y, cen5_x, cen5_y = 0, 0, 0, 0

    for c0 in center0:
        cen0_x += c0[0]
        cen0_y += c0[1]
    cen0_x = cen0_x / len(center0)
    cen0_y = cen0_y / len(center0)

    for c1 in center1:
        cen1_x += c1[0]
        cen1_y += c1[1]
    cen1_x = cen1_x / len(center1)
    cen1_y = cen1_y / len(center1)

    if center2:
        for c2 in center2:
            cen2_x += c2[0]
            cen2_y += c2[1]
        cen2_x = cen2_x / len(center2)
        cen2_y = cen2_y / len(center2)

    if center3:
        for c3 in center3:
            cen3_x += c3[0]
            cen3_y += c3[1]
        cen3_x = cen3_x / len(center3)
        cen3_y = cen3_y / len(center3)
    
    # if center4:
    #     for c4 in center4:
    #         cen4_x += c4[0]
    #         cen4_y += c4[1]
    #     cen4_x = cen4_x / len(center4)
    #     cen4_y = cen4_y / len(center4)

    # if center5:
    #     for c5 in center5:
    #         cen5_x += c5[0]
    #         cen5_y += c5[1]
    #     cen5_x = cen5_x / len(center5)
    #     cen5_y = cen5_y / len(center5)

    x01 = int(cen0_x) - 300
    y01 = int(cen0_y) - 300
    x02 = int(cen0_x) + 300
    y02 = int(cen0_y) + 300
    data0_xy.append([x01, y01, x02, y02])
    data0_xy = [x for z in data0_xy for x in z]
    print(data0_xy)

    x11 = int(cen1_x) - 300
    y11 = int(cen1_y) - 300
    x12 = int(cen1_x) + 300
    y12 = int(cen1_y) + 300
    data1_xy.append([x11, y11, x12, y12])
    data1_xy = [x for z in data1_xy for x in z]
    print(data1_xy)

    x21 = int(cen2_x) - 300
    y21 = int(cen2_y) - 300
    x22 = int(cen2_x) + 300
    y22 = int(cen2_y) + 300
    data2_xy.append([x21, y21, x22, y22])
    data2_xy = [x for z in data2_xy for x in z]
    print(data2_xy)

    x31 = int(cen3_x) - 300
    y31 = int(cen3_y) - 300
    x32 = int(cen3_x) + 300
    y32 = int(cen3_y) + 300
    data3_xy.append([x31, y31, x32, y32])
    data3_xy = [x for z in data3_xy for x in z]
    print(data3_xy)

    # x41 = int(cen4_x) - 300
    # y41 = int(cen4_y) - 300
    # x42 = int(cen4_x) + 300
    # y42 = int(cen4_y) + 300
    # data4_xy.append([x41, y41, x42, y42])
    # data4_xy = [x for z in data4_xy for x in z]
    # print(data4_xy)

    # x51 = int(cen5_x) - 300
    # y51 = int(cen5_y) - 300
    # x52 = int(cen5_x) + 300
    # y52 = int(cen5_y) + 300
    # data5_xy.append([x51, y51, x52, y52])
    # data5_xy = [x for z in data5_xy for x in z]
    # print(data5_xy)

    cv2.circle(image[0], (int(cen0_x), int(cen0_y)), 5, (125,125,125), -1)
    cv2.rectangle(image[0], (x01, y01), (x02, y02), (75,75,75), 2)
    cv2.circle(image[0], (int(cen1_x), int(cen1_y)), 5, (125,125,125), -1)
    cv2.rectangle(image[0], (x11, y11), (x12, y12), (75,75,75), 2)
    cv2.circle(image[0], (int(cen2_x), int(cen2_y)), 5, (125,125,125), -1)
    cv2.rectangle(image[0], (x21, y21), (x22, y22), (75,75,75), 2)
    cv2.circle(image[0], (int(cen3_x), int(cen3_y)), 5, (125,125,125), -1)
    cv2.rectangle(image[0], (x31, y31), (x32, y32), (75,75,75), 2)
    # cv2.circle(image[0], (int(cen4_x), int(cen4_y)), 5, (125,125,125), -1)
    # cv2.rectangle(image[0], (x41, y41), (x42, y42), (75,75,75), 2)
    # cv2.circle(image[0], (int(cen5_x), int(cen5_y)), 5, (125,125,125), -1)
    # cv2.rectangle(image[0], (x51, y51), (x52, y52), (75,75,75), 2)
    cv2.imwrite(save_cluster_path + idx, image[0])

    data_xy.append(data0_xy)
    data_xy.append(data1_xy)
    data_xy.append(data2_xy)
    data_xy.append(data3_xy)
    # data_xy.append(data4_xy)
    # data_xy.append(data5_xy)
    # ----------------dbscan——end---------------------

    print("ok!")
    end_time = time.time()   
    print("take %f second" % (end_time - start_time))
    return np.array(data_xy)

def text_save(filename, data):#filename为写入txt文件的路径,data为要写入数据列表.
    file = open(filename,'w')
    from itertools import chain
    s = '\n'.join([' '.join(chain([filename[-10:-4] + '0' + str(i)], map(str, j))) for i, j in enumerate(data)]) + '\n'
    file.write(s)
    file.close()
    print(" %s 保存文件成功"  %  filename[-10:] ) 

if __name__ == '__main__':
    # txt存放的路径
    txt_path = "/home/jjliao/cluster/gt_map_copy/"
    # 原图片路径 
    img_path = "/home/jjliao/cluster/gt_img_copy/"
    # 画出来的图片保存的路径
    save_path = "/home/jjliao/cluster/output/"
    save_cluster_path = "/home/jjliao/cluster/clus_output/"
    output_path = "/home/jjliao/cluster/cluster_output/"
    
    class_list = ['pedestrian','people','bicycle','car','van','truck','tricycle','awning-tricycle','bus','motor']
    images = [i for i in os.listdir(img_path) if '.jpg' in i]
    labels = [i for i in os.listdir(txt_path) if 'txt' in i]
    print('find image', len(images))
    print('find label', len(labels))

    width, height = 600, 600
    for idx, img in enumerate(images):
        start_time = time.time()   
        print(idx, 'read image', img)
        img_data = misc.imread(os.path.join(img_path, img))
        
        txt = open(os.path.join(txt_path, img.replace('jpg', 'txt')), 'r').readlines()
        image = dot_Visualization(img_data, txt, save_path, img)
        
        box = cluster_method(txt, image, img_data, img)
        text_save(output_path + img.replace('jpg', 'txt'), box)
        # clip_image(img.strip('jpg'), img_data, box, width, height)  # 代码内裁剪
        end_time = time.time()   
        print(img, "take %f second" % (end_time - start_time))

    print("All Done....")

可视化聚类中心及框:
在这里插入图片描述

(2)在原图上裁剪,根据聚类中心框出600*600的矩阵,保存标注格式为xml

crop_visdrone.py

import os
from xml.dom.minidom import Document
import copy
import numpy as np
from scipy import misc
import cv2

class_list = [
    'ignored regions', 'pedestrian', 'people', 'bicycle', 'car', 'van',
    'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor', 'others'
]

raw_data = '/home/jjliao/Visdrone_yolo_cluster/VisDrone2019-DET-train/'
raw_images_dir = os.path.join(raw_data, 'images')
raw_label_dir = os.path.join(raw_data, 'annotations')
input_path = "/home/jjliao/cluster/cluster_output/"
out_images_dir = os.path.join(raw_data, 'images_cluster')
out_label_dir = os.path.join(raw_data, 'annotations_cluster_xml')


def format_label(txt_list):
    format_data = []
    for i in txt_list[0:]:
        format_data.append([int(xy) for xy in i.split(',')[:8]])
    return np.array(format_data)


def save_to_xml(save_path,
                im_width,
                im_height,
                objects_axis,
                label_name,
                name,
                hbb=True):
    im_depth = 0
    object_num = len(objects_axis)
    doc = Document()

    annotation = doc.createElement('annotataion')
    doc.appendChild(annotation)

    folder = doc.createElement('folder')
    folder_name = doc.createTextNode('Visdrone')
    folder.appendChild(folder_name)
    annotation.appendChild(folder)

    filename = doc.createElement('filename')
    filename_name = doc.createTextNode(name)
    filename.appendChild(filename_name)
    annotation.appendChild(filename)

    source = doc.createElement('source')
    annotation.appendChild(source)

    database = doc.createElement('database')
    database.appendChild(doc.createTextNode('The Visdrone Database'))
    source.appendChild(database)

    annotation_s = doc.createElement('annotation')
    annotation_s.appendChild(doc.createTextNode('Visdrone'))
    source.appendChild(annotation_s)

    image = doc.createElement('image')
    image.appendChild(doc.createTextNode('flickr'))
    source.appendChild(image)

    flickrid = doc.createElement('flickrid')
    flickrid.appendChild(doc.createTextNode('322409915'))
    source.appendChild(flickrid)

    owner = doc.createElement('owner')
    annotation.appendChild(owner)

    flickrid_o = doc.createElement('flickrid')
    flickrid_o.appendChild(doc.createTextNode('knautia'))
    owner.appendChild(flickrid_o)

    name_o = doc.createElement('name')
    name_o.appendChild(doc.createTextNode('yang'))
    owner.appendChild(name_o)

    size = doc.createElement('size')
    annotation.appendChild(size)
    width = doc.createElement('width')
    width.appendChild(doc.createTextNode(str(im_width)))
    height = doc.createElement('height')
    height.appendChild(doc.createTextNode(str(im_height)))
    depth = doc.createElement('depth')
    depth.appendChild(doc.createTextNode(str(im_depth)))
    size.appendChild(width)
    size.appendChild(height)
    size.appendChild(depth)
    segmented = doc.createElement('segmented')
    segmented.appendChild(doc.createTextNode('0'))
    annotation.appendChild(segmented)
    for i in range(object_num):
        objects = doc.createElement('object')
        annotation.appendChild(objects)
        object_name = doc.createElement('name')
        object_name.appendChild(
            doc.createTextNode(label_name[int(objects_axis[i][5])]))
        objects.appendChild(object_name)
        pose = doc.createElement('pose')
        pose.appendChild(doc.createTextNode('Unspecified'))
        objects.appendChild(pose)
        truncated = doc.createElement('truncated')
        truncated.appendChild(doc.createTextNode('1'))
        objects.appendChild(truncated)
        difficult = doc.createElement('difficult')
        difficult.appendChild(doc.createTextNode('0'))
        objects.appendChild(difficult)
        bndbox = doc.createElement('bndbox')
        objects.appendChild(bndbox)
        if hbb:
            x0 = doc.createElement('xmin')
            x0.appendChild(doc.createTextNode(str((objects_axis[i][0]))))
            bndbox.appendChild(x0)
            y0 = doc.createElement('ymin')
            y0.appendChild(doc.createTextNode(str((objects_axis[i][1]))))
            bndbox.appendChild(y0)
            x1 = doc.createElement('xmax')
            x1.appendChild(doc.createTextNode(str((objects_axis[i][2]))))
            bndbox.appendChild(x1)
            y1 = doc.createElement('ymax')
            y1.appendChild(doc.createTextNode(str((objects_axis[i][3]))))
            bndbox.appendChild(y1)
        else:

            x0 = doc.createElement('x0')
            x0.appendChild(doc.createTextNode(str((objects_axis[i][0]))))
            bndbox.appendChild(x0)
            y0 = doc.createElement('y0')
            y0.appendChild(doc.createTextNode(str((objects_axis[i][1]))))
            bndbox.appendChild(y0)

            x1 = doc.createElement('x1')
            x1.appendChild(doc.createTextNode(str((objects_axis[i][2]))))
            bndbox.appendChild(x1)
            y1 = doc.createElement('y1')
            y1.appendChild(doc.createTextNode(str((objects_axis[i][3]))))
            bndbox.appendChild(y1)

            x2 = doc.createElement('x2')
            x2.appendChild(doc.createTextNode(str((objects_axis[i][4]))))
            bndbox.appendChild(x2)
            y2 = doc.createElement('y2')
            y2.appendChild(doc.createTextNode(str((objects_axis[i][5]))))
            bndbox.appendChild(y2)

            x3 = doc.createElement('x3')
            x3.appendChild(doc.createTextNode(str((objects_axis[i][6]))))
            bndbox.appendChild(x3)
            y3 = doc.createElement('y3')
            y3.appendChild(doc.createTextNode(str((objects_axis[i][7]))))
            bndbox.appendChild(y3)

    f = open(save_path, 'w')
    f.write(doc.toprettyxml(indent=''))
    f.close()


def clip_image(name_new, path_new_xml, img_old, data_new, boxes_all):
    name_new, x1, y1, x2, y2 = data_new
    if len(boxes_all) > 0:
        shape = img_old.shape
        width, height = x2 - x1, y2 - y1
        # print(width, height)
        assert (width == 600 and height == 600)
        boxes = copy.deepcopy(boxes_all)
        boxes_new = np.zeros_like(boxes_all)
        top_left_col, top_left_row = max(x1, 0), max(y1, 0)
        bottom_right_col, bottom_right_row = max(x2, 0), max(y2, 0)

        img_new = img_old[top_left_row:bottom_right_row, top_left_col:bottom_right_col]

        boxes_new[:, 0] = boxes[:, 0] - top_left_col
        boxes_new[:, 2] = boxes[:, 0] + boxes[:, 2] - top_left_col
        boxes_new[:, 4] = boxes[:, 4]

        boxes_new[:, 1] = boxes[:, 1] - top_left_row
        boxes_new[:, 3] = boxes[:, 1] + boxes[:, 3] - top_left_row
        boxes_new[:, 5] = boxes[:, 5]

        center_y = 0.5 * (boxes_new[:, 1] + boxes_new[:, 3])
        center_x = 0.5 * (boxes_new[:, 0] + boxes_new[:, 2])

        cond1 = np.intersect1d(
            np.where(center_y[:] >= 0)[0],
            np.where(center_x[:] >= 0)[0])
        cond2 = np.intersect1d(
            np.where(center_y[:] <= (bottom_right_row - top_left_row))[0],
            np.where(center_x[:] <= (bottom_right_col - top_left_col))[0])
        idx = np.intersect1d(cond1, cond2)
        if len(idx) > 0:
            save_to_xml(path_new_xml, img_new.shape[1], img_new.shape[0], boxes_new[idx,:], class_list,
                                name_new + '.jpg')
            if img_new.shape[0] > 5 and img_new.shape[1] > 5:
                img = os.path.join(out_images_dir, name_new + '.jpg')
                cv2.imwrite(img, img_new)
            # return img_new.shape[1], img_new.shape[0], boxes_new


if __name__ == '__main__':
    for i in [i for i in os.listdir(input_path) if i[-4:] == '.txt']:
        print(i)
        with open(os.path.join(input_path, i), 'r', encoding='utf8') as f:
            lines = [i.split() for i in f.readlines()]
            drawer = {}
            for line in lines:
                name_new, x1, y1, x2, y2 = line
                x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
                name_old = name_new[:6]
                if name_old in drawer:
                    drawer[name_old].append((name_new, x1, y1, x2, y2))
                else:
                    drawer[name_old] = [(name_new, x1, y1, x2, y2)]
            for name_old, datas in drawer.items():
                path_old = os.path.join(raw_images_dir, name_old + '.jpg')
                img_data = misc.imread(path_old)
                txt_data = open(os.path.join(raw_label_dir, name_old + '.txt'),
                                'r').readlines()
                boxes = format_label(txt_data)
                for data_new in datas:
                    name_new, x1, y1, x2, y2 = data_new
                    path_new = os.path.join(out_images_dir, name_new + '.jpg')
                    path_new_xml = os.path.join(out_label_dir,name_new + '.xml')
                    clip_image(name_new, path_new_xml, img_data, data_new, boxes)
                    # save_to_xml(path_new_xml, w, h, boxes_new, class_list,
                    #             name_new + '.jpg')

在这里插入图片描述

(3)可视化裁剪的标注格式是否正确

xml_drawbox_cluster.py

import os
import os.path
import numpy as np
import xml.etree.ElementTree as xmlET
from PIL import Image, ImageDraw

classes = ('__background__', # always index 0
           'ignored regions','pedestrian','people','bicycle','car','van','truck','tricycle','awning-tricycle','bus','motor','others')

# 把下面的路径改为自己的路径即可
file_path_img = '/home/jjliao/Visdrone_yolo_cluster/VisDrone2019-DET-train/images_cluster'
file_path_xml = '/home/jjliao/Visdrone_yolo_cluster/VisDrone2019-DET-train/annotations_cluster_xml'
save_file_path = '/home/jjliao/Visdrone_yolo_cluster/VisDrone2019-DET-train/visual_cluster'

pathDir = os.listdir(file_path_xml)
for idx in range(len(pathDir)): 
    filename = pathDir[idx]
    tree = xmlET.parse(os.path.join(file_path_xml, filename))
    objs = tree.findall('object')        
    num_objs = len(objs)
    boxes = np.zeros((num_objs, 5), dtype=np.uint16)

    for ix, obj in enumerate(objs):
        bbox = obj.find('bndbox')
        
        # Make pixel indexes 0-based
        x1 = max(float(bbox.find('xmin').text), 0) 
        y1 = max(float(bbox.find('ymin').text), 0) 
        x2 = float(bbox.find('xmax').text) 
        y2 = float(bbox.find('ymax').text)

        cla = obj.find('name').text
        label = classes.index(cla)

        boxes[ix, 0:4] = [x1, y1, x2, y2]
        boxes[ix, 4] = label

    image_name = os.path.splitext(filename)[0]
    # if image_name == '10380201':
    #     import pdb
    #     pdb.set_trace()
    img = Image.open(os.path.join(file_path_img, image_name + '.jpg'))

    draw = ImageDraw.Draw(img)
    for ix in range(len(boxes)):
        xmin = int(boxes[ix, 0])
        ymin = int(boxes[ix, 1])
        xmax = int(boxes[ix, 2])
        ymax = int(boxes[ix, 3])
        draw.rectangle([xmin, ymin, xmax, ymax], outline=(255, 0, 0))
        draw.text([xmin, ymin], classes[boxes[ix, 4]], (255, 0, 0))

    img.save(os.path.join(save_file_path, image_name + '.jpg'))

在这里插入图片描述

3. 将裁剪后的图片用yolov4进行预测,得到result.json文件。再将预测出来的目标进行拼接回原图。

splice.py

# coding:utf-8
import json
import shutil
import cv2
import os


json_old_path = "/home/jjliao/code/PyTorch_yolov4_visdrone_cluster/results_cluster.json"
json_new_path = "/home/jjliao/code/PyTorch_yolov4_visdrone_cluster/results_cluster_new.json"
txt_path = "/home/jjliao/cluster/cluster_output/"

def select(json_old, json_new, txt):
    json_old = open(json_old, 'r')
    infos = json.load(json_old)
    new_json = []

    for i in [i for i in os.listdir(txt) if i[-4:] == '.txt']:
        print(i)
        with open(os.path.join(txt, i), 'r', encoding='utf8') as f:
            lines = [i.split() for i in f.readlines()]
        for line in lines:
            name, x1, y1, x2, y2 = line # x1, y1, x2, y2
            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)  
            for j in infos:
                images = j["image_id"]
                if images == name:
                    x1_hat, y1_hat = int(j['bbox'][0]), int(j['bbox'][1])
                    j['bbox'][0], j['bbox'][1] = x1 + x1_hat, y1 + y1_hat
                    j["image_id"] = images[:-2]
                    if os.path.exists(json_new_path):
                        with open(json_new_path, 'w', encoding='utf-8') as ff:
                            new_json.append(j)
                            new = json.dump(new_json, ff, indent=4)
                        ff.close()
                    print(len(new_json))



if __name__ == "__main__":
    
    select(json_old_path, json_new_path, txt_path)
  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值