pytorch使用torchvision自带fasterrcnn模型训练与测试(Pascal Voc与Coco数据集)

参考项目地址:https://github.com/lpuglia/torchvision_voc
参考链接:
[1]https://github.com/pytorch/vision/issues/1116
[2]https://pytorch.org/docs/stable/_modules/torchvision/models/detection/faster_rcnn.html
[3]https://pytorch.org/tutorials/beginner/data_loading_tutorial.html


2020-08-16更新
采用torchvision版本的faster rcnn模型训练自定义数据集(COCO数据集就格式)已经更新,代码托管在https://github.com/ouening/torchvision-FasterRCNN,下文做的修改目的是支持Pascal VOC格式,现已更新至支持COCO格式,PASCAL VOC至COCO格式的转换脚本亦已提供,用COCO格式数据集的好处是可以用pycocotools的评价指标,指标更加丰富。

本项目地址: https://github.com/ouening/MLPractice
项目文件结构:
在这里插入图片描述
原项目工程只提供Pascal数据集和coco数据集的训练方法代码,为实现Pascal格式的自定义数据集,需要额外添加相关函数, 添加的函数以及其他改动有:
① voc_eval.py: custom_voc_eval()(该函数是冗余添加的,使用默认的voc_eval()函数也是可以的), _do_python_eval_custom_voc()
② engine.py: custom_voc_evaluate()
③ train.py: 添加两个参数选项:--train-data-path--test-ddata-path,用于设置自定义数据集路径

  • parser.add_argument('--train-data-path', help='train dataset path for custom voc dataset')
  • parser.add_argument('--test-data-path', help='test dataset path for custom voc dataset')

④ voc_utils.py: class ConvertCustomVOCtoCOCO(), get_custom_voc(), class VOCCustomData()

上面改动中新增的两个类ConvertCustomVOCtoCOCOVOCCustomData是加载自定义数据集的关键代码,参考了原项目ConvertVOCtoCOCO类以及pytorch官方VOCDetection类的实现,后面在小节内容中详细介绍.

1.数据集与文件修改

训练使用的数据识别的类别有2类

1.1 voc_utils.py文件修改

1.1.1 添加类VOCCustomData

