SSD-Pytorch模型训练自己的数据集

1.下载SSD-Pytorch代码

SSD-pytorch代码链接: https://github.com/amdegroot/ssd.pytorch

git clone https://github.com/amdegroot/ssd.pytorch
  1. 运行该代码下载到本地(如果下载太慢可以上传到码云,然后git clone码云地址)

2.准备数据集

  1. 没有数据集的同学可以下载代码自带的VOC和COCO数据集(./data/scripts目录下)

在这里插入图片描述

  1. 有自己的数据集请将数据集放置在./data目录下,例如VOC格式数据集,新建VOCdevkit文件夹,如下图所示,可以参考:https://blog.csdn.net/qq_34806812/article/details/81673798.
  2. 在Annotations中放置所有的标签,在JPEGimages中放置所有的图片,在ImagesSets/Main中放置train.txt/val.txt/test.txt(内容只有图片的名字,例如:00001,00002,不能带后缀jpg或者png)等,可以用脚本自己生成:https://blog.csdn.net/GeekYao/article/details/105074574.

在这里插入图片描述
3.根据自己的数据集修改代码

  1. 博主用的VOC格式的数据集,下面修改都是以VOC格式为例

3.1 config.py

  1. 找到config.py文件,

  2. 打开修改VOC中的num_classes,根据自己的情况修改:classes+1(背景算一类),

  3. 我这里就只有一类,所有是2

  4. 第一次调试最好修改一下max_iter,不然迭代次数太大,要好长时间,其他都是一些超参数,可以占时不修改

博主用的VOC格式的数据集,下面修改都是以VOC格式为例

在这里插入图片描述3.2VOC0712.py
在这里插入图片描述

  1. 根据自己的标签进行修改,博主这里只有一类,所以只有一个dargon fruit(注:如果只有一类,需要加上[ ])

在这里插入图片描述

  1. image_sets中修改一下,根据自己的设置的数据集修改,我这里只有train和val

3.3 train.py
在这里插入图片描述下载预训练模型。VGG16_reducedfc.pth
链接: https://pan.baidu.com/s/1EW9qT0nJkE2dK7thn_kPVw 密码: nw6t
–来自百度网盘超级会员V1的分享
在这里插入图片描述

  1. 根据自己的显存修改batch_size,建议一开始修改小一点,博主1660ti 6G显存

在这里插入图片描述

  1. 将保存训练模型的参数调低一点,之前iter设置的1000,这里设置为500,之后根据自己情况在设置
  2. 顺便修改一下保存的模型名字,也可以之后修改,把COCO改成VOC,博主这里没修改

3.4 eval.py
在这里插入图片描述添加训练好的模型到eval.py,对模型进行验证,我这里训练好的是ssd300_VOC_500.pth
将下面的

args = parser.parse_args()

修改为

args,unknow= parser.parse_known_args()

3.5 SSD.py
在这里插入图片描述

  1. 修改num_classes,跟上面config.py中的一致就行
  2. 修改完成后,运行train.py,完成训练之后,博主运行eval.py验证了训练的模型,AP只有63%,可能是博主数据集太少了

运行eval.py只能看到AP值,想要测试自己的图片,在jupyter notebook中运行demo.ipynb

将对应部分的代码,修改为以下这样即可,注意正确添加图片的路径

image = cv2.imread(’…/data/example3.jpg’, cv2.IMREAD_COLOR) # uncomment if dataset not downloaded
from matplotlib import pyplot as plt
from data import VOCDetection, VOC_ROOT, VOCAnnotationTransform

here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)

#testset = VOCDetection('./data/example1.jpg', [('2020', 'val')], None, VOCAnnotationTransform())
#img_id = 13
#image = testset.pull_image(img_id)
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10,10))
plt.imshow(rgb_image)
plt.show()

在这里插入图片描述在这里插入图片描述可能会存在的问题:

bug1:出现维度不匹配的情况
loc_loss += loss_l.data[0] 报错
在这里插入图片描述解决方法:

  1. 将.data[0]改为.item(),下面print中的也改为loss.item()
  2. 建议参考:https://github.com/amdegroot/ssd.pytorch/issues/421

bug2:自动停止训练
解决方法:
在这里插入图片描述
3. load train data部分修改为如上图所示

bug3:可能会出现pytorch版本带来的影响问题
解决方法:根据提示语句,百度修改即可

bug4:运行eval.py可能会出现pytest这种情况
解决方法:将eval.py中的test_net函数名字修改即可,不能出现test关键字,博主修改为set_net成功运行

bug5:训练出现-nan
解决方法:降低学习率

bug6:出现显存不足的问题Runtimeout
解决方法:降低batch_size

bug7:出现数组索引过多的情况
IndexError: too many indices for array
解决方法:因为有些标注的标签没有数据,所有会出现数组索引出错

如果数据比较多,可以用如下脚本排查是哪个标签出现问题(注意修改自己的标签路径)

import argparse
import sys
import cv2
import os

import os.path          as osp
import numpy            as np

if sys.version_info[0] == 2:
    import xml.etree.cElementTree as ET
else:
    import xml.etree.ElementTree  as ET


parser    = argparse.ArgumentParser(
            description='Single Shot MultiBox Detector Training With Pytorch')
train_set = parser.add_mutually_exclusive_group()

parser.add_argument('--root', default='data/VOCdevkit/VOC2020' , help='Dataset root directory path')

args = parser.parse_args()

CLASSES = [(  # always index 0
    'dargon fruit')]

annopath = osp.join('%s', 'Annotations', '%s.{}'.format("xml"))
imgpath  = osp.join('%s', 'JPEGImages',  '%s.{}'.format("jpg"))

def vocChecker(image_id, width, height, keep_difficult = False):
    target   = ET.parse(annopath % image_id).getroot()
    res      = []

    for obj in target.iter('object'):

        difficult = int(obj.find('difficult').text) == 1

        if not keep_difficult and difficult:
            continue

        name = obj.find('name').text.lower().strip()
        bbox = obj.find('bndbox')

        pts    = ['xmin', 'ymin', 'xmax', 'ymax']
        bndbox = []

        for i, pt in enumerate(pts):

            cur_pt = int(bbox.find(pt).text) - 1
            # scale height or width
            cur_pt = float(cur_pt) / width if i % 2 == 0 else float(cur_pt) / height

            bndbox.append(cur_pt)

        print(name)
        label_idx =  dict(zip(CLASSES, range(len(CLASSES))))[name]
        bndbox.append(label_idx)
        res += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]
        # img_id = target.find('filename').text[:-4]
    print(res)
    try :
        print(np.array(res)[:,4])
        print(np.array(res)[:,:4])
    except IndexError:
        print("\nINDEX ERROR HERE !\n")
        exit(0)
    return res  # [[xmin, ymin, xmax, ymax, label_ind], ... ]

if __name__ == '__main__' :

    i = 0

    for name in sorted(os.listdir(osp.join(args.root,'Annotations'))):
    # as we have only one annotations file per image
        i += 1

        img    = cv2.imread(imgpath  % (args.root,name.split('.')[0]))
        height, width, channels = img.shape
        print("path : {}".format(annopath % (args.root,name.split('.')[0])))
        res = vocChecker((args.root, name.split('.')[0]), height, width)
    print("Total of annotations : {}".format(i))

之作为学习使用不商用
ref:https://blog.csdn.net/weixin_42447868/article/details/105675158

  • 8
    点赞
  • 76
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值