Detectron2 maskRCNN训练自己的数据集

摘要:使用Detectron2 来训练一个mask RCNN实例分割的模型。数据集用labelme标注,最后转为coco格式训练。

参考:
安装detectron2
labelme标注格式转为coco格式

数据准备

用labelme标注的分割任务采用链接labelme标注格式转为coco格式的方法转换为coco格式,detectron2注册数据集的时候需要指定标签文件(coco所有标注的内容都在一个json文件内)和图片存储路径,例如这里的训练集的路径是“K:\imageData\golden_pad\mask_bond\label\bondOnly\train”,对应在2.条中的路径。
在这里插入图片描述

1. 导入依赖库

import torch, torchvision
print(torch.__version__, torch.cuda.is_available())
!gcc --version
1.6.0 True
gcc (x86_64-win32-sjlj-rev0, Built by MinGW-W64 project) 8.1.0
Copyright (C) 2018 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
#from google.colab.patches import cv2_imshow

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

2. 注册数据集

from detectron2.data.datasets import register_coco_instances
register_coco_instances("bondOnlyDataset_train", {},
                        r"K:\imageData\golden_pad\mask_bond\label\bondOnly\train\train_bondOnly.json", 
                        r"K:\imageData\golden_pad\mask_bond\label\bondOnly\train")
register_coco_instances("bondOnlyDataset_val", {}, 
                        r"K:\imageData\golden_pad\mask_bond\label\bondOnly\val\val_bondOnly.json", 
                        r"K:\imageData\golden_pad\mask_bond\label\bondOnly\val")

my_bond_metadata = MetadataCatalog.get("bondOnlyDataset_train")
my_bond_metadata
namespace(name='bondOnlyDataset_train',
          json_file='K:\\imageData\\golden_pad\\mask_bond\\label\\bondOnly\\train\\train_bondOnly.json',
          image_root='K:\\imageData\\golden_pad\\mask_bond\\label\\bondOnly\\train',
          evaluator_type='coco')

3. Train

