win10安装mmdetection以及训练私有数据集

1、安装mmdetection

参考:https://www.huaweicloud.com/articles/b0247a0d742f451efbf435acfc79a40d.html
默认已经安装过anaconda、cuda、cudnn、pytorch等等一系列的包,安装mmdetection主要有两个模块,一个是mmcv一个是mmdet,mmcv可以直接pip install。mmdet稍微复杂点,貌似我用的git clone,git clone安装请自行百度,如果不行的话试试pip install mmdet。哈哈哈抱歉我也不太记得了,反正就是尝试。安装好了之后可能会报错,貌似是要更新一下mmcv还是mmdet的版本。(下次我一定做笔记)总之win10是可以安装mmdetection,虽然官网没有给教程。

2、将自己的数据集制作成coco格式

楼主表达能力很差,直接上代码吧。参考:https://blog.csdn.net/qq_15969343/article/details/80848175。做了一些小修改

import json
import os
import cv2
import shutil

dataset = {'categories': [], 'images': [], 'annotations': []}

# 根路径,里面包含images(图片文件夹),annos.txt(bbox标注),classes.txt(类别标签),以及annotations文件夹(如果没有则会自动创建,用于保存最后的json)
root_path = r'D:\PycharmProjects\GitHubProjects\yolov5-master\PigDetection\dataprocess\新建文件夹'
img_path = r'D:\PycharmProjects\GitHubProjects\yolov5-master\PigDetection\train_img_bbox'

# 用于创建训练集或验证集
phase = 'instances_val2017'
# 训练集和验证集划分的界线
split = 399

# 打开类别标签
with open(os.path.join(root_path, 'classes.txt')) as f:
    classes = f.read().strip().split()


# 建立类别标签和数字id的对应关系
for i, cls in enumerate(classes, 1):
    dataset['categories'].append({'supercategory': 'mark', 'id': i, 'name': cls})
# dataset['categories'].append({'supercategory': 'mark', 'id': 1, 'name': 'pig'})

# 读取images文件夹的图片名称
indexes = [f for f in os.listdir(os.path.join(root_path, 'images'))]

# 判断是建立训练集还是验证集
if phase == 'instances_train2017':
    indexes = [line for i, line in enumerate(indexes) if i <= split]
    train_path = os.path.join(root_path, 'coco/train2017')
    if not os.path.exists(train_path):
        os.makedirs(train_path)
    for i in indexes:
        shutil.copy(os.path.join(img_path, i), train_path)

elif phase == 'instances_val2017':
    indexes = [line for i, line in enumerate(indexes) if i > split]
    val_path = os.path.join(root_path, 'coco//val2017')
    if not os.path.exists(val_path):
        os.makedirs(val_path)
    for i in indexes:
        print(os.path.join(img_path, i))
        shutil.copy(os.path.join(img_path, i), val_path)

# 读取Bbox信息
with open(os.path.join(root_path, 'annos.txt')) as tr:
    annos = tr.readlines()

all = 0
for k, index in enumerate(indexes):
    # 用opencv读取图片,得到图像的宽和高
    # print(os.path.join(img_path, index))
    img = cv2.imread(os.path.join(img_path, index))
    height, width = img.shape[:2]

    # 添加图像的信息到dataset中
    dataset['images'].append({'file_name': index,
                              'id': k,
                              'width': width,
                              'height': height})

    for ii, anno in enumerate(annos):
        parts = anno.strip().split()
        # print('parts:', parts)

        # 如果图像的名称和标记的名称对上,则添加标记
        if parts[0] == index:
            # 类别
            cls_id = parts[1]
            # x_min
            x1 = float(parts[2])
            # y_min
            y1 = float(parts[3])
            # x_max
            x2 = float(parts[4])
            # y_max
            y2 = float(parts[5])
            width = max(0, x2 - x1)
            height = max(0, y2 - y1)
            dataset['annotations'].append({
                'area': width * height,
                'bbox': [x1, y1, width, height],
                'category_id': int(cls_id),
                'id': all,
                'image_id': k,
                'iscrowd': 0,
                # mask, 矩形是从左上角点按顺时针的四个顶点
                'segmentation': [[x1, y1, x2, y1, x2, y2, x1, y2]]
            })
            all += 1

# 保存结果的文件夹
folder = os.path.join(root_path, 'annotations')

if not os.path.exists(folder):
    os.makedirs(folder)
json_name = os.path.join(root_path, 'annotations/{}.json'.format(phase))
with open(json_name, 'w') as f:
    json.dump(dataset, f)

3、训练自己的数据集

首先下载mmdetection代码到本地,官网:https://github.com/open-mmlab/mmdetection。貌似官网教程说的挺详细的,我在这里仅仅说一些要改动的坑。以configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py为例
1、修改class_names。路径C:\ProgramData\Anaconda3\Lib\site-packages\mmdet-2.14.0-py3.8.egg\mmdet\core\evaluation(博主之前修改的是下载的mmdetection里面的mmdet包一直报错)

def coco_classes():
    # return [
    #     'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
    #     'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign',
    #     'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
    #     'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
    #     'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
    #     'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard',
    #     'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork',
    #     'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
    #     'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair',
    #     'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv',
    #     'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
    #     'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
    #     'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'
    # ]
    return ['pig']

2、修改coco.py。路径C:\ProgramData\Anaconda3\Lib\site-packages\mmdet-2.14.0-py3.8.egg\mmdet\datasets

@DATASETS.register_module()
class CocoDataset(CustomDataset):

    # CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    #            'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
    #            'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
    #            'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
    #            'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
    #            'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
    #            'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    #            'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    #            'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
    #            'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
    #            'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
    #            'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
    #            'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
    #            'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
    CLASSES = ('pig',)

假如只有一个类别就需要加上逗号。不加会报错让你加逗号
3、修改faster_rcnn_r50_fpn.py,(因为博主的config参数是faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py,在configs里面打开这个文件可以看到models的路径)之后找到faster_rcnn_r50_fpn.py文件把num_classes这个参数全部都改为你的类别个数。这里只需要改动一个即可。路径D:\PycharmProjects\GitHubProjects\mmdetection-master\mmdetection-master\tools\configs_base_\models\faster_rcnn_r50_fpn.py。这里说明一下为什么configs在tools文件夹下。貌似win10系统不放到这里会报错,所以我直接把configs文件夹复制到这里了。
4、运行tools包下面的trian.py。配置参数是configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py。成功!(写的有点粗糙,慢慢再补充)

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值