Deterctron2 训练自己的数据集

 使用的气球数据集

链接:https://pan.baidu.com/s/1IHetqrgJB9vhZNrCv_Pc5A 
提取码:e2ol


import json
import logging
import os
import random
from collections import OrderedDict

import cv2
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch.nn.parallel import DistributedDataParallel

import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer
from detectron2.config import get_cfg
from detectron2.data import (
    MetadataCatalog,
    build_detection_test_loader,
    build_detection_train_loader, DatasetCatalog,
)
from detectron2.engine import default_argument_parser, default_setup, default_writers, launch, DefaultPredictor
from detectron2.evaluation import (
    CityscapesInstanceEvaluator,
    CityscapesSemSegEvaluator,
    COCOEvaluator,
    COCOPanopticEvaluator,
    DatasetEvaluators,
    LVISEvaluator,
    PascalVOCDetectionEvaluator,
    SemSegEvaluator,
    inference_on_dataset,
    print_csv_format,
)
from detectron2.model_zoo import model_zoo
from detectron2.modeling import build_model
from detectron2.solver import build_lr_scheduler, build_optimizer
from detectron2.utils.events import EventStorage
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
logger = logging.getLogger("detectron2")

#如果是coco数据集
# from detectron2.data.datasets import register_coco_instances
# register_coco_instances(
#     "my_dataset_train",
#     {},
#     "D:/Desktop/yolo/detectron2-main/detectron2/data/datasets/balloon/annotations/train_region_data.json",
#     "D:/Desktop/yolo/detectron2-main/detectron2/data/datasets/balloon/train"
#                        )
# register_coco_instances(
# 	"my_dataset_val",
#     {},
#     "D:/Desktop/yolo/detectron2-main/detectron2/data/datasets/balloon/annotations/val_region_data.json",
#     "D:/Desktop/yolo/detectron2-main/detectron2/data/datasets/balloon/val"
# )
#********************如果是自己的数据集*************************
from detectron2.structures import BoxMode

def get_balloon_dicts(img_dir):
    json_file = os.path.join(img_dir, "via_region_data.json")
    with open(json_file) as f:
        imgs_anns = json.load(f)

    dataset_dicts = []
    for idx, v in enumerate(imgs_anns.values()):
        record = {}
        filename = os.path.join(img_dir, v["filename"])
        height, width = cv2.imread(filename).shape[:2]
        """
        对于每张图片,我们需要记录:
        1.图片的文件名,2.图片的编号,3.图片的高,4.图片的宽
        """
        record["file_name"] = filename
        record["image_id"] = idx
        record["height"] = height
        record["width"] = width
        annos = v["regions"]
        objs = []

        # 对于单张图片的每个标注(annotation)
        for _, anno in annos.items():
            assert not anno["region_attributes"]
            anno = anno["shape_attributes"]
            px = anno["all_points_x"]
            py = anno["all_points_y"]
            poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
            poly = [p for x in poly for p in x]

            # 创建物体的字典
            obj = {
                "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],  # 物体轮廓同时可以转换为一个框
                "bbox_mode": BoxMode.XYXY_ABS,
                "segmentation": [poly],
                "category_id": 0,
            }
            objs.append(obj)
        # 把转换完形式的字典放进 record 中
        record["annotations"] = objs
        # 把单张图片的信息加入 dataset_dict
        dataset_dicts.append(record)
    return dataset_dicts

#同时这里也是读取的目录,你的数据集的位置要和下面的目录位置相同,或者修改目录位置
for d in ["train", "val"]:
    DatasetCatalog.register("balloon_" + d, lambda d=d: get_balloon_dicts("balloon/" + d))
    MetadataCatalog.get("balloon_" + d).set(things_classes=["balloon"])

ballon_metadata = MetadataCatalog.get("balloon_train")
# ********************如果是自己的数据集*************************

dataset_dicts = get_balloon_dicts("D:/Desktop/yolo/detectron2-main/detectron2/data/datasets/balloon/train")
#查看数据集是否有问题
# for d in random.sample(dataset_dicts, 3):
#     img = cv2.imread(d["file_name"])
#     visualizer = Visualizer(img[:,:,::-1], metadata=ballon_metadata, scale=0.5)
#     out = visualizer.draw_dataset_dict(d)
#     cv2.imshow("window",out.get_image()[:,:,::-1])
#     cv2.waitKey()
#     #cv2_imshow方法不是opencv中的,可以使用cv2.imshow来显示图片。
train=0

from detectron2.engine import DefaultTrainer
if __name__ == "__main__":
    cfg = get_cfg()
    # cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))  # 预设档,参数
    cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))  # 预设档,参数
    cfg.DATASETS.TRAIN = ('balloon_train',)  # 训练集
    cfg.DATASETS.TEST = ('balloon_val',)  # 测试集
    # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")  # 迁移基础
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # 迁移基础

    cfg.DATALOADER.NUM_WORKERS = 0  # 执行序,0是cpu
    cfg.SOLVER.IMS_PER_BATCH = 1  # 每批次改变的大小
    cfg.SOLVER.BASE_LR = 0.001  # 学习率
    cfg.SOLVER.MAX_ITER = 1000  # 最大迭代次数
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128  # default:512 批次大小
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # 一类
    # cfg.MODEL.DEVICE = 'cpu'  # 注释掉此项,系统默认使用NVidia的显卡
    cfg.OUTPUT_DIR = 'D:/temp_model'

    if train:
        # 训练的一些基础参数
        os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
        trainer = DefaultTrainer(cfg)
        trainer.resume_or_load(resume=False)
        trainer.train()

        cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, 'model_final.pth')
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
        predictor = DefaultPredictor(cfg)
        val_dicts = DatasetCatalog.get('balloon_val')
        balloon_metadata = MetadataCatalog.get('balloon_val')

        s1, s2 = 0, 0
        for d in val_dicts:
            im = cv2.imread(d['file_name'])
            outputs = predictor(im)
            s1 += len(outputs['instances'].get("pred_classes"))
        with open('./balloon/val/via_region_data.json') as f:
            im_js = json.load(f)
        for i in im_js.keys():
            s2 += len(im_js[i]['regions'])
        print(s1 / s2)


    else:
        #推理 验证
        cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, 'model_final.pth')
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8
        predictor = DefaultPredictor(cfg)
        val_dicts = DatasetCatalog.get('balloon_val')
        balloon_metadata = MetadataCatalog.get('balloon_val')

        for d in random.sample(val_dicts, 3):
            im = cv2.imread(d["file_name"])
            outputs = predictor(im)
            # 查看outputs的格式请浏览:https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
            v = Visualizer(im[:, :, ::-1],
                               metadata=balloon_metadata,
                               scale=0.5,
                               instance_mode=ColorMode.IMAGE_BW
                               )
            out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
            cv2.imshow("window", out.get_image()[:, :, ::-1])
            cv2.waitKey()

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值