Mask Scoring RCNN训练自己的数据

本文档详细介绍了如何使用PyTorch训练Mask Scoring RCNN模型,包括环境配置、数据准备、参数修改和模型训练。首先,从GitHub获取相关代码库,然后创建conda环境并安装依赖。数据通过labelme标注并转为COCO格式。接着,调整配置文件以适应自定义数据集,并启动训练。训练完成后,进行预测操作,修改yaml权重和预测脚本来使用训练好的模型。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一. 代码准备

    基于pytorch。

    mask scoring rcnn 代码参考:【github

    mask rcnn benchmark 【github

二. 环境安装

1. 基于conda创建pytorch环境: 

conda create -n pytorch python=3.7.4
conda install ipython
conda install -c pytorch pytorch-nightly torchvision=0.2.1 cudatoolkit=9.0 # 注:必须9.0
conda activate pytorch
pip install numpy scipy ninja yacs cython matplotlib tqdm opencv-python

注:pip install torchvision==0.2.1,否则会出现 AttributeError: 'list' object has no attribute 'resize' #45

参考 https://github.com/zjhuang22/maskscoring_rcnn/issues/45

2. 安装cocoapi & apex:

export INSTALL_DIR=$PWD

# install pycocotools
git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
python setup.py build_ext install

# install apex
cd $INSTALL_DIR
git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cuda_ext --cpp_ext

3. 安装detection:

# install PyTorch Detection
cd $INSTALL_DIR

#maskrcnn-benchmark
#git clone https://github.com/facebookresearch/maskrcnn-benchmark.git

git clone https://github.com/zjhuang22/maskscoring_rcnn

cd maskscoring_rcnn
python setup.py build develop

三. 数据准备

    训练数据基于labelme标注,需要转换成coco格式,目前常用的是 labelme2coco.py,可以找到比较多的code,直接用就好了。

# -*- coding:utf-8 -*-
import os, sys
import argparse
import json
import matplotlib.pyplot as plt
import skimage.io as io
from labelme import utils
import numpy as np
import glob
import PIL.Image

class MyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(MyEncoder, self).default(obj)


class labelme2coco(object):
    def __init__(self, labelme_json=[], save_json_path='./tran.json'):
        '''
        :param labelme_json: 所有labelme的json文件路径组成的列表
        :param save_json_path: json保存位置
        '''
        self.labelme_json = labelme_json
        self.save_json_path = save_json_path
        self.images = []
        self.categories = []
        self.annotations = []
        # self.data_coco = {}
        self.label = []
        self.annID = 1
        self.height = 0
        self.width = 0

        self.save_json()

    def data_transfer(self):

        for num, json_file in enumerate(self.labelme_json):
            with open(json_
评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值