class VOCCustomData(torchvision.datasets.vision.VisionDataset):
    """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.

    Args:
        root (string): Root directory of the custom VOC Dataset which includes directories
            Annotations and JPEGImages

        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, required): A function/transform that takes in the
            target and transforms it.
    """

    def __init__(self,
                 root,
                 transform=None,
                 target_transform=None,
                 transforms=None):
        super(VOCCustomData, self).__init__(root, transforms, transform, target_transform)
        self.root = root
        self._transforms = transforms

        voc_root = self.root
        self.image_dir = os.path.join(voc_root, 'JPEGImages')
        self.annotation_dir = os.path.join(voc_root, 'Annotations')

        if not os.path.isdir(voc_root):
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' Please verify the correct Dataset!')
        file_names = []

        for imgs in os.listdir(self.image_dir):
            file_names.append(imgs.split('.')[0])
        
        images_file = pd.DataFrame(file_names,index=None)    
        # 保存图像路径,注意只有文件名,不带后缀和文件路径
        images_file.to_csv(voc_root+'/imagesetfile.txt',header=False,index=False)  

        self.images = [os.path.join(self.image_dir, x + ".jpg") for x in file_names]
        self.annotations = [os.path.join(self.annotation_dir, x + ".xml") for x in file_names]
        assert (len(self.images) == len(self.annotations))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is a dictionary of the XML tree.
        """
        img = Image.open(self.images[index]).convert('RGB')
        
        target = self.parse_voc_xml(
            ET.parse(self.annotations[index]).getroot())
        
        target = dict(image_id=index, annotations=target['annotation'])

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.images)

    def parse_voc_xml(self, node):
        voc_dict = {}
        children = list(node)
        if children:
            def_dic = collections.defaultdict(list)
            for dc in map(self.parse_voc_xml, children):
                for ind, v in dc.items():
                    def_dic[ind].append(v)
            voc_dict = {
                node.tag:
                    {ind: v[0] if len(v) == 1 else v
                     for ind, v in def_dic.items()}
            }
        if node.text:
            text = node.text.strip()
            if not children:
                voc_dict[node.tag] = text
        return voc_dict

1.1.2 添加类ConvertCustomVOCtoCOCO

class ConvertCustomVOCtoCOCO(object):
    # def __init__(self, class):

    CLASSES = (
        "__background__", "lost", "normal"
    )
    def __call__(self, image, target):
        # return image, target
        anno = target['annotations']
        filename = anno["filename"].split('.')[0]
        h, w = anno['size']['height'], anno['size']['width']
        boxes = []
        classes = []
        ishard = []
        objects = anno['object']
        if not isinstance(objects, list):
            objects = [objects]
        for obj in objects:
            bbox = obj['bndbox']
            bbox = [int(bbox[n]) - 1 for n in ['xmin', 'ymin', 'xmax', 'ymax']]
            boxes.append(bbox)
            classes.append(self.CLASSES.index(obj['name']))
            ishard.append(int(obj['difficult']))

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        classes = torch.as_tensor(classes)
        ishard = torch.as_tensor(ishard)

        target = {}
        target["boxes"] = boxes
        target["labels"] = classes
        target["ishard"] = ishard
        target['name'] = torch.tensor([ord(i) for i in list(filename)], dtype=torch.int8) #convert filename in int8

        return image, target

def get_custom_voc(root, transforms):
    t = [ConvertCustomVOCtoCOCO()]

    if transforms is not None:
        t.append(transforms)
    transforms = T.Compose(t)

    dataset = VOCCustomData(root=root,transforms=transforms)

    return dataset

1.2 voc_eval文件修改

1.2.1 添加函数custom_voc_eval()

def custom_voc_eval(classname,
             detpath,
             imagesetfile,
             annopath='',
             ovthresh=0.5,
             use_07_metric=False):
    """rec, prec, ap = voc_eval(detpath,
                                annopath,
                                imagesetfile,
                                classname,
                                [ovthresh],
                                [use_07_metric])

    Top level function that does the PASCAL VOC evaluation.

    detpath: Path to detections
        detpath.format(classname) should produce the detection results file.
    annopath: Path to annotations xml标准文件路径,一般在Annotations里面
        annopath.format(imagename) should be the xml annotations file.
    imagesetfile: Text file containing the list of images, one image per line.只包含图片名称的文本文件
    classname: Category name (duh)
    [ovthresh]: Overlap threshold (default = 0.5)
    [use_07_metric]: Whether to use VOC07's 11 point AP computation
        (default False)
    """
    # assumes detections are in detpath.format(classname)
    # assumes annotations are in annopath.format(imagename)
    # assumes imagesetfile is a text file with each line an image name

    recs = {}
    # read list of images
    with open(imagesetfile, 'r') as f:
        lines = f.readlines()
        imagenames = [x.strip() for x in lines]

        # load annotations
        for i, imagename in enumerate(imagenames):
            recs[imagename] = parse_rec(annopath.format(imagename))

    # extract gt objects for this class
    class_recs = {}
    npos = 0
    for imagename in imagenames:
        R = [obj for obj in recs[imagename] if obj['name'] == classname]
        bbox = np.array([x['bbox'] for x in R])
        difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
        det = [False] * len(R)
        npos = npos + sum(~difficult)
        class_recs[imagename] = {'bbox': bbox,
                                 'difficult': difficult,
                                 'det': det}

    # read dets
    detfile = detpath.format(classname)
    with open(detfile, 'r') as f:
        lines = f.readlines()

    splitlines = [x.strip().split(' ') for x in lines]
    image_ids = [x[0] for x in splitlines]
    confidence = np.array([float(x[1]) for x in splitlines])
    BB = np.array([[float(z) for z in x[2:]] for x in splitlines])

    nd = len(image_ids)
    tp = np.zeros(nd)
    fp = np.zeros(nd)

    if BB.shape[0] > 0:
      # sort by confidence
      sorted_ind = np.argsort(-confidence)
      sorted_scores = np.sort(-confidence)
      BB = BB[sorted_ind, :]
      image_ids = [image_ids[x] for x in sorted_ind]

      # go down dets and mark TPs and FPs
      for d in range(nd):
        R = class_recs[image_ids[d]]
        bb = BB[d, :].astype(float)
        ovmax = -np.inf
        BBGT = R['bbox'].astype(float)

        if BBGT.size > 0:
            # compute overlaps
            # intersection
            ixmin = np.maximum(BBGT[:, 0], bb[0])
            iymin = np.maximum(BBGT[:, 1], bb[1])
            ixmax = np.minimum(BBGT[:, 2], bb[2])
            iymax = np.minimum(BBGT[:, 3], bb[3])
            iw = np.maximum(ixmax - ixmin + 1., 0.)
            ih = np.maximum(iymax - iymin + 1., 0.)
            inters = iw * ih

            # union
            uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
                 (BBGT[:, 2] - BBGT[:, 0] + 1.) *
                 (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)

            overlaps = inters / uni
            ovmax = np.max(overlaps)
            jmax = np.argmax(overlaps)

        if ovmax > ovthresh:
            if not R['difficult'][jmax]:
                if not R['det'][jmax]:
                    tp[d] = 1.
                    R['det'][jmax] = 1
                else:
                    fp[d] = 1.
        else:
            fp[d] = 1.

    # compute precision recall
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)
    rec = tp / float(npos)
    # avoid divide by zero in case the first detection matches a difficult
    # ground truth
    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
    ap = voc_ap(rec, prec, use_07_metric)

    return rec, prec, ap

1.2.2 添加函数_write_custom_voc_results_file

def _write_custom_voc_results_file(data_loader,all_boxes, image_index, root, classes, thread=0.3):
    if os.path.exists('/tmp/results'):
        shutil.rmtree('/tmp/results')
    os.makedirs('/tmp/results')
    print('Writing results file', end='\r')

    os.makedirs("output", exist_ok=True)    # 创建output目录,存储图片检测结果
    # Bounding-box colors
    # cmap = plt.get_cmap("tab20b")
    # colors = [cmap(i) for i in np.linspace(0, 1, 20)]
    colors = [(255,0,0),(0,255,0),(0,0,255)]

    for cls_ind, cls  in enumerate(classes):
        # DistributeSampler happens to clone the inputs to make the task 
        # lenghts even among the nodes:
        # https://github.com/pytorch/pytorch/issues/22584
        # Boxes can be duplicated in the process since multiple
        # evaluation of the same image can happen, multiple boxes in the
        # same location decrease the final mAP, later in the code we discard
        # repeated image_index thanks to the sorting
        new_image_index, all_boxes[cls_ind] = zip(*sorted(zip(image_index,
                                 all_boxes[cls_ind]), key=lambda x: x[0]))
        if cls == '__background__':
            continue
        images_dir = data_loader.dataset.image_dir
        filename = '/tmp/results/det_test_{:s}.txt'.format(cls)
        

        with open(filename, 'wt') as f:
            prev_index = ''
            for im_ind, index in enumerate(new_image_index):
                # opencv读取图片
                img = cv2.imread(os.path.join(images_dir,index+'.jpg'))
                h, w, _ = img.shape

                # check for repeated input and discard
                if prev_index == index: continue
                prev_index = index
                dets = all_boxes[cls_ind][im_ind]
                if dets == []:
                    continue
                dets = dets[0]
                
                bbox_colors = random.sample(colors, 3)

                # the VOCdevkit expects 1-based indices
                for k in range(dets.shape[0]):
                    f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
                            format(index, dets[k, -1],
                                   dets[k, 0] + 1, dets[k, 1] + 1,
                                   dets[k, 2] + 1, dets[k, 3] + 1))
                    if dets[k, -1]<thread:
                        continue
                    # print("\t+ Label: %s, Conf: %.5f" % (cls, dets[k, -1]))
                    x1, x2 = dets[k, 0], dets[k, 2]
                    y1, y2 = dets[k, 1], dets[k, 3]

                    color = colors[cls_ind]
                    thick = int((h + w) / 300)
                    cv2.rectangle(img,
                                    (x1, y1), (x2, y2),
                                    color, thick)
                    mess = '%s: %.3f' % (cls, dets[k, -1])
                    cv2.putText(img, mess, (x1, y1 - 12),
                                0, 1e-3 * h, color, thick // 3)
                
                filename = index
                cv2.imwrite(f"output/output-{filename}.png", img)

1.2.3 添加函数_do_python_eval_custom_voc

def _do_python_eval_custom_voc(data_loader,use_07_metric=True):

    imagesetfile = os.path.join(data_loader.dataset.root,'imagesetfile.txt')
    annopath = os.path.join(data_loader.dataset.annotation_dir,'{:s}.xml')

    classes = data_loader.dataset._transforms.transforms[0].CLASSES
    aps = []
    fig = plt.figure()

    for cls in classes:
        if cls == '__background__':    
            continue    
        filename = '/tmp/results/det_test_{:s}.txt'.format(cls)    
        rec, prec, ap = custom_voc_eval(cls, filename, imagesetfile, annopath,
                            ovthresh=0.5, use_07_metric=use_07_metric)    
        print('+ Class {} - AP: {}'.format(cls, ap))
        plt.plot(rec, prec, label='{}'.format(cls))
        aps += [ap]
    plt.xlabel('recall')
    plt.ylabel('precision')
    plt.legend()
    plt.show()
    print('Mean AP = {:.4f}        '.format(np.mean(aps)))

1.3 engine.py文件修改

1.3.1 添加函数custom_voc_evaluate()

@torch.no_grad()
def custom_voc_evaluate(model, data_loader, device):
    n_threads = torch.get_num_threads()
    torch.set_num_threads(1)
    cpu_device = torch.device("cpu")
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    all_boxes = [[] for i in range(21)]
    image_index = []
    for image, targets in metric_logger.log_every(data_loader, 100, header):
        image = list(img.to(device) for img in image)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        torch.cuda.synchronize()
        model_time = time.time()
        outputs = model(image)

        name = ''.join([chr(i) for i in targets[0]['name'].tolist()])
        image_index.append(name)

        outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]

        image_boxes = [[] for i in range(3)] # 需要修改该值
        for o in outputs:
            for i in range(o['boxes'].shape[0]):
                image_boxes[o['labels'][i]].extend([
                    torch.cat([o['boxes'][i],o['scores'][i].unsqueeze(0)], dim=0)
                ])

        #makes sure that the all_boxes is filled with empty array when
        #there are no boxes in image_boxes
        for i in range(3):
            if image_boxes[i] != []:
                all_boxes[i].append([torch.stack(image_boxes[i])])
            else:
                all_boxes[i].append([])

        model_time = time.time() - model_time

    metric_logger.synchronize_between_processes()

    all_boxes_gathered = utils.all_gather(all_boxes)
    image_index_gathered = utils.all_gather(image_index)
    
    # results from all processes are gathered here
    if utils.is_main_process():
        all_boxes = [[] for i in range(21)]
        for abgs in all_boxes_gathered:
            for ab,abg in zip(all_boxes,abgs):
                ab += abg
        image_index = []
        for iig in image_index_gathered:
            image_index+=iig

        _write_custom_voc_results_file(data_loader, all_boxes,image_index, data_loader.dataset.root, 
                                data_loader.dataset._transforms.transforms[0].CLASSES,)
        _do_python_eval_custom_voc(data_loader)
    torch.set_num_threads(n_threads)

1.4 train.py文件修改

1.4.1 修改get_dataset函数

def get_dataset(name, image_set, transform, data_path):
    paths = {
        "coco": (data_path, get_coco, 91),
        "coco_kp": (data_path, get_coco_kp, 2),
        "voc": (data_path, get_voc, 21),
        "custom_voc": (data_path, get_custom_voc, 3)
    }
    p, ds_fn, num_classes = paths[name]

    if name=='custom_voc':  # 加载自定义的Pascal格式数据集
        ds = ds_fn(p, transforms=transform)
        return ds, num_classes
    else:    
        ds = ds_fn(p, image_set=image_set, transforms=transform)
        return ds, num_classes

1.4.2 修改main函数

def main(args):
    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")

    # 支持加载自定义Pascal格式数据集 参数dataset设置为custom_voc
    if args.dataset=='custom_voc':
        # 如果是自定义Pascal数据集,不需要传入image_set参数,因此这里设置为None
        dataset, num_classes = get_dataset(args.dataset, None, get_transform(train=True), args.train_data_path)
        dataset_test, _ = get_dataset(args.dataset, None, get_transform(train=False), args.test_data_path)
    else :
        dataset, num_classes = get_dataset(args.dataset, "train" if args.dataset=='coco' else 'trainval', 
            get_transform(train=True), args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "test" if args.dataset=='coco' else 'val', 
                    get_transform(train=False), args.data_path)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
    else:
        train_batch_sampler = torch.utils.data.BatchSampler(
            train_sampler, args.batch_size, drop_last=True)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
        collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1,
        sampler=test_sampler, num_workers=args.workers,
        collate_fn=utils.collate_fn)

    print("Creating model")
    model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes,
                                                              pretrained=args.pretrained)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(
        params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])  # 用于恢复训练,处理模型还需要优化器和学习率规则
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

    # 如果只进行模型测试,注意这里传入的参数是--resume, 原作者只提到了--resume用于恢复训练,根据官方文档可知也是可以用于模型推理的
    # 参考官方文档https://pytorch.org/tutorials/beginner/saving_loading_models.html
    if args.test_only:  
        if not args.resume:
            raise Exception('需要checkpoints模型用于推理!')
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
            model_without_ddp.load_state_dict(checkpoint['model'])

            if 'coco' == args.dataset:
                coco_evaluate(model_without_ddp, data_loader_test, device=device)
            elif 'voc' == args.dataset:
                voc_evaluate(model_without_ddp, data_loader_test, device=device)
            elif 'custom_voc' == args.dataset:
                custom_voc_evaluate(model_without_ddp, data_loader_test, device=device)
            else:
                print(f'No evaluation method available for the dataset {args.dataset}')
            # evaluate(model, data_loader_test, device=device)
            return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
        lr_scheduler.step()
        if args.output_dir:
            # model.save('./checkpoints/model_{}_{}.pth'.format(args.dataset, epoch))
            utils.save_on_master({
                'model': model_without_ddp.state_dict(), # 存储网络参数(不存储网络骨架)
                # 'model': model_without_ddp, # 存储整个网络
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'args': args},
                os.path.join(args.output_dir, 'model_{}_{}.pth'.format(args.dataset, epoch)))

        # evaluate after every epoch
        if  args.dataset=='coco':
            coco_evaluate(model, data_loader_test, device=device)
        elif 'voc'==args.dataset:
            voc_evaluate(model, data_loader_test, device=device)
        elif 'custom_voc' == args.dataset:
            custom_voc_evaluate(model, data_loader_test, device=device)
        else:
            print(f'No evaluation method available for the dataset {args.dataset}')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(
        description=__doc__)

    parser.add_argument('--data-path', default='./', help='dataset path used for coco and voc(default is "./")')
    parser.add_argument('--train-data-path',  help='train dataset path for custom voc dataset')
    parser.add_argument('--test-data-path',  help='test dataset path for custom voc dataset')
    parser.add_argument('--dataset', default='coco', 
                        help='dataset type, option are "coco", "voc" and "coco_kp", defualt is "coco"')
    parser.add_argument('--model', default='fasterrcnn_resnet50_fpn', help='model, default="fasterrcnn_resnet50_fpn"')
    parser.add_argument('--device', default='cuda', help='device, default is cuda')
    parser.add_argument('-b', '--batch-size', default=2, type=int, 
                        help='number of batch_size(default is 2)')
    parser.add_argument('--epochs', default=13, type=int, metavar='N',
                        help='number of total epochs to run(default is 13)')
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 16)')
    parser.add_argument('--lr', default=0.02, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
    parser.add_argument('--lr-steps', default=[8, 11], nargs='+', type=int, help='decrease lr every step-size epochs')
    parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
    parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
    parser.add_argument('--output-dir', default='./', help='path where to save,default="./" ')
    parser.add_argument('--resume', default='', help='resume from checkpoint,default=''')
    parser.add_argument('--aspect-ratio-group-factor', default=0, type=int)
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument(
        "--pretrained",
        dest="pretrained",
        help="Use pre-trained models from the modelzoo",
        action="store_true",
    )

    # distributed training parameters
    parser.add_argument('--world-size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')

    args = parser.parse_args()

    if args.output_dir:
        utils.mkdir(args.output_dir)

    main(args)

2.数据集加载

数据集加载是所有机器学习任务中最开始需要完成的操作,注意涉及对数据集的文件操作,因此需要对Python文件操作方式要熟练.对分类任务来说数据集加载还不算难,但是UI目标检测任务而言数据集加载就涉及比较多的细节,参考pytorch官方实现的对Pascal数据集的加载,可以发现里面需要对xml标注文件进行解析,同时把解析得到的内容存储到一个字典里,除了经典的Pascal VOC数据集格式,其他常见的数据集格式还有coco格式和yolo格式,不同算法模型会要求使用不同的数据集格式,因此这些格式直接的相互转换也是机器学习中的重点和难点内容,如果数据集这一块无法正常获取加载,后面的网络训练部分也就无从谈起了.
原项目https://github.com/lpuglia/torchvision_voc中只实现了标准VOC和Coco数据集的训练和检测,对于用Pascal Vov格式制作的自定义数据集加载和使用需要另外实现,这点在第一节内容已经介绍了各个文件中代码的修改,详细实现可以查看源码.

3.网络训练

$ python3 train.py --dataset custom_voc --train-data-path /data/to/train --test-data-path /data/to/test -b 2 --output-dir ./checkpoints

训练过程中,每完成一轮迭代训练,会对测试集进行一次模型评估,输出mAP值,绘制PR曲线.

4.模型评估测试

在训练过程已经自动对测试集进行评估测试过,亦可单独执行推理评估步骤,在Linux终端执行下列命令:

$ python3 train.py --dataset custom_voc --test-only --train-data-path /data/to/train --test-data-path /data/to/test --resume model_custom_voc_11.pth

注意上面参数–resume的作用是用于模型推理,结果为:

Test: Total time: 0:00:10 (0.0506 s / it)
+ Class lost - AP: 0.8858719783518023
+ Class normal - AP: 0.887533003893689
Mean AP = 0.8867 

5. 注意

本博客中使用的自定义数据集地址均为博主本地地址,主要是方便以后快速复现,也没有什么敏感内容就不做修改了,目前只支持Pascal格式的自定义数据集,数据集结构目录如下:
在这里插入图片描述文件夹名称和里面的内容不能出错,因为在文件voc_utils.py中实现数据集的加载代码为:
在这里插入图片描述另外一点需要注意的是在使用自定义数据集的时候还需要修改voc_utils.py中的类ConvertCustomVOCtoCOCO:
在这里插入图片描述
以上红框中的内容要根据自己的数据集进行修改,笔者的数据集只有两类:lostnormal(背景__background__不用改动),这个方式是借鉴原项目的方法实现的,灵活度不够,后面有需要再进行改进.

  • 9
    点赞
  • 49
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
您可以按照以下步骤使用PyTorch训练自己的数据集: 1. 准备数据集:将您的自定义数据集整理为适当的格式。确保每个图像都有对应的标注文件,标注文件中包含每个对象的类别和边界框坐标。 2. 创建自定义数据集类:在PyTorch中,您需要创建一个继承自`torch.utils.data.Dataset`的自定义数据集类。在这个类中,实现`__getitem__`方法来加载图像和标注,并将它们转换为模型所需的格式。 3. 数据预处理:在加载图像和标注后,您可能需要进行一些预处理操作,例如缩放、裁剪、归一化等。这些操作可以在自定义数据集类中完成。 4. 定义模型:根据您的需求选择合适的Faster R-CNN模型结构,并在PyTorch中实现它。您可以参考Torchvision库中提供的Faster R-CNN模型,也可以自己构建模型。 5. 定义损失函数和优化器:Faster R-CNN模型通常使用多个损失函数,如分类损失和边界框回归损失。在PyTorch中,您可以分别定义这些损失函数,并选择合适的优化器,如SGD或Adam。 6. 训练模型使用准备好的数据集模型、损失函数和优化器,通过迭代训练来更新模型参数。在每个训练迭代中,您需要将输入数据传递给模型,计算损失并进行反向传播更新参数。 7. 评估模型:在训练过程中,您可以定期使用验证集或测试集来评估模型的性能。通过计算精度、召回率、平均精度等指标,了解模型在自定义数据集上的表现。 请注意,以上步骤只是一个大致的指导,具体实现可能会根据您的数据集和需求有所不同。您可能需要参考相关文档、教程或代码示例来更详细地了解每个步骤的具体实现方式。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值