学习记录(一)之语义分割数据集转为cocojson标注文件

存储的内容包括:

img_info = dict(
file_name=osp.basename(img_file),
height=segm_img.shape[0],
width=segm_img.shape[1],
anno_info=anno_info,
segm_file=osp.basename(segm_file)
)

import argparse
import glob
import os
import os.path as osp

import cityscapesscripts.helpers.labels as CSLabels
import cv2
import mmcv
import numpy as np
import pycocotools.mask as maskUtils
import torch
from mmengine.fileio import dump
from mmengine.utils import (Timer, mkdir_or_exist, track_parallel_progress,
                            track_progress)
#
#
# # bj20.8内网数据结构
"""
--BJ2
-----images
------------______img.tif
-----masks
------------______mask.png
"""
# def collect_files(img_dir, gt_dir):
#     files = []
#     # 根据自己的数据集路径修改(根据内容的要素识别修改格式)
#     img_files = glob.glob(osp.join(img_dir, 'image/*.tif'))  # 读取图像文件
#     for img_file in img_files:
#         # 如果图片和标签的文件后缀不一样,要重新定义标签的读取方式
#         label_filename = os.path.basename(img_file).replace("_img.tif","_mask.png")
#         segm_file = gt_dir + '/mask/' + os.path.basename(label_filename)
#         files.append((img_file, segm_file))
#     assert len(files), f'No images found in {img_dir}'
#     print(f'Loaded {len(files)} images from {img_dir}')
#
#     return files

# 将WHU转为coco格式
"""
-WHU 
    ---train
        -----image
            -------_.tif
        -----label
            -------_.tif
    ---val
        -----image
            -------_.tif
        -----label
            -------_.tif
    ---test
        -----image
            -------_.tif
        -----label
            -------_.tif
"""
def collect_files(img_dir, gt_dir):
    files = []
    img_files = glob.glob(osp.join(img_dir, 'image/*.tif')) #读取图像文件
    for img_file in img_files:
        segm_file = gt_dir + '/label/' + os.path.basename(img_file) #读取标签文件路径
        files.append((img_file, segm_file))
    assert len(files), f'No images found in {img_dir}'
    print(f'Loaded {len(files)} images from {img_dir}')

    return files


def collect_annotations(files, nproc=1):
    print('Loading annotation images')
    if nproc > 1:
        images = track_parallel_progress(load_img_info, files, nproc=nproc)
    else:
        images = track_progress(load_img_info, files)

    return images


def load_img_info(files):
    img_file, segm_file = files
    segm_img = mmcv.imread(segm_file, flag='unchanged',
                           backend='pillow')  # opencv读取的图像的通道顺序为BGR,读取的数据类型为numpy.ndarray,取值范围为[0,255],数据元素类型为uint8.
    # matplotlib读取的图像的通道顺序为RGB,读取的数据类型为numpy.ndarray,取值范围为[0,255],数据元素类型为uint8
    # PIL读取的图像的数据格式不是一个numpy.ndarray,打印出来的结果是一个物理地址,读取的结果包含mode,size,info,format等属性。通过np.array可将其转换成numpy.ndarray供后期的处理和使用
    # ids < 24 are stuff labels (filtering them first is about 5% faster)
    # unique_inst_ids = np.unique(segm_img[(segm_img > 0)&(segm_img <= 6)])  #如果数据集类别为大于2,利用该判断可以指定类别的区间 BJ20.8
    unique_inst_ids = np.unique(segm_img[segm_img > 0])
    anno_info = []
    for inst_id in unique_inst_ids:
        label_id = inst_id
        print('inst_id:', inst_id)
        label_mapping = {
            1: {"id": 1, "name": 'building', "category": "flat"},
            # 0: {"id": 0, "name": 'Background', "category": "void"},
            # 1: {"id": 1, "name": 'Building-flooded', "category": "flat"},
            # 2: {"id": 2, "name": 'Building-non-flooded', "category": "flat"},
            # 3: {"id": 3, "name": 'Road-flooded', "category": "flat"},
            # 4: {"id": 4, "name": 'Road-non-flooded', "category": "flat"},
            # 5: {"id": 5, "name": 'Water', "category": "flat"},
            # 6: {"id": 6, "name": 'Tree', "category": "flat"},
            # 7: {"id": 7, "name": 'Vehicle', "category": "flat"},
            # 8: {"id": 8, "name": 'Pool', "category": "flat"},
            # 9: {"id": 9, "name": 'Grass', "category": "flat"},
            # 添加更多标签映射
        }

        # label_id是您要查找的标签ID
        # 使用标签映射字典获取标签信息
        label = {}
        if label_id in label_mapping:
            label = label_mapping[label_id]

        category_id = label["id"]
        iscrowd = int(inst_id < 1000)

        # 将相同像素值的区域分离为单独的连通组件
        component_masks = (segm_img == inst_id).astype(np.uint8)
        num_components, labeled_masks, stats, centroids = cv2.connectedComponentsWithStats(component_masks, connectivity=4)
        # num_labels, instances, stats, centroids = cv2.connectedComponentsWithStats(segm_img, connectivity=4)

        for component_id in range(1, num_components):
            component_mask = np.asarray(labeled_masks == component_id, dtype=np.uint8, order='F')
            if component_mask.max() < 1:
                print(f'Ignore empty instance: {inst_id} in {segm_file}')
            #             continue
            # 计算每个连通组件的边界框
            # x, y, w, h, _ = stats[component_id]
            # bbox = [x, y, w, h]

            # 创建新的RLE编码
            mask_rle = maskUtils.encode(component_mask[:, :, None])[0]
            area = maskUtils.area(mask_rle)
            # convert to COCO style XYWH format
            bbox = maskUtils.toBbox(mask_rle)
            # for json encoding
            mask_rle['counts'] = mask_rle['counts'].decode()

            anno = dict(
                iscrowd=0,
                category_id=category_id,
                bbox=bbox,
                area=area,
                segmentation=mask_rle)

            anno_info.append(anno)

    img_info = dict(
        # remove img_prefix for filename
        file_name=osp.basename(img_file),
        height=segm_img.shape[0],
        width=segm_img.shape[1],
        anno_info=anno_info,
        segm_file=osp.basename(segm_file))
    return img_info
