maskrcnn模型训练和推理过程

该文详细介绍了如何基于mmdetection库在PyTorch环境中搭建maskrcnn模型,包括数据标注、数据集划分、模型配置和训练过程。首先,利用labelme进行标注并转换,然后修改源码以适应自定义类别,最后配置并训练maskrcnn模型。
摘要由CSDN通过智能技术生成

1、源码:mmtetection 库,pytorch 版本的 mask rcnn

github:https://github.com/open-mmlab/mmdetection

2、环境搭建

       支持 PyTorch 1.6 以上的版本

3、数据标注

         (1)在dataset下新建文件夹images和label

        (2)工具labelme,将标注的文件放到label/total文件夹下,原图放入images/total下

        (3)运行 creat_txt.py,划分数据集,将划分的txt保存到dataset文件夹下面

#!/usr/bin/env python

# -*- coding:utf-8 -*-

import os
import random

trainval_percent = 0.8  # 验证集+训练集占总比例多少,剩下一部分做测试,不留
train_percent = 0.8  # 训练数据集占验证集+训练集比例多少
jsonfilepath = './dataset/labelme/total'
txtsavepath = './dataset/'
total_xml = os.listdir(jsonfilepath)

num = len(total_xml)
list = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list, tv)
train = random.sample(trainval, tr)

ftrainval = open(os.path.join(txtsavepath,'./trainval.txt'), 'w')
ftest = open(os.path.join(txtsavepath,'./test.txt'), 'w')
ftrain = open(os.path.join(txtsavepath,'./train.txt'), 'w')
fval = open(os.path.join(txtsavepath,'./val.txt'), 'w')

for i in list:
    name = total_xml[i][:-5] + '\n'
    if i in trainval:
        ftrainval.write(name)
        if i in train:
            ftrain.write(name)
        else:
            fval.write(name)
    else:
        ftest.write(name)

ftrainval.close()
ftrain.close()
fval.close()
ftest.close()

 (4)运行classify.py文件,将图片和标注文件按照step3划分的数据集移动到不同的文件夹下面。

#!/usr/bin/env python

# -*- coding:utf-8 -*-
import shutil
from shutil import copy, rmtree
import cv2 as cv
import os

sets = ['train', 'val', 'test']
root='./dataset'
image_root=os.path.join(root,"images")
label_root=os.path.join(root,"label")

def mk_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在,则先删除原文件夹在重新创建
        rmtree(file_path)
    os.makedirs(file_path)

##在images和label的文件夹下创建['train', 'val', 'test']文件夹
for image_set in sets:
    mk_file(os.path.join(image_root,image_set))
    mk_file(os.path.join(label_root, image_set))

for image_set in sets:
    image_ids = open('%s/%s.txt' % (root,image_set)).read().strip().split()
    for image_id in image_ids:
      
        img = cv.imread('%s/total/%s.jpg' % (image_root,image_id))
        json = '%s/total/%s.json' % (label_root,image_id)



        cv.imwrite('%s/%s/%s.jpg' % (image_root,image_set, image_id), img)
        cv.imwrite('%s/%s/%s.jpg' % (label_root,image_set, image_id), img)
        shutil.copy(json, '%s/%s/%s.json' % (label_root,image_set, image_id))
print("finish")

 (5)运行labelme2coco.py 文件,将labme的标注文件转换成分割的标注文件,修改input_dir和input_dir的路径,运行三次,将train、val、test分别转换。

   运行前先准备labels.txt

输入下面内容到labels.txt里面,有几个类别就加几行

__ignore__
_background_
car

#!/usr/bin/env python

import argparse
import collections
import datetime
import glob
import json
import os
import os.path as osp
import sys
import uuid
from shutil import copy, rmtree

import imgviz
import numpy as np

import labelme

try:
    import pycocotools.mask
except ImportError:
    print("Please install pycocotools:\n\n    pip install pycocotools\n")
    sys.exit(1)


