1、安装MMdetection
查看官方的安装文档:https://mmdetection.readthedocs.io/zh_CN/latest/get_started.html#id2
从git上面下载方法,需要注意最后这个命令。
2、准备voc格式数据集
MMdetection支持coco和voc数据集两种格式,也可以自定义数据集。最好使用的是coco格式,我先用的是voc格式,所以介绍一下voc格式。
首先需要了解一下voc数据集的格式:
自己数据集转换为voc数据集格式的代码:
import os
from glob import glob
import cv2
from lxml.etree import Element, SubElement, tostring
import numpy as np
# YOLO格式的txt转VOC格式的xml
def convert(img, box):
name, x, y, w, h = box
img_w = img.shape[0]
img_h = img.shape[1]
x = float(x) * img_h
w = float(w) * img_h
y = float(y) * img_w
h = float(h) * img_w
x = (x * 2 - w) / 2
y = (y * 2 - h) / 2
# print(name)
# print((x,y,w,h))
# exit()
return name, x, y, w, h
# 单个文件转换
def txt_xml(img_name):
clas = []
img = cv2.imdecode(np.fromfile(os.path.join(img_path, img_name), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
# imh, imw = img.shape[0:2]
imh=800
imw=800
# print(img_name[:-4])
# exit()
txt_img = os.path.join(txt_path, img_name[:-4] + '.txt') #txt_path TOD/labels/test img_name[:-4] 014348
with open(txt_img, "r") as f:
for line in f.readlines():
line = line.strip('\n')
list = line.split(" ")
while '' in list:
list.remove('')
# print(list)
# exit()
list = convert(img, list)
clas.append(list)
# print(clas)
# exit()
node_root = Element('annotation')
node_folder = SubElement(node_root, 'folder')
node_folder.text = 'imge'
node_filename = SubElement(node_root, 'filename')
node_filename.text = img_name
node_filepath = SubElement(node_root, 'path')
node_filepath.text = img_path + "\\" + img_name
node_size = SubElement(node_root, 'size')
node_width = SubElement(node_size, 'width')
node_width.text = str(imw)
node_height = SubElement(node_size, 'height')
node_height.text = str(imh)
node_depth = SubElement(node_size, 'depth')
node_depth.text = '3'
for i in range(len(clas)):
node_object = SubElement(node_root, 'object')
node_name = SubElement(node_object, 'name')
node_name.text = str(classes[int(clas[i][0])])
node_pose = SubElement(node_object, 'pose')
node_pose.text = "Unspecified"
node_truncated = SubElement(node_object, 'truncated')
node_truncated.text = "truncated"
node_difficult = SubElement(node_object, 'difficult')
node_difficult.text = '0'
node_bndbox = SubElement(node_object, 'bndbox')
node_xmin = SubElement(node_bndbox, 'xmin')
node_xmin.text = str(clas[i][1])
node_ymin = SubElement(node_bndbox, 'ymin')
node_ymin.text = str(clas[i][2])
node_xmax = SubElement(node_bndbox, 'xmax')
node_xmax.text = str(clas[i][1] + clas[i][3])
node_ymax = SubElement(node_bndbox, 'ymax')
node_ymax.text = str(clas[i][2] + clas[i][4])
xml = tostring(node_root, pretty_print=True) # 格式化显示,该换行的换行
img_newxml = os.path.join(xml_path, img_name[:-4] + '.xml')
file_object = open(img_newxml, 'wb')
file_object.write(xml)
file_object.close()
# 批量转换
def generate_label_file():
# 获取所有图片
imgs = []
jpg_imgs = glob("{}/*.jpg".format(img_path))
# png_imgs = glob("{}/*.png".format(img_path))
imgs.extend(jpg_imgs)
# imgs.extend(png_imgs)
for img in imgs:
# print(img.split(os.sep)[-1])
# exit()
txt_xml(img.split(os.sep)[-1]) #img->TOD/images/test/014348.jpg img.split(os.sep)[-1]->014348.jpg
if __name__ == "__main__":
# 图片路径、txt文件路径、xml文件保存路径
img_path = r"TOD/JPEGImages"
txt_path = r"TOD/Annotations"
xml_path = r"TOD/Annotations2"
# 类别列表
classes = ["a", "b"]
# 执行转换
generate_label_file()
3、修改配置文件
(1)修改数据集配置文件:
(2)修改模型:这里的模型是你需要跑的实验的模型
4、开始训练