#
def cvt_annotations(image_infos, out_json_name):
    out_json = dict()
    img_id = 0
    ann_id = 0
    out_json['images'] = []
    out_json['categories'] = []
    out_json['annotations'] = []
    for image_info in image_infos:
        image_info['id'] = img_id
        anno_infos = image_info.pop('anno_info')
        out_json['images'].append(image_info)
        for anno_info in anno_infos:
            anno_info['image_id'] = img_id
            anno_info['id'] = ann_id
            out_json['annotations'].append(anno_info)
            ann_id += 1
        img_id += 1
    # for label in CSLabels.labels:
    #     if label.hasInstances and not label.ignoreInEval:
    #         cat = dict(id=label.id, name=label.name)
    #         out_json['categories'].append(cat)
    label_mapping = {
        1: 'building',
        # 0: 'Background',
        # 1: 'Building-flooded',
        # 2: 'Building-non-flooded',
        # 3: 'Road-flooded',
        # 4: 'Road-non-flooded',
        # 5: 'Water',
        # 6: 'Tree',
        # 7: 'Vehicle',
        # 8: 'Pool',
        # 9: 'Grass',
    }
    for category_id, category_name in label_mapping.items():
        cat = dict(id=category_id, name=category_name)
        out_json['categories'].append(cat)

    if len(out_json['annotations']) == 0:
        out_json.pop('annotations')

    dump(out_json, out_json_name)
    return out_json


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert Cityscapes annotations to COCO format')
    parser.add_argument('--cityscapes_path',default=r'E:\github_project\mmdetection-main\data\WHU', help='cityscapes data path')
    # parser.add_argument('--img-dir', default='leftImg8bit', type=str)
    # parser.add_argument('--gt-dir', default='gtFine', type=str)
    parser.add_argument('--img-dir', default='', type=str)
    parser.add_argument('--gt-dir', default='', type=str)
    parser.add_argument('-o', '--out-dir',default=r'E:\github_project\mmdetection-main\data\WHU\LFW_WHU_COCOJSON', help='output path')
    parser.add_argument(
        '--nproc', default=1, type=int, help='number of process')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    cityscapes_path = args.cityscapes_path
    out_dir = args.out_dir if args.out_dir else cityscapes_path
    mkdir_or_exist(out_dir)

    img_dir = osp.join(cityscapes_path, args.img_dir)
    gt_dir = osp.join(cityscapes_path, args.gt_dir)

    set_name = dict(
        #数据集文件夹中如果有train val test  ,就写这三个 ,文件夹名称='—__-.json'
        train='WHU_BULIDING_train.json',
        val='WHU_BULIDING_val.json',
        test='WHU_BULIDING_test.json'
    )

    for split, json_name in set_name.items():
        print(f'Converting {split} into {json_name}')
        with Timer(print_tmpl='It took {}s to convert Cityscapes annotation'):
            files = collect_files(
                osp.join(img_dir, split), osp.join(gt_dir, split))
            image_infos = collect_annotations(files, nproc=args.nproc)
            cvt_annotations(image_infos, osp.join(out_dir, json_name))


if __name__ == '__main__':
    main()

  • 9
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值