def main():
   
    root = '/=dataset'
    input_dir = os.path.join(root,'label/train/')
    output_dir = os.path.join(root,'annotations/train/')
    labels = os.path.join(root,'labels.txt')
    noviz = False

    if osp.exists(output_dir):
        #print("Output directory already exists:", output_dir)
        rmtree(output_dir)

    os.makedirs(output_dir)
    os.makedirs(osp.join(output_dir, "JPEGImages"))
    if not noviz:
        os.makedirs(osp.join(output_dir, "Visualization"))
    print("Creating dataset:", output_dir)

    now = datetime.datetime.now()

    data = dict(
        info=dict(
            description=None,
            url=None,
            version=None,
            year=now.year,
            contributor=None,
            date_created=now.strftime("%Y-%m-%d %H:%M:%S.%f"),
        ),
        licenses=[dict(url=None, id=0, name=None, )],
        images=[
            # license, url, file_name, height, width, date_captured, id
        ],
        type="instances",
        annotations=[
            # segmentation, area, iscrowd, image_id, bbox, category_id, id
        ],
        categories=[
            # supercategory, id, name
        ],
    )

    class_name_to_id = {}
    for i, line in enumerate(open(labels).readlines()):
        class_id = i - 1  # starts with -1
        class_name = line.strip()
        if class_id == -1:
            assert class_name == "__ignore__"
            continue
        class_name_to_id[class_name] = class_id
        data["categories"].append(
            dict(supercategory=None, id=class_id, name=class_name, )
        )

    out_ann_file = osp.join(output_dir, "annotations.json")
    label_files = glob.glob(osp.join(input_dir, "*.json"))
    for image_id, filename in enumerate(label_files):
        print("Generating dataset from:", filename)

        label_file = labelme.LabelFile(filename=filename)

        base = osp.splitext(osp.basename(filename))[0]
        out_img_file = osp.join(output_dir, "JPEGImages", base + ".jpg")

        img = labelme.utils.img_data_to_arr(label_file.imageData)
        imgviz.io.imsave(out_img_file, img)
        data["images"].append(
            dict(
                license=0,
                url=None,
                # file_name=osp.relpath(out_img_file, osp.dirname(out_ann_file)),
                file_name=base + ".jpg",
                height=img.shape[0],
                width=img.shape[1],
                date_captured=None,
                id=image_id,
            )
        )

        masks = {}  # for area
        segmentations = collections.defaultdict(list)  # for segmentation
        for shape in label_file.shapes:
            points = shape["points"]
            label = shape["label"]
            group_id = shape.get("group_id")
            shape_type = shape.get("shape_type", "polygon")
            mask = labelme.utils.shape_to_mask(
                img.shape[:2], points, shape_type
            )

            if group_id is None:
                group_id = uuid.uuid1()

            instance = (label, group_id)

            if instance in masks:
                masks[instance] = masks[instance] | mask
            else:
                masks[instance] = mask

            if shape_type == "rectangle":
                (x1, y1), (x2, y2) = points
                x1, x2 = sorted([x1, x2])
                y1, y2 = sorted([y1, y2])
                points = [x1, y1, x2, y1, x2, y2, x1, y2]
            if shape_type == "circle":
                (x1, y1), (x2, y2) = points
                r = np.linalg.norm([x2 - x1, y2 - y1])
                # r(1-cos(a/2))<x, a=2*pi/N => N>pi/arccos(1-x/r)
                # x: tolerance of the gap between the arc and the line segment
                n_points_circle = max(int(np.pi / np.arccos(1 - 1 / r)), 12)
                i = np.arange(n_points_circle)
                x = x1 + r * np.sin(2 * np.pi / n_points_circle * i)
                y = y1 + r * np.cos(2 * np.pi / n_points_circle * i)
                points = np.stack((x, y), axis=1).flatten().tolist()
            else:
                points = np.asarray(points).flatten().tolist()

            segmentations[instance].append(points)
        segmentations = dict(segmentations)

        for instance, mask in masks.items():
            cls_name, group_id = instance
            if cls_name not in class_name_to_id:
                continue
            cls_id = class_name_to_id[cls_name]

            mask = np.asfortranarray(mask.astype(np.uint8))
            mask = pycocotools.mask.encode(mask)
            area = float(pycocotools.mask.area(mask))
            bbox = pycocotools.mask.toBbox(mask).flatten().tolist()

            data["annotations"].append(
                dict(
                    id=len(data["annotations"]),
                    image_id=image_id,
                    category_id=cls_id,
                    segmentation=segmentations[instance],
                    area=area,
                    bbox=bbox,
                    iscrowd=0,
                )
            )

        if not noviz:
            viz = img
            if masks:

                listdata_labels = []
                listdata_captions = []
                listdata_masks = []

                for (cnm, gid), msk in masks.items():
                    if cnm in class_name_to_id:
                        listdata_labels.append(class_name_to_id[cnm])
                        listdata_captions.append(cnm)
                        listdata_masks.append(msk)

                listdata = zip(listdata_labels, listdata_captions, listdata_masks)
                labels, captions, masks = zip(*listdata)
                # labels, captions, masks = zip(*[(listdata_labels, listdata_captions, listdata_masks)])

            
                viz = imgviz.instances2rgb(
                    image=img,
                    labels=labels,
                    masks=masks,
                    captions=captions,
                    font_size=15,
                    line_width=2,
                )
            out_viz_file = osp.join(
                output_dir, "Visualization", base + ".jpg"
            )
            imgviz.io.imsave(out_viz_file, viz)

    with open(out_ann_file, "w") as f:
        json.dump(data, f)


if __name__ == "__main__":
    main()

数据集就准备好了

4、修改模型的源码

        (1)修改类别

        ./mmdet/evaluation/functional/class_names.py

        ./mmdetection-master/mmdet/datasets/coco.py

         (2)修改maskrcnn的配置文件

        ./mmdetection-master/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py

修改这四个文件:

./mmdetection-master/configs/_base_/models/mask-rcnn_r50_fpn.py

将num_classes(2处)都改为自己的类别数

 

 ./mmdetection-master/configs/_base_/datasets/coco_instance.py

改成自己的数据集文件夹

 标注文件的位置:

 ./mmdetection-master/configs/_base_/schedules/schedule_1x.py

修改模型训练次数和学习率

修改configs/_base_/default_runtime.py

load_from的路径为预训练的权重文件,

logger=dict(type='LoggerHook', interval=50)  #interval 日志隔50次保存一次,可以改小点,看每次训练的情况。
checkpoint=dict(type='CheckpointHook', interval=1) #interval 训练的权重文件隔1次保存,如果内存不够可以改大点,隔5次保存一次权重

         (3)修改train.py文件

改成mask-rcnn_r50_fpn_1x_coco.py所在位置

5、模型就可以训练了,训练完可以用tools/analysis_tools的文件分析训练结果


                
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值