聚类算法
1. 直接聚类后裁剪
- clus_output:聚类后二值图片
- img:原始图片
- output:根据目标坐标和有无目标——》二值图片
- txt:原始标注
- txt_del:删掉有无目标的标注
# -*- coding: UTF-8 -*-
# 聚类代码
import cv2
import os
import sys
from sklearn import cluster
import numpy as np
import matplotlib.pyplot as plt
def dot_Visualization(txt_path, img_path, save_path): # 可视化像素点
name_id = os.listdir(txt_path)
img_id = os.listdir(img_path)
image_total = []
for i, image_id in enumerate(img_id):
data = []
img = cv2.imread(img_path + img_id[i])
w, h = img.shape[0], img.shape[1]
image = np.zeros((w, h), np.uint8)
txt_file = open(txt_path + name_id[i], "r")
for j, line in enumerate(txt_file):
data.append(line)
x = int(data[j].split(',')[0])
y = int(data[j].split(',')[1])
if data[j].split(',')[2] == '1\n' or data[j].split(',')[2] == '1':
cv2.circle(image, (x, y), 7, (255,255,255), -1)
cv2.imwrite(save_path + img_id[i], image)
else:
continue
image_total.append(image)
return image_total
def del_zeros(txt_path, save_del_path): # 去掉坐标后面的0和1,只用坐标来聚类
name_id = os.listdir(txt_path)
data = []
new_lines = []
for i, image_id in enumerate(name_id):
file_ori = open(txt_path + name_id[i], "r")
file_new = open(save_del_path + name_id[i], "w")
lines = file_ori.readlines()
for line in lines:
if line.split(',')[2] == '0' or line.split(',')[2] == '0\n':
del line
else:
if lines[-1] == line:
new_line = line[:-2]
file_new.write(new_line)
else:
new_line = line[:-3] + '\n'
file_new.write(new_line)
def cluster_method(save_del_path, img_path, save_cluster_path, image): # 聚类操作
name_id = os.listdir(save_del_path)
img_id = os.listdir(img_path)
data = []
data_change = []
for i, image_id in enumerate(name_id):
txt_file = open(save_del_path + name_id[i], "r")
for j, xy in enumerate(txt_file):
data.append(xy)
for val in data:
arr = val.split(',')
arr = [int(i) for i in arr]
data_change.append(arr)
data_arr = np.array(data_change)
# import pdb
# pdb.set_trace()
[centroid, label, inertial] = cluster.k_means(data_arr, n_clusters=3)
cluster_1 = centroid[0].astype(np.int).tolist()
cluster_2 = centroid[1].astype(np.int).tolist()
# cluster_1_right = cluster_1[0] + 300
# cluster_1_left = cluster_1[0] - 300
# cluster_1_top = cluster_1[1] + 300
# cluster_1_bottom = cluster_1[1] - 300
# cluster_2_right = cluster_2[0] + 300
# cluster_2_left = cluster_2[0] - 300
# cluster_2_top = cluster_2[1] + 300
# cluster_2_bottom = cluster_2[1] - 300
cluster_1_x1 = cluster_1[0] - 300
cluster_1_y1 = cluster_1[1] + 300
cluster_1_x2 = cluster_1[0] + 300
cluster_1_y2 = cluster_1[1] - 300
cluster_2_x1 = cluster_2[0] - 300
cluster_2_y1 = cluster_2[1] + 300
cluster_2_x2 = cluster_2[0] + 300
cluster_2_y2 = cluster_2[1] - 300
cv2.circle(image[i], (cluster_1[0], cluster_1[1]), 5, (125,125,125), -1)
cv2.circle(image[i], (cluster_2[0], cluster_2[1]), 5, (125,125,125), -1)
cv2.rectangle(image[i], (cluster_1_x1, cluster_1_y1), (cluster_1_x2, cluster_1_y2), (75,75,75), 2)
cv2.rectangle(image[i], (cluster_2_x1, cluster_2_y1), (cluster_2_x2, cluster_2_y2), (75,75,75), 2)
cv2.imwrite(save_cluster_path + img_id[i], image[i])
if __name__ == '__main__':
# txt存放的路径
txt_path = "D:/code/cluster/txt/"
# 原图片路径
img_path = "D:/code/cluster/img/"
# 画出来的图片保存的路径
save_path = "D:/code/cluster/output/"
save_cluster_path = "D:/code/cluster/clus_output/"
save_del_path = "D:/code/cluster/txt_del/"
image = dot_Visualization(txt_path, img_path, save_path)
del_zeros(txt_path, save_del_path)
cluster_method(save_del_path, img_path, save_cluster_path, image)
print("All Done....")
2. 分步骤进行聚类、裁剪、可视化图片
(1)使用DBSCAN进行聚类
final_cluster.py
# -*- coding: UTF-8 -*-
# 聚类代码
import cv2
import os
import sys
from sklearn import cluster
import numpy as np
import matplotlib.pyplot as plt
import time
import copy
from xml.dom.minidom import Document
import scipy.misc as misc
def dot_Visualization(img_data, txt, save_path, idx): # 可视化像素点
image_total = []
data = []
w, h = img_data.shape[0], img_data.shape[1]
image = np.zeros((w, h), np.uint8)
for j, line in enumerate(txt):
data.append(line)
x = int(data[j].split(',')[0])
y = int(data[j].split(',')[1])
cv2.circle(image, (x, y), 2, (255,255,255), -1)
cv2.imwrite(save_path + idx, image)
if line.isspace(): # 判断当前行是不是空
continue
image_total.append(image)
return image_total
def cluster_method(txt, image, img_data, idx): # 聚类操作
start_time = time.time()
data = []
data_change = []
data0_xy = []
data1_xy = []
data2_xy = []
data3_xy = []
# data4_xy = []
# data5_xy = []
center0 = []
center1 = []
center2 = []
center3 = []
# center4 = []
# center5 = []
data_xy = []
for j, xy in enumerate(txt):
data.append(xy)
for val in data:
arr = val.split(',')
arr = [int(i) for i in arr]
data_change.append(arr)
data_arr = np.array(data_change)
# 随机抽取某1000行
row_rand_arr = np.arange(data_arr.shape[0])
np.random.shuffle(row_rand_arr)
row_rand = data_arr[row_rand_arr[0:1000]]
# k-means
# [centroid, label, inertial] = cluster.k_means(data_arr, n_clusters=3)
# print(centroid)
# centroid = centroid.tolist()
# mini-batch kmeans
# centroid = cluster.MiniBatchKMeans(n_clusters=2, batch_size=16).fit(data_arr)
# centroid = centroid.cluster_centers_
# ----------------dbscan——start---------------------
# [centroid, label, inertial] = cluster.dbscan(data_arr)
# centroid = cluster.DBSCAN(eps=40, min_samples=50).fit_predict(row_rand) # clus_output eps=40, min_samples=50
centroid = cluster.DBSCAN(eps=40, min_samples=50).fit_predict(row_rand)
centroid = centroid.tolist()
print(centroid)
data = row_rand.tolist()
for i in range(len(centroid)):
if centroid[i] == 0:
center0.append(data[i])
elif centroid[i] == 1:
center1.append(data[i])
elif centroid[i] == 2:
center2.append(data[i])
elif centroid[i] == 3:
center3.append(data[i])
# elif centroid[i] == 4:
# center4.append(data[i])
# elif centroid[i] == 5:
# center5.append(data[i])
cen0_x, cen0_y, cen1_x, cen1_y = 0, 0, 0, 0
cen2_x, cen2_y, cen3_x, cen3_y = 0, 0, 0, 0
# cen4_x, cen4_y, cen5_x, cen5_y = 0, 0, 0, 0
for c0 in center0:
cen0_x += c0[0]
cen0_y += c0[1]
cen0_x = cen0_x / len(center0)
cen0_y = cen0_y / len(center0)
for c1 in center1:
cen1_x += c1[0]
cen1_y += c1[1]
cen1_x = cen1_x / len(center1)
cen1_y = cen1_y / len(center1)
if center2:
for c2 in center2:
cen2_x += c2[0]
cen2_y += c2[1]
cen2_x = cen2_x / len(center2)
cen2_y = cen2_y / len(center2)
if center3:
for c3 in center3:
cen3_x += c3[0]
cen3_y += c3[1]
cen3_x = cen3_x / len(center3)
cen3_y = cen3_y / len(center3)
# if center4:
# for c4 in center4:
# cen4_x += c4[0]
# cen4_y += c4[1]
# cen4_x = cen4_x / len(center4)
# cen4_y = cen4_y / len(center4)
# if center5:
# for c5 in center5:
# cen5_x += c5[0]
# cen5_y += c5[1]
# cen5_x = cen5_x / len(center5)
# cen5_y = cen5_y / len(center5)
x01 = int(cen0_x) - 300
y01 = int(cen0_y) - 300
x02 = int(cen0_x) + 300
y02 = int(cen0_y) + 300
data0_xy.append([x01, y01, x02, y02])
data0_xy = [x for z in data0_xy for x in z]
print(data0_xy)
x11 = int(cen1_x) - 300
y11 = int(cen1_y) - 300
x12 = int(cen1_x) + 300
y12 = int(cen1_y) + 300
data1_xy.append([x11, y11, x12, y12])
data1_xy = [x for z in data1_xy for x in z]
print(data1_xy)
x21 = int(cen2_x) - 300
y21 = int(cen2_y) - 300
x22 = int(cen2_x) + 300
y22 = int(cen2_y) + 300
data2_xy.append([x21, y21, x22, y22])
data2_xy = [x for z in data2_xy for x in z]
print(data2_xy)
x31 = int(cen3_x) - 300
y31 = int(cen3_y) - 300
x32 = int(cen3_x) + 300
y32 = int(cen3_y) + 300
data3_xy.append([x31, y31, x32, y32])
data3_xy = [x for z in data3_xy for x in z]
print(data3_xy)
# x41 = int(cen4_x) - 300
# y41 = int(cen4_y) - 300
# x42 = int(cen4_x) + 300
# y42 = int(cen4_y) + 300
# data4_xy.append([x41, y41, x42, y42])
# data4_xy = [x for z in data4_xy for x in z]
# print(data4_xy)
# x51 = int(cen5_x) - 300
# y51 = int(cen5_y) - 300
# x52 = int(cen5_x) + 300
# y52 = int(cen5_y) + 300
# data5_xy.append([x51, y51, x52, y52])
# data5_xy = [x for z in data5_xy for x in z]
# print(data5_xy)
cv2.circle(image[0], (int(cen0_x), int(cen0_y)), 5, (125,125,125), -1)
cv2.rectangle(image[0], (x01, y01), (x02, y02), (75,75,75), 2)
cv2.circle(image[0], (int(cen1_x), int(cen1_y)), 5, (125,125,125), -1)
cv2.rectangle(image[0], (x11, y11), (x12, y12), (75,75,75), 2)
cv2.circle(image[0], (int(cen2_x), int(cen2_y)), 5, (125,125,125), -1)
cv2.rectangle(image[0], (x21, y21), (x22, y22), (75,75,75), 2)
cv2.circle(image[0], (int(cen3_x), int(cen3_y)), 5, (125,125,125), -1)
cv2.rectangle(image[0], (x31, y31), (x32, y32), (75,75,75), 2)
# cv2.circle(image[0], (int(cen4_x), int(cen4_y)), 5, (125,125,125), -1)
# cv2.rectangle(image[0], (x41, y41), (x42, y42), (75,75,75), 2)
# cv2.circle(image[0], (int(cen5_x), int(cen5_y)), 5, (125,125,125), -1)
# cv2.rectangle(image[0], (x51, y51), (x52, y52), (75,75,75), 2)
cv2.imwrite(save_cluster_path + idx, image[0])
data_xy.append(data0_xy)
data_xy.append(data1_xy)
data_xy.append(data2_xy)
data_xy.append(data3_xy)
# data_xy.append(data4_xy)
# data_xy.append(data5_xy)
# ----------------dbscan——end---------------------
print("ok!")
end_time = time.time()
print("take %f second" % (end_time - start_time))
return np.array(data_xy)
def text_save(filename, data):#filename为写入txt文件的路径,data为要写入数据列表.
file = open(filename,'w')
from itertools import chain
s = '\n'.join([' '.join(chain([filename[-10:-4] + '0' + str(i)], map(str, j))) for i, j in enumerate(data)]) + '\n'
file.write(s)
file.close()
print(" %s 保存文件成功" % filename[-10:] )
if __name__ == '__main__':
# txt存放的路径
txt_path = "/home/jjliao/cluster/gt_map_copy/"
# 原图片路径
img_path = "/home/jjliao/cluster/gt_img_copy/"
# 画出来的图片保存的路径
save_path = "/home/jjliao/cluster/output/"
save_cluster_path = "/home/jjliao/cluster/clus_output/"
output_path = "/home/jjliao/cluster/cluster_output/"
class_list = ['pedestrian','people','bicycle','car','van','truck','tricycle','awning-tricycle','bus','motor']
images = [i for i in os.listdir(img_path) if '.jpg' in i]
labels = [i for i in os.listdir(txt_path) if 'txt' in i]
print('find image', len(images))
print('find label', len(labels))
width, height = 600, 600
for idx, img in enumerate(images):
start_time = time.time()
print(idx, 'read image', img)
img_data = misc.imread(os.path.join(img_path, img))
txt = open(os.path.join(txt_path, img.replace('jpg', 'txt')), 'r').readlines()
image = dot_Visualization(img_data, txt, save_path, img)
box = cluster_method(txt, image, img_data, img)
text_save(output_path + img.replace('jpg', 'txt'), box)
# clip_image(img.strip('jpg'), img_data, box, width, height) # 代码内裁剪
end_time = time.time()
print(img, "take %f second" % (end_time - start_time))
print("All Done....")
可视化聚类中心及框:
(2)在原图上裁剪,根据聚类中心框出600*600的矩阵,保存标注格式为xml
crop_visdrone.py
import os
from xml.dom.minidom import Document
import copy
import numpy as np
from scipy import misc
import cv2
class_list = [
'ignored regions', 'pedestrian', 'people', 'bicycle', 'car', 'van',
'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor', 'others'
]
raw_data = '/home/jjliao/Visdrone_yolo_cluster/VisDrone2019-DET-train/'
raw_images_dir = os.path.join(raw_data, 'images')
raw_label_dir = os.path.join(raw_data, 'annotations')
input_path = "/home/jjliao/cluster/cluster_output/"
out_images_dir = os.path.join(raw_data, 'images_cluster')
out_label_dir = os.path.join(raw_data, 'annotations_cluster_xml')
def format_label(txt_list):
format_data = []
for i in txt_list[0:]:
format_data.append([int(xy) for xy in i.split(',')[:8]])
return np.array(format_data)
def save_to_xml(save_path,
im_width,
im_height,
objects_axis,
label_name,
name,
hbb=True):
im_depth = 0
object_num = len(objects_axis)
doc = Document()
annotation = doc.createElement('annotataion')
doc.appendChild(annotation)
folder = doc.createElement('folder')
folder_name = doc.createTextNode('Visdrone')
folder.appendChild(folder_name)
annotation.appendChild(folder)
filename = doc.createElement('filename')
filename_name = doc.createTextNode(name)
filename.appendChild(filename_name)
annotation.appendChild(filename)
source = doc.createElement('source')
annotation.appendChild(source)
database = doc.createElement('database')
database.appendChild(doc.createTextNode('The Visdrone Database'))
source.appendChild(database)
annotation_s = doc.createElement('annotation')
annotation_s.appendChild(doc.createTextNode('Visdrone'))
source.appendChild(annotation_s)
image = doc.createElement('image')
image.appendChild(doc.createTextNode('flickr'))
source.appendChild(image)
flickrid = doc.createElement('flickrid')
flickrid.appendChild(doc.createTextNode('322409915'))
source.appendChild(flickrid)
owner = doc.createElement('owner')
annotation.appendChild(owner)
flickrid_o = doc.createElement('flickrid')
flickrid_o.appendChild(doc.createTextNode('knautia'))
owner.appendChild(flickrid_o)
name_o = doc.createElement('name')
name_o.appendChild(doc.createTextNode('yang'))
owner.appendChild(name_o)
size = doc.createElement('size')
annotation.appendChild(size)
width = doc.createElement('width')
width.appendChild(doc.createTextNode(str(im_width)))
height = doc.createElement('height')
height.appendChild(doc.createTextNode(str(im_height)))
depth = doc.createElement('depth')
depth.appendChild(doc.createTextNode(str(im_depth)))
size.appendChild(width)
size.appendChild(height)
size.appendChild(depth)
segmented = doc.createElement('segmented')
segmented.appendChild(doc.createTextNode('0'))
annotation.appendChild(segmented)
for i in range(object_num):
objects = doc.createElement('object')
annotation.appendChild(objects)
object_name = doc.createElement('name')
object_name.appendChild(
doc.createTextNode(label_name[int(objects_axis[i][5])]))
objects.appendChild(object_name)
pose = doc.createElement('pose')
pose.appendChild(doc.createTextNode('Unspecified'))
objects.appendChild(pose)
truncated = doc.createElement('truncated')
truncated.appendChild(doc.createTextNode('1'))
objects.appendChild(truncated)
difficult = doc.createElement('difficult')
difficult.appendChild(doc.createTextNode('0'))
objects.appendChild(difficult)
bndbox = doc.createElement('bndbox')
objects.appendChild(bndbox)
if hbb:
x0 = doc.createElement('xmin')
x0.appendChild(doc.createTextNode(str((objects_axis[i][0]))))
bndbox.appendChild(x0)
y0 = doc.createElement('ymin')
y0.appendChild(doc.createTextNode(str((objects_axis[i][1]))))
bndbox.appendChild(y0)
x1 = doc.createElement('xmax')
x1.appendChild(doc.createTextNode(str((objects_axis[i][2]))))
bndbox.appendChild(x1)
y1 = doc.createElement('ymax')
y1.appendChild(doc.createTextNode(str((objects_axis[i][3]))))
bndbox.appendChild(y1)
else:
x0 = doc.createElement('x0')
x0.appendChild(doc.createTextNode(str((objects_axis[i][0]))))
bndbox.appendChild(x0)
y0 = doc.createElement('y0')
y0.appendChild(doc.createTextNode(str((objects_axis[i][1]))))
bndbox.appendChild(y0)
x1 = doc.createElement('x1')
x1.appendChild(doc.createTextNode(str((objects_axis[i][2]))))
bndbox.appendChild(x1)
y1 = doc.createElement('y1')
y1.appendChild(doc.createTextNode(str((objects_axis[i][3]))))
bndbox.appendChild(y1)
x2 = doc.createElement('x2')
x2.appendChild(doc.createTextNode(str((objects_axis[i][4]))))
bndbox.appendChild(x2)
y2 = doc.createElement('y2')
y2.appendChild(doc.createTextNode(str((objects_axis[i][5]))))
bndbox.appendChild(y2)
x3 = doc.createElement('x3')
x3.appendChild(doc.createTextNode(str((objects_axis[i][6]))))
bndbox.appendChild(x3)
y3 = doc.createElement('y3')
y3.appendChild(doc.createTextNode(str((objects_axis[i][7]))))
bndbox.appendChild(y3)
f = open(save_path, 'w')
f.write(doc.toprettyxml(indent=''))
f.close()
def clip_image(name_new, path_new_xml, img_old, data_new, boxes_all):
name_new, x1, y1, x2, y2 = data_new
if len(boxes_all) > 0:
shape = img_old.shape
width, height = x2 - x1, y2 - y1
# print(width, height)
assert (width == 600 and height == 600)
boxes = copy.deepcopy(boxes_all)
boxes_new = np.zeros_like(boxes_all)
top_left_col, top_left_row = max(x1, 0), max(y1, 0)
bottom_right_col, bottom_right_row = max(x2, 0), max(y2, 0)
img_new = img_old[top_left_row:bottom_right_row, top_left_col:bottom_right_col]
boxes_new[:, 0] = boxes[:, 0] - top_left_col
boxes_new[:, 2] = boxes[:, 0] + boxes[:, 2] - top_left_col
boxes_new[:, 4] = boxes[:, 4]
boxes_new[:, 1] = boxes[:, 1] - top_left_row
boxes_new[:, 3] = boxes[:, 1] + boxes[:, 3] - top_left_row
boxes_new[:, 5] = boxes[:, 5]
center_y = 0.5 * (boxes_new[:, 1] + boxes_new[:, 3])
center_x = 0.5 * (boxes_new[:, 0] + boxes_new[:, 2])
cond1 = np.intersect1d(
np.where(center_y[:] >= 0)[0],
np.where(center_x[:] >= 0)[0])
cond2 = np.intersect1d(
np.where(center_y[:] <= (bottom_right_row - top_left_row))[0],
np.where(center_x[:] <= (bottom_right_col - top_left_col))[0])
idx = np.intersect1d(cond1, cond2)
if len(idx) > 0:
save_to_xml(path_new_xml, img_new.shape[1], img_new.shape[0], boxes_new[idx,:], class_list,
name_new + '.jpg')
if img_new.shape[0] > 5 and img_new.shape[1] > 5:
img = os.path.join(out_images_dir, name_new + '.jpg')
cv2.imwrite(img, img_new)
# return img_new.shape[1], img_new.shape[0], boxes_new
if __name__ == '__main__':
for i in [i for i in os.listdir(input_path) if i[-4:] == '.txt']:
print(i)
with open(os.path.join(input_path, i), 'r', encoding='utf8') as f:
lines = [i.split() for i in f.readlines()]
drawer = {}
for line in lines:
name_new, x1, y1, x2, y2 = line
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
name_old = name_new[:6]
if name_old in drawer:
drawer[name_old].append((name_new, x1, y1, x2, y2))
else:
drawer[name_old] = [(name_new, x1, y1, x2, y2)]
for name_old, datas in drawer.items():
path_old = os.path.join(raw_images_dir, name_old + '.jpg')
img_data = misc.imread(path_old)
txt_data = open(os.path.join(raw_label_dir, name_old + '.txt'),
'r').readlines()
boxes = format_label(txt_data)
for data_new in datas:
name_new, x1, y1, x2, y2 = data_new
path_new = os.path.join(out_images_dir, name_new + '.jpg')
path_new_xml = os.path.join(out_label_dir,name_new + '.xml')
clip_image(name_new, path_new_xml, img_data, data_new, boxes)
# save_to_xml(path_new_xml, w, h, boxes_new, class_list,
# name_new + '.jpg')
(3)可视化裁剪的标注格式是否正确
xml_drawbox_cluster.py
import os
import os.path
import numpy as np
import xml.etree.ElementTree as xmlET
from PIL import Image, ImageDraw
classes = ('__background__', # always index 0
'ignored regions','pedestrian','people','bicycle','car','van','truck','tricycle','awning-tricycle','bus','motor','others')
# 把下面的路径改为自己的路径即可
file_path_img = '/home/jjliao/Visdrone_yolo_cluster/VisDrone2019-DET-train/images_cluster'
file_path_xml = '/home/jjliao/Visdrone_yolo_cluster/VisDrone2019-DET-train/annotations_cluster_xml'
save_file_path = '/home/jjliao/Visdrone_yolo_cluster/VisDrone2019-DET-train/visual_cluster'
pathDir = os.listdir(file_path_xml)
for idx in range(len(pathDir)):
filename = pathDir[idx]
tree = xmlET.parse(os.path.join(file_path_xml, filename))
objs = tree.findall('object')
num_objs = len(objs)
boxes = np.zeros((num_objs, 5), dtype=np.uint16)
for ix, obj in enumerate(objs):
bbox = obj.find('bndbox')
# Make pixel indexes 0-based
x1 = max(float(bbox.find('xmin').text), 0)
y1 = max(float(bbox.find('ymin').text), 0)
x2 = float(bbox.find('xmax').text)
y2 = float(bbox.find('ymax').text)
cla = obj.find('name').text
label = classes.index(cla)
boxes[ix, 0:4] = [x1, y1, x2, y2]
boxes[ix, 4] = label
image_name = os.path.splitext(filename)[0]
# if image_name == '10380201':
# import pdb
# pdb.set_trace()
img = Image.open(os.path.join(file_path_img, image_name + '.jpg'))
draw = ImageDraw.Draw(img)
for ix in range(len(boxes)):
xmin = int(boxes[ix, 0])
ymin = int(boxes[ix, 1])
xmax = int(boxes[ix, 2])
ymax = int(boxes[ix, 3])
draw.rectangle([xmin, ymin, xmax, ymax], outline=(255, 0, 0))
draw.text([xmin, ymin], classes[boxes[ix, 4]], (255, 0, 0))
img.save(os.path.join(save_file_path, image_name + '.jpg'))
3. 将裁剪后的图片用yolov4进行预测,得到result.json文件。再将预测出来的目标进行拼接回原图。
splice.py
# coding:utf-8
import json
import shutil
import cv2
import os
json_old_path = "/home/jjliao/code/PyTorch_yolov4_visdrone_cluster/results_cluster.json"
json_new_path = "/home/jjliao/code/PyTorch_yolov4_visdrone_cluster/results_cluster_new.json"
txt_path = "/home/jjliao/cluster/cluster_output/"
def select(json_old, json_new, txt):
json_old = open(json_old, 'r')
infos = json.load(json_old)
new_json = []
for i in [i for i in os.listdir(txt) if i[-4:] == '.txt']:
print(i)
with open(os.path.join(txt, i), 'r', encoding='utf8') as f:
lines = [i.split() for i in f.readlines()]
for line in lines:
name, x1, y1, x2, y2 = line # x1, y1, x2, y2
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
for j in infos:
images = j["image_id"]
if images == name:
x1_hat, y1_hat = int(j['bbox'][0]), int(j['bbox'][1])
j['bbox'][0], j['bbox'][1] = x1 + x1_hat, y1 + y1_hat
j["image_id"] = images[:-2]
if os.path.exists(json_new_path):
with open(json_new_path, 'w', encoding='utf-8') as ff:
new_json.append(j)
new = json.dump(new_json, ff, indent=4)
ff.close()
print(len(new_json))
if __name__ == "__main__":
select(json_old_path, json_new_path, txt_path)