from detectron2.engine import DefaultTrainer

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("bondOnlyDataset_train")
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
cfg.SOLVER.MAX_ITER = 300    # 300 iterations 
cfg.SOLVER.STEPS = []        # do not decay learning rate
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  #the number of classes

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()
[32m[07/01 15:41:12 d2.data.datasets.coco]: [0mLoaded 10 images in COCO format from K:\imageData\golden_pad\mask_bond\label\bondOnly\train\train_bondOnly.json
[32m[07/01 15:41:12 d2.data.build]: [0mRemoved 0 images with no usable annotations. 10 images left.
[32m[07/01 15:41:12 d2.data.build]: [0mDistribution of instances among all 1 categories:
[36m| category   | #instances   |
|:-----------|:-------------|
| 0          | 109          |
|            |              |[0m
[32m[07/01 15:41:12 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in training: [ResizeShortestEdge(short_edge_length=(640, 672, 704, 736, 768, 800), max_size=1333, sample_style='choice'), RandomFlip()]
[32m[07/01 15:41:12 d2.data.build]: [0mUsing training sampler TrainingSampler
[32m[07/01 15:41:12 d2.data.common]: [0mSerializing 10 elements to byte tensors and concatenating them all ...
[32m[07/01 15:41:12 d2.data.common]: [0mSerialized dataset takes 0.03 MiB


Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (2, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (2,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (4, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (4,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.mask_head.predictor.weight' to the model due to incompatible shapes: (80, 256, 1, 1) in the checkpoint but (1, 256, 1, 1) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.mask_head.predictor.bias' to the model due to incompatible shapes: (80,) in the checkpoint but (1,) in the model! You might want to double check if this is expected.
Some model parameters or buffers are not found in the checkpoint:
[34mroi_heads.box_predictor.bbox_pred.{bias, weight}[0m
[34mroi_heads.box_predictor.cls_score.{bias, weight}[0m
[34mroi_heads.mask_head.predictor.{bias, weight}[0m
The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


[32m[07/01 15:41:32 d2.engine.train_loop]: [0mStarting training from iteration 0


g:\mydoc\ml\detection2\detectron2\detectron2\modeling\roi_heads\fast_rcnn.py:103: UserWarning: This overload of nonzero is deprecated:
	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  ..\torch\csrc\utils\python_arg_parser.cpp:766.)
  num_fg = fg_inds.nonzero().numel()


[32m[07/01 15:43:09 d2.utils.events]: [0m eta: 0:17:17  iter: 19  total_loss: 2.29  loss_cls: 0.6478  loss_box_reg: 0.2737  loss_mask: 0.6865  loss_rpn_cls: 0.5353  loss_rpn_loc: 0.07  time: 3.6469  data_time: 0.3210  lr: 1.6068e-05  max_mem: 2762M
[32m[07/01 15:44:27 d2.utils.events]: [0m eta: 0:16:48  iter: 39  total_loss: 2.004  loss_cls: 0.5684  loss_box_reg: 0.5949  loss_mask: 0.6479  loss_rpn_cls: 0.09013  loss_rpn_loc: 0.06019  time: 3.7739  data_time: 0.0017  lr: 3.2718e-05  max_mem: 2762M
[32m[07/01 15:45:44 d2.utils.events]: [0m eta: 0:15:41  iter: 59  total_loss: 1.922  loss_cls: 0.4842  loss_box_reg: 0.7446  loss_mask: 0.5819  loss_rpn_cls: 0.03776  loss_rpn_loc: 0.05224  time: 3.7993  data_time: 0.0013  lr: 4.9367e-05  max_mem: 2762M
[32m[07/01 15:46:59 d2.utils.events]: [0m eta: 0:14:07  iter: 79  total_loss: 1.769  loss_cls: 0.4234  loss_box_reg: 0.7561  loss_mask: 0.5092  loss_rpn_cls: 0.0316  loss_rpn_loc: 0.04771  time: 3.7876  data_time: 0.0016  lr: 6.6017e-05  max_mem: 2762M
[32m[07/01 15:48:14 d2.utils.events]: [0m eta: 0:12:50  iter: 99  total_loss: 1.622  loss_cls: 0.3621  loss_box_reg: 0.7352  loss_mask: 0.4412  loss_rpn_cls: 0.0186  loss_rpn_loc: 0.04654  time: 3.7840  data_time: 0.0012  lr: 8.2668e-05  max_mem: 2762M
[32m[07/01 15:49:30 d2.utils.events]: [0m eta: 0:11:33  iter: 119  total_loss: 1.481  loss_cls: 0.3132  loss_box_reg: 0.7195  loss_mask: 0.3788  loss_rpn_cls: 0.0193  loss_rpn_loc: 0.03893  time: 3.7889  data_time: 0.0012  lr: 9.9318e-05  max_mem: 2762M
[32m[07/01 15:50:45 d2.utils.events]: [0m eta: 0:10:17  iter: 139  total_loss: 1.342  loss_cls: 0.2538  loss_box_reg: 0.7044  loss_mask: 0.3372  loss_rpn_cls: 0.01795  loss_rpn_loc: 0.03062  time: 3.7814  data_time: 0.0013  lr: 0.00011597  max_mem: 2762M
[32m[07/01 15:52:01 d2.utils.events]: [0m eta: 0:09:00  iter: 159  total_loss: 1.211  loss_cls: 0.2095  loss_box_reg: 0.6424  loss_mask: 0.2873  loss_rpn_cls: 0.01189  loss_rpn_loc: 0.04023  time: 3.7833  data_time: 0.0014  lr: 0.00013262  max_mem: 2762M
[32m[07/01 15:53:20 d2.utils.events]: [0m eta: 0:07:47  iter: 179  total_loss: 0.9882  loss_cls: 0.1662  loss_box_reg: 0.5488  loss_mask: 0.2426  loss_rpn_cls: 0.004047  loss_rpn_loc: 0.02753  time: 3.8005  data_time: 0.0011  lr: 0.00014927  max_mem: 2762M
[32m[07/01 15:54:38 d2.utils.events]: [0m eta: 0:06:32  iter: 199  total_loss: 0.7963  loss_cls: 0.1555  loss_box_reg: 0.4163  loss_mask: 0.1969  loss_rpn_cls: 0.002383  loss_rpn_loc: 0.02136  time: 3.8124  data_time: 0.0012  lr: 0.00016592  max_mem: 2762M
[32m[07/01 15:55:55 d2.utils.events]: [0m eta: 0:05:13  iter: 219  total_loss: 0.6485  loss_cls: 0.121  loss_box_reg: 0.33  loss_mask: 0.1632  loss_rpn_cls: 0.003209  loss_rpn_loc: 0.02633  time: 3.8126  data_time: 0.0011  lr: 0.00018257  max_mem: 2762M
[32m[07/01 15:57:13 d2.utils.events]: [0m eta: 0:03:55  iter: 239  total_loss: 0.5863  loss_cls: 0.1117  loss_box_reg: 0.2936  loss_mask: 0.1408  loss_rpn_cls: 0.002255  loss_rpn_loc: 0.01813  time: 3.8213  data_time: 0.0011  lr: 0.00019922  max_mem: 2762M
[32m[07/01 15:58:29 d2.utils.events]: [0m eta: 0:02:37  iter: 259  total_loss: 0.5784  loss_cls: 0.1194  loss_box_reg: 0.2973  loss_mask: 0.1357  loss_rpn_cls: 0.001713  loss_rpn_loc: 0.03597  time: 3.8197  data_time: 0.0011  lr: 0.00021587  max_mem: 2762M
[32m[07/01 15:59:46 d2.utils.events]: [0m eta: 0:01:18  iter: 279  total_loss: 0.489  loss_cls: 0.1003  loss_box_reg: 0.2314  loss_mask: 0.1291  loss_rpn_cls: 0.00143  loss_rpn_loc: 0.02953  time: 3.8201  data_time: 0.0011  lr: 0.00023252  max_mem: 2762M
[32m[07/01 16:01:04 d2.utils.events]: [0m eta: 0:00:00  iter: 299  total_loss: 0.4941  loss_cls: 0.09173  loss_box_reg: 0.2375  loss_mask: 0.1203  loss_rpn_cls: 0.001047  loss_rpn_loc: 0.02824  time: 3.8256  data_time: 0.0020  lr: 0.00024917  max_mem: 2762M
[32m[07/01 16:01:05 d2.engine.hooks]: [0mOverall training speed: 298 iterations in 0:19:00 (3.8256 s / it)
[32m[07/01 16:01:05 d2.engine.hooks]: [0mTotal training time: 0:19:01 (0:00:01 on hooks)

4. 推理

4.1 读取训练的模型生成一个预测器

# Inference should use the config with parameters that are used in training
# cfg now already contains everything we've set previously. We changed it a little bit for inference:
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")  # path to the model we just trained
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.3   # set a custom testing threshold
predictor = DefaultPredictor(cfg)

4.2 读取一张图片预测,并用detectron2可视化结果

from detectron2.utils.visualizer import ColorMode

img_path = r"K:\imageData\golden_pad\bond2\006004_2.bmp"
im = cv2.imread(img_path)
outputs = predictor(im)  # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
v = Visualizer(im[:, :, ::-1],
               metadata=my_bond_metadata, 
               scale=1, 
               instance_mode=2   # remove the colors of unsegmented pixels. This option is only available for segmentation models
)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2.namedWindow("image",0)
cv2.imshow("image",out.get_image()[:, :, ::-1])
cv2.waitKey(0)
cv2.destroyWindow("image")

在这里插入图片描述

4.3 自定义可视化mask

def showMask(outputs):
    mask = outputs["instances"].to("cpu").get("pred_masks").numpy()
    #这里的mask的通道数与检测到的示例的个数一致,把所有通道的mask合为一个通道
    img = np.zeros((mask.shape[1],mask.shape[2]))
    for i in range(mask.shape[0]):
        img += mask[i]

    np.where(img>0,255,0)
    cv2.namedWindow("mask",0)
    cv2.imshow("mask",img)
    cv2.waitKey(1)
    
img_path = r"K:\imageData\golden_pad\bond2\006004_2.bmp"
im = cv2.imread(img_path)
outputs = predictor(im) 
showMask(outputs)
cv2.namedWindow("image",0)
cv2.imshow("image",im)
cv2.waitKey(0)
cv2.destroyAllWindows()

在这里插入图片描述

4.4 推理时间测试

import time

def speedTest(predictor,img):
    N = 20
    startT = time.time()
    for i in range(N):
        predictor(img)
    endT = time.time()
    print("run time:{:.3}s/per image,imageShape:{}".format((endT-startT)/N,img.shape))
speedTest(predictor,im)
    
run time:0.804s/per image,imageShape:(275, 350, 3)

  • 3
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值