01官方下载SSD模型
在GitHub上搜索SSD,pytorch框架下
GitHub - lufficc/SSD: High quality, fast, modular reference implementation of SSD in PyTorch
02配置SSD环境
01自备一个cuda环境
01创建环境
conda create -n torch python=3.8 ---torch是环境名
02激活环境
conda activate torch
03安装cuda
版本网站 https://pytorch.org/get-started/previous-versions/
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge
04检查cuda
import torch
torch.cuda.is_available()
02安装SSD环境
先打开我们创建的torch环境,然后安装SSD环境。
pip install -r requirements.txt
如果哪个包出问题,就把他拿出来单独装。
03设置数据集
01数据文件介绍
新建一个datasets文件夹,路径在SSD-master\datasets。其中子文件夹VOC2007,VOC2007包括以下三个文件夹Annotations,JPEGImages, ImageSets。
数据文件夹如下,
其中Annotations中是标签文件,放置xml文件。
JPEGImages中是影像,放置jpg影像。
ImageSets中是txt索引文件,放置训练验证的文件名称。
分别把所有的xml和jpg文件放到对应文件夹中。
02索引文件生成
这个是生成索引文件的代码,输入xml文件的路径,生成训练测试索引文件。保存在datasets/VOC2007/ImageSets/Main路径下。
import os
import random
# 定义百分比
trainval_percent = 1
train_percent = 0.8
# 路径
xmlfilepath = "Annotations"
txtsavepath = 'ImageSets/Main'
xml_files = os.listdir(xmlfilepath)
num = len(xml_files)
file_list = list(range(num))
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
# 随机划分数据
trainval = random.sample(file_list, tv)
train = random.sample(trainval, tr)
# 创建并写入文本文件
with open(os.path.join(txtsavepath, 'trainval.txt'), 'w') as ftrainval, \
open(os.path.join(txtsavepath, 'test.txt'), 'w') as ftest, \
open(os.path.join(txtsavepath, 'train.txt'), 'w') as ftrain, \
open(os.path.join(txtsavepath, 'val.txt'), 'w') as fval:
for i in file_list:
name = xml_files[i][:-4] + '\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
代码命名为txt.py,放在和jpg,xml文件夹同级目录下。datasets/VOC2007/txt.py
结果生成,
03修改类别名称
路径在SSD-master/ssd/data/datasets/voc.py。修改成自己xml所有的。
04修改配置文件
配置文件路径在configs/vgg_ssd300_voc0712.yaml。修改自己的类别数,其他参数可以任意改。
05 开始训练
终端输入以下训练代码。
python train.py --config-file configs/vgg_ssd300_voc0712.yaml
训练显示开始
报错
D:\learn\sdxx\mbjc\dbsy\SSD-master\ssd\utils\nms.py:10: UserWarning: No NMS is available. Please upgrade torchvision to 0.3.0+
warnings.warn('No NMS is available. Please upgrade torchvision to 0.3.0+')
点击蓝色部分,找到报错位置,将其中的0.3.0改成0.0.0.
if torchvision.__version__ >= '0.3.0':
_nms = torchvision.ops.nms
else:
warnings.warn('No NMS is available. Please upgrade torchvision to 0.3.0+')
sys.exit(-1)