-------Start
1. 下载代码,数据集,预训练权重
1.1 代码下载

文件解压

1.2 数据集下载

下载以下2个文件,并解压。
xml为数据的标签文件,共8929个文件,至于xml文件中的内容大家可以百度学习下,这里不做讲解。

jpg文件为检测的X光图像,共8929张图像,一张图像对应一个xml文件。

1.3 预训练权重下载

2. 数据集整理
2.1 首先新建SIXray文件,里面包含如下内容:Annotations文件夹里为所有的xml文件,Images文件夹里为所有的jpg文件,ImageSets文件夹里再新建一个main文件夹,labels为空文件夹。


2.2 先修改以下代码的convert_annotation函数中in_file ,out_file 文件路径为你自己数据集所对应的路径。
import xml.etree.ElementTree as ET
import os
from os import getcwd
sets = ['train', 'val', 'test']
classes = ["Gun", "Knife","Wrench","Pliers", "Scissors"]
def convert(size, box):
dw = 1. / (size[0])
dh = 1. / (size[1])
x = (box[0] + box[1]) / 2.0 - 1
y = (box[2] + box[3]) / 2.0 - 1
w = box[1] - box[0]
h = box[3] - box[2]
x = x * dw
w = w * dw
y = y * dh
h = h * dh
return x, y, w, h
def convert_annotation(image_id):
in_file = open('E:/chongda/SIXray/Annotations/%s.xml' % (image_id), encoding='UTF-8')
out_file = open('E:/chongda/SIXray/labels/%s.txt' % (image_id), 'w')
tree = ET.parse(in_file)
root = tree.getroot()
w = int(root.find('size').find('width').text)
h = int(root.find('size').find('height').text)
for obj in root.iter('object'):
try:
cls = obj.find('name').text
except:
continue
if cls not in classes == 1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
float(xmlbox.find('ymax').text))
b1, b2, b3, b4 = b
if b2 > w:
b2 = w
if b4 > h:
b4 = h
b = (b1, b2, b3, b4)
bb = convert((w, h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
def test():
import os
import random
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--xml_path', default=r'E:\chongda\SIXray\labels', type=str, help='input xml label path')
parser.add_argument('--txt_path', default=r'E:\chongda\SIXray\ImageSets\Main', type=str, help='output txt label path')
opt = parser.parse_args()
trainval_percent = 0.9
train_percent = 0.9
xmlfilepath = opt.xml_path
txtsavepath = opt.txt_path
total_xml = os.listdir(xmlfilepath)
if not os.path.exists(txtsavepath):
os.makedirs(txtsavepath)
random.seed(2022)
num = len(total_xml)
list_index = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list_index, tv)
train = random.sample(trainval, tr)
file_trainval = open(txtsavepath + '/trainval.txt', 'w')
file_test = open(txtsavepath + '/test.txt', 'w')
file_train = open(txtsavepath + '/train.txt', 'w')
file_val = open(txtsavepath + '/val.txt', 'w')
for i in list_index:
name = total_xml[i][:-4] + '\n'
if i in trainval:
file_trainval.write(name)
if i in train:
file_train.write(name)
else:
file_val.write(name)
else:
file_test.write(name)
file_trainval.close()
file_train.close()
file_val.close()
file_test.close()
def test2():
for image_set in sets:
image_ids = open(r'E:\chongda\SIXray\ImageSets\Main\%s.txt' % (image_set)).read().strip().split()
list_file = open(r'E:\chongda\SIXray\%s.txt' % (image_set), 'w')
for image_id in image_ids:
list_file.write('E:\chongda\SIXray\images\%s.jpg\n' % (image_id))
list_file.close()
def test3():
path = r'E:\chongda\SIXray\Annotations'
for image_id in os.listdir(path):
convert_annotation(image_id[:-4])
然后执行test3函数,执行完后会在labels文件夹里生成8929个txt文件,如下图所示。这一步操作是将xml标签文件转化成我们所需格式的文件。

2.3 修改以上代码的test函数中的两个路径值为你自己数据集相对应的位置。

修改完毕后,执行test函数,将产生以下文件,这一步操作是划分数据集为训练集,测试集,验证集。
2.4 修改test2函数中的三处路径为你自己数据集所对应的路径。

修改完成后会在SIXray文件夹下生成如下文件。

检查其中文件的内容是否是图像的完整路径。

3. 代码整理
3.1 用pycharm导入解压后的代码文件

在data目录下新建my.yaml文件,文件内容如下所示,此处需要修改train,val,test的路径,值为你的数据集下txt文件的路径。

3.2 修改models下yolov5s.yaml文件中的nc值为5,代表要检测的5个类别。

3.3 将下载好的YOLOv5s模型预训练权重,并放在项目根目录下。

3.4 修改train.py文件中的相应参数,batchsize的值根据显卡的算力进行调整,通常取值为偶数。值越大,要求显卡的算力越高。

3.5 至此,修改完毕,执行train.py进行网络的训练。

3.6 控制台出现以下训练输出日志,表示训练完毕。程序会自动保存最后一次训练完毕的模型权重和最优模型权重。

运行完毕后,会在runs/train/的文件夹下保存每次训练的相关文件,weights存放的是模型的权重。best.pt是最优模型权重,last.pt是训练完毕的权重,其他文件可自行打开查看。

PR图像

训练过程指标变化

混淆矩阵

模型的测试过程请参考下一篇博文。
-------End