使用的气球数据集
链接:https://pan.baidu.com/s/1IHetqrgJB9vhZNrCv_Pc5A
提取码:e2ol
import json
import logging
import os
import random
from collections import OrderedDict
import cv2
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch.nn.parallel import DistributedDataParallel
import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer
from detectron2.config import get_cfg
from detectron2.data import (
MetadataCatalog,
build_detection_test_loader,
build_detection_train_loader, DatasetCatalog,
)
from detectron2.engine import default_argument_parser, default_setup, default_writers, launch, DefaultPredictor
from detectron2.evaluation import (
CityscapesInstanceEvaluator,
CityscapesSemSegEvaluator,
COCOEvaluator,
COCOPanopticEvaluator,
DatasetEvaluators,
LVISEvaluator,
PascalVOCDetectionEvaluator,
SemSegEvaluator,
inference_on_dataset,
print_csv_format,
)
from detectron2.model_zoo import model_zoo
from detectron2.modeling import build_model
from detectron2.solver import build_lr_scheduler, build_optimizer
from detectron2.utils.events import EventStorage
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
logger = logging.getLogger("detectron2")
#如果是coco数据集
# from detectron2.data.datasets import register_coco_instances
# register_coco_instances(
# "my_dataset_train",
# {},
# "D:/Desktop/yolo/detectron2-main/detectron2/data/datasets/balloon/annotations/train_region_data.json",
# "D:/Desktop/yolo/detectron2-main/detectron2/data/datasets/balloon/train"
# )
# register_coco_instances(
# "my_dataset_val",
# {},
# "D:/Desktop/yolo/detectron2-main/detectron2/data/datasets/balloon/annotations/val_region_data.json",
# "D:/Desktop/yolo/detectron2-main/detectron2/data/datasets/balloon/val"
# )
#********************如果是自己的数据集*************************
from detectron2.structures import BoxMode
def get_balloon_dicts(img_dir):
json_file = os.path.join(img_dir, "via_region_data.json")
with open(json_file) as f:
imgs_anns = json.load(f)
dataset_dicts = []
for idx, v in enumerate(imgs_anns.values()):
record = {}
filename = os.path.join(img_dir, v["filename"])
height, width = cv2.imread(filename).shape[:2]
"""
对于每张图片,我们需要记录:
1.图片的文件名,2.图片的编号,3.图片的高,4.图片的宽
"""
record["file_name"] = filename
record["image_id"] = idx
record["height"] = height
record["width"] = width
annos = v["regions"]
objs = []
# 对于单张图片的每个标注(annotation)
for _, anno in annos.items():
assert not anno["region_attributes"]
anno = anno["shape_attributes"]
px = anno["all_points_x"]
py = anno["all_points_y"]
poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
poly = [p for x in poly for p in x]
# 创建物体的字典
obj = {
"bbox": [np.min(px), np.min(py), np.max(px), np.max(py)], # 物体轮廓同时可以转换为一个框
"bbox_mode": BoxMode.XYXY_ABS,
"segmentation": [poly],
"category_id": 0,
}
objs.append(obj)
# 把转换完形式的字典放进 record 中
record["annotations"] = objs
# 把单张图片的信息加入 dataset_dict
dataset_dicts.append(record)
return dataset_dicts
#同时这里也是读取的目录,你的数据集的位置要和下面的目录位置相同,或者修改目录位置
for d in ["train", "val"]:
DatasetCatalog.register("balloon_" + d, lambda d=d: get_balloon_dicts("balloon/" + d))
MetadataCatalog.get("balloon_" + d).set(things_classes=["balloon"])
ballon_metadata = MetadataCatalog.get("balloon_train")
# ********************如果是自己的数据集*************************
dataset_dicts = get_balloon_dicts("D:/Desktop/yolo/detectron2-main/detectron2/data/datasets/balloon/train")
#查看数据集是否有问题
# for d in random.sample(dataset_dicts, 3):
# img = cv2.imread(d["file_name"])
# visualizer = Visualizer(img[:,:,::-1], metadata=ballon_metadata, scale=0.5)
# out = visualizer.draw_dataset_dict(d)
# cv2.imshow("window",out.get_image()[:,:,::-1])
# cv2.waitKey()
# #cv2_imshow方法不是opencv中的,可以使用cv2.imshow来显示图片。
train=0
from detectron2.engine import DefaultTrainer
if __name__ == "__main__":
cfg = get_cfg()
# cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")) # 预设档,参数
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) # 预设档,参数
cfg.DATASETS.TRAIN = ('balloon_train',) # 训练集
cfg.DATASETS.TEST = ('balloon_val',) # 测试集
# cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml") # 迁移基础
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") # 迁移基础
cfg.DATALOADER.NUM_WORKERS = 0 # 执行序,0是cpu
cfg.SOLVER.IMS_PER_BATCH = 1 # 每批次改变的大小
cfg.SOLVER.BASE_LR = 0.001 # 学习率
cfg.SOLVER.MAX_ITER = 1000 # 最大迭代次数
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # default:512 批次大小
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # 一类
# cfg.MODEL.DEVICE = 'cpu' # 注释掉此项,系统默认使用NVidia的显卡
cfg.OUTPUT_DIR = 'D:/temp_model'
if train:
# 训练的一些基础参数
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, 'model_final.pth')
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
predictor = DefaultPredictor(cfg)
val_dicts = DatasetCatalog.get('balloon_val')
balloon_metadata = MetadataCatalog.get('balloon_val')
s1, s2 = 0, 0
for d in val_dicts:
im = cv2.imread(d['file_name'])
outputs = predictor(im)
s1 += len(outputs['instances'].get("pred_classes"))
with open('./balloon/val/via_region_data.json') as f:
im_js = json.load(f)
for i in im_js.keys():
s2 += len(im_js[i]['regions'])
print(s1 / s2)
else:
#推理 验证
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, 'model_final.pth')
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8
predictor = DefaultPredictor(cfg)
val_dicts = DatasetCatalog.get('balloon_val')
balloon_metadata = MetadataCatalog.get('balloon_val')
for d in random.sample(val_dicts, 3):
im = cv2.imread(d["file_name"])
outputs = predictor(im)
# 查看outputs的格式请浏览:https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
v = Visualizer(im[:, :, ::-1],
metadata=balloon_metadata,
scale=0.5,
instance_mode=ColorMode.IMAGE_BW
)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2.imshow("window", out.get_image()[:, :, ::-1])
cv2.waitKey()