一、挂载数据
筛选小技巧:
可以全部选择数据集,构建数据集,在详情页可以看到每类有多少张图片,不需要一类一类的去查看。
数据集的筛选很重要,注意数据是否存在编码混乱和标注错误
check一下标签是否正确:
import cv2
import os
import numpy as np
import json
import sys
import random
classes = ["0D0004", "0D0005", "0D2007", "0D200B", "0D2010", "0D201A", "0D6201", "0D7001", "0D7098","0D0000","0DOther"]
image_dir = "/root/store-data_new/crop1125/images"
outdir = "/root/store-data_new/data_process/check_label/check_json_images"
label_dir = "/root/store-data_new/crop1125/labels"
images = os.listdir(image_dir)
labels = os.listdir(label_dir)
for i in images:
image = cv2.imread(os.path.join(image_dir,i))
json_file = os.path.join(label_dir,(i.split(".")[0] + ".json"))
with open(json_file) as ann_f:
annos = json.load(ann_f)
img_w = image.shape[1]
img_h = image.shape[0]
objects = annos['objects']
for obj in objects:
xyxy_box = [obj['obj_points'][0]['x'], obj['obj_points'][0]['y'], obj['obj_points'][0]['x'] + obj['obj_points'][0]['w'], obj['obj_points'][0]['y'] + obj['obj_points'][0]['h']]
label = obj['f_code']
cv2.rectangle(image, ((int)(xyxy_box[0]), (int)(xyxy_box[1])), ((int)(xyxy_box[2]), (int)(xyxy_box[3])), (0, 0, 255), 2)
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(image,label,(int(xyxy_box[0])+10, int(xyxy_box[1])+10), font, 1,(0,0,255),2)
cv2.imwrite(os.path.join(outdir,i), image)
print(i,"xml check successfully")
二、数据预处理
1、所有类别数据整合
1.1 检查图片和标签的后缀是否正确
执行脚本,这个脚本在AI平台被注释掉了,一般是存在图片的后缀".jpg",被误写为".JPG",遇到新的数据集则最好运行一下改脚本,以防还存在其他的错误后缀,判断后缀脚本如下:
import os
import shutil
root_dir = "/root/data1125/"
dst_dir = "/root/store-data_new/crop1125/"
# def know_finish(path1):
list_finish = []
for data_class in os.listdir(root_dir):# data_0001
for target in os.listdir(os.path.join(root_dir,data_class)): # images labels
for my_object in os.listdir(os.path.join(os.path.join(root_dir,data_class),target)): # xxx.jpg xxx.json
src_path = os.path.join(os.path.join(os.path.join(root_dir,data_class),target),my_object)
str_object = my_object.split(".")
if str_object[-1] == "jpg" or str_object[-1] == "json":
continue
elif str_object[-1] in list_finish:
continue
else:
list_finish.append(str_object[-1])
print(list_finish)
1.2 将所有类别整合到一起
gather.py的内容如下:
import os
import shutil
root_dir = "/root/data1125/"
dst_dir = "/root/store-data_new/crop1125/"
# def know_finish(path1):
list_finish = []
for data_class in os.listdir(root_dir):# data_0001
for target in os.listdir(os.path.join(root_dir,data_class)): # images labels
for my_object in os.listdir(os.path.join(os.path.join(root_dir,data_class),target)): # xxx.jpg xxx.json
src_path = os.path.join(os.path.join(os.path.join(root_dir,data_class),target),my_object)
str_object = my_object.split(".")
if str_object[-1] == "jpg":
dst_path = dst_dir + "images/"+ my_object
elif str_object[-1] == "json":
dst_path = dst_dir + "labels/"+ my_object
elif str_object[-1] == "JPG":
my_object = str_object[0] + ".jpg"
dst_path = dst_dir + "images/"+ my_object
shutil.copy(src_path,dst_path)
print(dst_path,"copy successfully")
1.3 查看整合前后图片标签数量的变化
一般整合后得数目要小于整合前,因为同一张图片可能包含多个类别的目标
import os
import shutil
root_dir = "/root/data1125/"
image_dst_dir = "/root/store-data_new/crop1125/images/"
json_dst_dir = "/root/store-data_new/crop1125/labels/"
class_list = os.listdir(root_dir)
sum_num = 0
for test_class in class_list:
second_dir = os.path.join(root_dir,test_class)
image_dir = os.path.join(second_dir,"images")
json_dir = os.path.join(second_dir,"labels")
if not os.path.exists(image_dir):
continue
else:
images = os.listdir(image_dir)
jsons = os.listdir(json_dir)
sum_num = sum_num + len(images)
print(test_class,"images num",len(images),"jsons num",len(jsons))
print(sum_num)
print(len(os.listdir(image_dst_dir)),len(os.listdir(json_dst_dir)))
2、格式转换
需要将平台的json格式转换为txt格式
2.1 json转换成VOC格式的xml
config.yaml的内容:
INPUT_IMAGES_DIR: '/root/store-data_new/crop1125/images/'
INPUT_LABELS_DIR: '/root/store-data_new/crop1125/labels/'
OUTPUT_DIR: '/root/store-data_new/crop1125/'
transform2voc.py的内容:
from aicloud import json_to_voc
yamlPath = '/root/store-data_new/data_process/transform_data/config.yaml'
result = json_to_voc.json2voc(yamlPath)
2.2 xml转换成txt
注意:
1、如果后续需要裁剪,则不需要执行归一化函数
2、如果后续不需要裁剪,直接参与训练,则要归一化
import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
sets=[('2007', 'train'), ('2007', 'val'), ('2007', 'test')]
classes = ["0D0001", "0D0004", "0D0005", "0D2007", "0D200B", "0D2010", "0D201A", "0D6201", "0D7001", "0D7098"]
def convert(size, box):
dw = 1./(size[0])
dh = 1./(size[1])
x = (box[0] + box[1])/2.0 - 1
y = (box[2] + box[3])/2.0 - 1
w = box[1] - box[0]
h = box[3] - box[2]
x = x*dw
w = w*dw
y = y*dh
h = h*dh
return (x,y,w,h)
def convert_annotation(year, image_id):
in_file = open('/root/store-data_new/crop1125/VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id))
out_file = open('/root/store-data_new/crop1125/VOCdevkit/VOC%s/labels/%s.txt'%(year, image_id), 'w')
tree=ET.parse(in_file)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult)==1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
bb = convert((w,h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
wd = getcwd()
for year, image_set in sets:
if not os.path.exists('/root/store-data_new/crop1125/VOCdevkit/VOC%s/labels/'%(year)):
os.makedirs('/root/store-data_new/crop1125/VOCdevkit/VOC%s/labels/'%(year))
image_ids = open('/root/store-data_new/crop1125/VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split()
list_file = open('%s_%s.txt'%(year, image_set), 'w')
for image_id in image_ids:
# list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg\n'%(wd, year, image_id))
convert_annotation(year, image_id)
# list_file.close()
# os.system("cat 2007_train.txt 2007_val.txt 2012_train.txt 2012_val.txt > train.txt")
# os.system("cat 2007_train.txt 2007_val.txt 2007_test.txt 2012_train.txt 2012_val.txt > train.all.txt")
有可能会遇到xml报错,这个时候一般是因为报错的xml是空的,有可能是磁盘空间不够,解压之后xml是空的,可以尝试压缩成tar zip,扩大磁盘空间或者腾出空间
3、图片裁剪(可选)
4、check处理过的图片和txt标签
import cv2
import os
import numpy as np
classes = ["0D0004", "0D0005", "0D2007", "0D200B", "0D2010", "0D201A", "0D6201", "0D7001", "0D7098","0D0000","0DOther"]
image_dir = "/root/store-data_new/crop1125/VOCdevkit/VOC2007/JPEGImages"
outdir = "/root/store-data_new/data_process/check_label/check_images"
label_dir = "/root/store-data_new/crop1125/VOCdevkit/VOC2007/labels"
images = os.listdir(image_dir)
labels = os.listdir(label_dir)
for i in images:
image = cv2.imread(os.path.join(image_dir,i))
w = image.shape[1]
h = image.shape[0]
label_path = os.path.join(label_dir,(i.split(".")[0] + ".txt"))
with open(label_path, 'r') as f:
for line in f.readlines():
line_strs = line.split()
cls_id = int(line_strs[0])
x = float(line_strs[1])* w
y = float(line_strs[2]) * h
weight = float(line_strs[3]) * w
height = float(line_strs[4]) * h
xmin = x - weight/2
xmax = x + weight/2
ymin = y - height/2
ymax = y + height/2
# xmin = int(float(line_strs[1]))
# xmax = int(float(line_strs[2]))
# ymin = int(float(line_strs[3]))
# ymax = int(float(line_strs[4]))
cv2.rectangle(image, (int(xmin),int(ymax)), (int(xmax), int(ymin)), (0, 0, 255), 2)
label = classes[cls_id]
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(image,label,(int(xmin)+10, int(ymax)+10), font, 1,(0,0,255),2)
cv2.imwrite(os.path.join(outdir,i), image)
print(i,"check successfully")
三、训练模型
1、crop.cfg的修改
1.1 修改anchor
./darknet detector calc_anchors crop/crop.data -num_of_clusters 9 -width 608 -height 608
1.2 修改类别
同时修改filters和class,filters = 3 * (class+5),有三处
2、crop.data的修改
3、crop.names的修改
4、开始训练
./darknet detector train crop/crop.data crop/crop.cfg backup/yolov4.conv.137 -gpus 0,1 -dont_show -mjpeg_port 8092 -map
./darknet detector train crop/crop.data crop/crop.cfg backup/yolov4.conv.137 -gpus 0 -dont_show -mjpeg_port 8092 -map
5、测试模型
(1) 测试单张图片:
./darknet detector test crop/crop.data crop/crop.cfg backup/crop_best.weights test.jpg -dont_show
(2) 计算mAP:
./darknet detector map crop/crop.data crop/crop.cfg backup/v1.0/crop_best.weights
默认是阈值0.5
计算mAP:
./darknet detector map crop/crop.data crop/crop.cfg backup/v1.0/crop_best.weights -iou_thresh 0.5
(3) 计算recall:
./darknet detector recall crop/crop.data crop/crop.cfg backup/v1.0/crop_best.weights
(4) 批量测试
./darknet detector valid crop/crop.data crop/crop.cfg backup/v1.0/crop_best.weights -dont_show -thresh 0.25
./darknet detector test crop/crop.data crop/crop.cfg crop/crop_best.weights -ext_output -dont_show -out result.json < /root/data/empty/order/voc_2012.txt
测试结果默认保存在results文件夹内,把测试结果可视化为图片的代码如下: