SSD pytorch训练自己的数据集(windows+colab)

对于一个小白,想对自己数据的训练实在不容易,花了好几天时间,翻阅了很多资料,在此做个总结。
我的环境是windows+cpu,没有N卡!本文并将讲解用Colab进行训练
默认已配置好Pytorch环境,目录结构如下,便于核对路径问题
在这里插入图片描述

下载

1、SSD pytorch代码下载https://github.com/amdegroot/ssd.pytorch
github下载慢的话,可以在码云导入再下载

2、VGG预训练模型下载(https://99baiduyun.com/file/1AVCZSsm52-NA4A_uleXYSQ.html)

数据集

使用LabelImg标注数据集github地址有详细教程不再赘述。
我手上的数据集是cancer医学图像,label是txt格式的,转换的方法可见https://blog.csdn.net/weixin_43289424/article/details/106371995

训练代码修改

config.py

因为我的数据集只有一个类别,但还要算上背景,类别+1即共2类
加入以下代码

cancer = {
    'num_classes': 2,
    'lr_steps': (40000, 50000, 60000),#调整学习率的步数
    'max_iter': 60000,#迭代次数,可以先设小测试下能否运行
    'feature_maps': [38, 19, 10, 5, 3, 1],
    'min_dim': 300,
    'steps': [8, 16, 32, 64, 100, 300],
    'min_sizes': [30, 60, 111, 162, 213, 264],
    'max_sizes': [60, 111, 162, 213, 264, 315],
    'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
    'variance': [0.1, 0.2],
    'clip': True,
    'name': 'CANCER',
}

新建cancer.py作为数据读入

对于源代码修改的地方标了######

import os.path as osp
import sys
import torch
import torch.utils.data as data
import cv2
import numpy as np
if sys.version_info[0] == 2:
    import xml.etree.cElementTree as ET
else:
    import xml.etree.ElementTree as ET


CANCER_CLASSES =  ['cancer']#######

CANCER_ROOT = "data/cancer_or_not/"#######


class CANCERAnnotationTransform(object):######
    """Transforms a CANCER annotation into a Tensor of bbox coords and label index
    Initilized with a dictionary lookup of classnames to indexes

    Arguments:
        class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
            (default: alphabetic indexing of MASK's 2 classes)
        keep_difficult (bool, optional): keep difficult instances or not
            (default: False)
        height (int): height
        width (int): width
    """

    def __init__(self, class_to_ind=None, keep_difficult=False):
        # self.class_to_ind = class_to_ind or dict(
        #     zip(CANCER_CLASSES, range(len(CANCER_CLASSES))))
        self.class_to_ind = class_to_ind or dict(cancer=0)#####我这里是一个类别就直接创建字典了
        self.keep_difficult = keep_difficult

    def __call__(self, target, width, height):
        """
        Arguments:
            target (annotation) : the target annotation to be made usable
                will be an ET.Element
        Returns:
            a list containing lists of bounding boxes  [bbox coords, class name]
        """
        res = []
        for obj in target.iter('object'):
            difficult = int(obj.find('difficult').text) == 1
            if not self.keep_difficult and difficult:
                continue
            name = obj.find('name').text.lower().strip()
            bbox = obj.find('bndbox')

            pts = ['xmin', 'ymin', 'xmax', 'ymax']
            bndbox = []
            for i, pt in enumerate(pts):
                cur_pt = int(bbox.find(pt).text) - 1
                # scale height or width
                cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
                bndbox.append(cur_pt)
            label_idx = self.class_to_ind[name]
            bndbox.append(label_idx)
            res += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]
            # img_id = target.find('filename').text[:-4]

        return res  # [[xmin, ymin, xmax, ymax, label_ind], ... ]


class CANCERDetection(data.Dataset):########
    """VOC Detection Dataset Object

    input is image, target is annotation

    Arguments:
        root (string): filepath to VOCdevkit folder.
        image_set (string): imageset to use (eg. 'train', 'val', 'test')
        transform (callable, optional): transformation to perform on the
            input image
        target_transform (callable, optional): transformation to perform on the
            target `annotation`
            (eg: take in caption string, return tensor of word indices)
        dataset_name (string, optional): which dataset to load
            (default: 'VOC2007')
    """
    #image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
    def __init__(self, root,
                 image_sets='trainval',######
                 transform=None, target_transform=CANCERAnnotationTransform(),#######
                 dataset_name='CANCER'):#######
        self.root = root
        self.image_set = image_sets
        self.transform = transform
        self.target_transform = target_transform
        self.name = dataset_name
        self._annopath = osp.join('%s', 'Annotations', '%s.xml')
        self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
        self.ids = list()
        for line in open(CANCER_ROOT+'/ImageSets/Main/'+self.image_set+'.txt'):
          self.ids.append((CANCER_ROOT, line.strip()))#######

    def __getitem__(self, index):
        im, gt, h, w = self.pull_item(index)

        return im, gt

    def __len__(self):
        return len(self.ids)

    def pull_item(self, index):
        img_id = self.ids[index]

        target = ET.parse(self._annopath % img_id).getroot()
        img = cv2.imread(self._imgpath % img_id)
        height, width, channels = img.shape

        if self.target_transform is not None:
            target = self.target_transform(target, width, height)

        if self.transform is not None:
            target = np.array(target)
            img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
            # to rgb
            img = img[:, :, (2, 1, 0)]
            # img = img.transpose(2, 0, 1)
            target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
        return torch.from_numpy(img).permute(2, 0, 1), target, height, width
        # return torch.from_numpy(img), target, height, width

    def pull_image(self, index):
        '''Returns the original image object at index in PIL form

        Note: not using self.__getitem__(), as any transformations passed in
        could mess up this functionality.

        Argument:
            index (int): index of img to show
        Return:
            PIL img
        '''
        img_id = self.ids[index]
        return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)

    def pull_anno(self, index):
        '''Returns the original annotation of image at index

        Note: not using self.__getitem__(), as any transformations passed in
        could mess up this functionality.

        Argument:
            index (int): index of img to get annotation of
        Return:
            list:  [img_id, [(label, bbox coords),...]]
                eg: ('001718', [('dog', (96, 13, 438, 332))])
        '''
        img_id = self.ids[index]
        anno = ET.parse(self._annopath % img_id).getroot()
        gt = self.target_transform(anno, 1, 1)
        return img_id[1], gt

    def pull_tensor(self, index):
        '''Returns the original image at an index in tensor form

        Note: not using self.__getitem__(), as any transformations passed in
        could mess up this functionality.

        Argument:
            index (int): index of img to show
        Return:
            tensorized version of img, squeezed
        '''
        return torch.Tensor(self.pull_image(index)).unsqueeze_(0)

__init__文件

# from .voc0712 import VOCDetection, VOCAnnotationTransform, VOC_CLASSES, VOC_ROOT
from .cancer import CANCERDetection, CANCERAnnotationTransform, CANCER_CLASSES, CANCER_ROOT
# from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT, get_label_map

ssd.py

在这里插入图片描述
在这里插入图片描述

train.py

在这里插入图片描述
在train函数的if结构中加入

    elif args.dataset == 'CANCER':
        # if args.dataset_root == CANCER_ROOT:
        #     parser.error('Must specify dataset if specifying dataset_root')
        cfg = cancer
        dataset = CANCERDetection(root=args.dataset_root,
                               transform=SSDAugmentation(cfg['min_dim'],
                                                         MEANS))

报错:StopIteration,将images, targets = next(batch_iterator)改成

        try:
            images, targets = next(batch_iterator)
        except StopIteration:
            batch_iterator = iter(data_loader)
            images, targets = next(batch_iterator)

报错:IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number,反向传播处将data[0]改成data.item()
在这里插入图片描述
报错xavier_uniform已经被弃用,使用xavier_uniform_代替
在这里插入图片描述

multibox_loss.py

报错:IndexError: The shape of the mask [32, 2990] at index 0 does not match the shape of the indexed tensor [95680, 1] at index 0类似的,解决方法:在这里加上
在这里插入图片描述
这样仍然会报错AttributeError: ‘Tensor’ object has no attribute ‘bool’
注意要使用torch1.3版本以上

报错:UserWarning: size_average and reduce args will be deprecated, please use reduction=‘sum’ instead. warnings.warn(warning.format(ret)),将loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)改成
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction=‘sum’)

验证eval.py代码修改

在这里插入图片描述
在这里插入图片描述

爬坑实录!

RuntimeError: Error(s) in loading state_dict for SSD:
size mismatch for conf.0.weight: copying a param with shape torch.Size([804, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([20, 512, 3, 3]).

在config文件中修改了,还是不管用,没搞懂这份代码怎么调用类别的,无奈之下!直接把build_ssd中的类别修改成数字
在train.py文件下
在这里插入图片描述
eval.py文件下
在这里插入图片描述
AttributeError: ‘Tensor’ object has no attribute 'bool’
这个问题解决只需使用torch1.3版本以上即可!

Legacy autograd function with non-static forward method is deprecated and will be removed in 1.3.
调用eval.py报错!调用torch1.2版本即可!

真的牛批!总而言之!训练用torch1.3以上,验证用torch
1.2版本

使用Colab

把代码文件上传至Google的云盘,直接传输大量文件会卡死,所以先压缩
在这里插入图片描述
在Colab中,先加载进云盘

from google.colab import drive
drive.mount('/content/drive')

进入云盘把压缩文件加载进来

%cd /content/drive/My Drive
!mv ssd_pytorch_cancer.zip /content

解压

%cd /content
!unzip ssd_pytorch_cancer.zip

在这里插入图片描述
把文件移入云盘,方便修改代码,

!mv ssd_pytorch_cancer /content/drive/My\ Drive

运行训练代码

import os
os.chdir("/content/drive/My Drive/ssd_pytorch_cancer")
!python train.py --learning-rate 1e-5

参考https://www.cnblogs.com/xiximayou/p/12546556.html

  • 0
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值