paddleseg中原始的转化lableme python脚本的改进

paddle的paddleseg中原始的转化lableme标注数据到可用于训练的数据,不是很方便程序化处理,于是在原始文件上做了一些修改方便自己数据的准备,具体的改动后的程序如下:

# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import argparse
import glob
import math
import json
import os
import os.path as osp
import numpy as np
import PIL.Image
import PIL.ImageDraw
import cv2

from gray2pseudo_color import get_color_map_list
import shutil


def labelmedata2segdata(datadir,dstmaindir,datasetype,mergeclsnamesdict,skipclsnames=(),dstimsize=()):
    newdatadir=osp.join(dstmaindir,datasetype)
    if not osp.exists(newdatadir):
        os.makedirs(newdatadir)
        print('Creating annotations directory:', newdatadir)

    # get the all class names for the given dataset
    class_names = ['_background_']
    for label_file in glob.glob(osp.join(datadir, '*.json')):
        with open(label_file) as f:
            data = json.load(f)
            for shape in data['shapes']:
                label = shape['label']
                cls_name = label

                ####dq###
                if cls_name in skipclsnames:
                    continue
                if cls_name in mergeclsnamesdict.keys():
                    cls_name = mergeclsnamesdict[cls_name]
                ####dq###

                if not cls_name in class_names:
                    class_names.append(cls_name)

    class_name_to_id = {}
    for i, class_name in enumerate(class_names):
        class_id = i  # starts with 0
        class_name_to_id[class_name] = class_id
        if class_id == 0:
            assert class_name == '_background_'
    class_names = tuple(class_names)
    print('class_names:', class_names)

    out_class_names_file = osp.join(dstmaindir, 'class_names.txt')
    with open(out_class_names_file, 'w') as f:
        f.writelines('\n'.join(class_names))
    print('Saved class_names:', out_class_names_file)

    color_map = get_color_map_list(256)


    if datasetype.endswith('set'):
        datapathtxt=osp.join(dstmaindir,datasetype[:-3]+'_list.txt')
    else:
        datapathtxt = osp.join(dstmaindir, datasetype + '_list.txt')
    fp1=open(datapathtxt,'w')

    for label_file in glob.glob(osp.join(datadir, '*.json')):
        print('Generating dataset from:', label_file)
        with open(label_file) as f:
            base = osp.splitext(osp.basename(label_file))[0]
            out_png_file = osp.join(newdatadir, base + '.png')

            data = json.load(f)

            annotname=osp.basename(out_png_file)
            img_file = osp.join(osp.dirname(label_file), data['imagePath'])
            im=cv2.imread(img_file)
            imname=osp.basename(img_file)
            newimpath=osp.join(newdatadir,imname)
            if dstimsize:
                newim=cv2.resize(im, (dstimsize[1],dstimsize[0]), None, fx=0, fy=0)
                cv2.imwrite(newimpath,newim)
            else:
                shutil.copy(img_file,newdatadir)

            img = np.asarray(im)
            lbl = shape2label(
                img.shape,
                data['shapes'],
                class_name_to_id,
                mergeclsnamesdict,
                skipclsnames,
                dstimsize
            )

            if osp.splitext(out_png_file)[1] != '.png':
                out_png_file += '.png'
            # Assume label ranges [0, 255] for uint8,
            if lbl.min() >= 0 and lbl.max() <= 255:
                lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
                lbl_pil.putpalette(color_map)
                lbl_pil.save(out_png_file)
                impath=osp.join(datasetype,imname)
                annotpath = osp.join(datasetype, annotname)
                datapathstr='{} {}\n'.format(impath,annotpath)
                fp1.writelines(datapathstr)
            else:
                raise ValueError(
                    '[%s] Cannot save the pixel-wise class label as PNG. '
                    'Please consider using the .npy format.' % out_png_file)
    fp1.close()


def shape2mask(img_size, points):
    label_mask = PIL.Image.fromarray(np.zeros(img_size[:2], dtype=np.uint8))
    image_draw = PIL.ImageDraw.Draw(label_mask)
    points_list = [tuple(point) for point in points]
    assert len(points_list) > 2, 'Polygon must have points more than 2'
    image_draw.polygon(xy=points_list, outline=1, fill=1)
    return np.array(label_mask, dtype=bool)


def shape2label(img_size, shapes, class_name_mapping,mergeclsnamesdict,skipclsnames,dstimsize):
    if dstimsize:
        scale=np.array(list(dstimsize))/np.array(img_size[:2])
        label = np.zeros(dstimsize, dtype=np.int32)
    else:
        scale = np.array([1.0,1.0])
        label = np.zeros(img_size[:2], dtype=np.int32)
    for shape in shapes:
        ####dq###
        points = np.array(shape['points'])
        pntnum=points.shape[0]
        scales=np.tile(scale,(pntnum,1))
        points*=scales
        points=points.tolist()
        ####dq###

        class_name = shape['label']

        if class_name in skipclsnames:
            continue

        ####dq###
        if class_name in mergeclsnamesdict.keys():
            class_name = mergeclsnamesdict[class_name]
        shape_type = shape.get('shape_type', None)
        class_id = class_name_mapping[class_name]
        if dstimsize:
            label_mask = shape2mask(dstimsize, points)
        else:
            label_mask = shape2mask(img_size[:2], points)
        label[label_mask] = class_id
    return label


if __name__ == '__main__':
    datadir='/data/dataset/segdata'#/*.json,*.jpg
    dstmaindir='/PaddleSeg/dataset/mydata'#
    datasetype='train'
    mergeclsnamesdict={'q':'b','h':'b','x':'b','y':'b'}
    skipclsnames=('h',)
    labelmedata2segdata(datadir, dstmaindir, datasetype,     
    mergeclsnamesdict,skipclsnames,dstimsize=(540, 960))

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值