语义分割数据集png和json相互转化

文章介绍了如何使用anylabeling进行图像标注,将json格式的标注转换为PNG格式,以及如何将PNG图像转换回json,以便适应语义分割网络的训练。着重于连通域处理和多边形逼近技术,同时提到了精度损失的问题和讨论空间。
摘要由CSDN通过智能技术生成

为了适配语义分割网络训练数据,我使用anylabeling进行图像标注,标注完之后的图像格式应该为png,而且每个点像素值0,1,2,3,4等对应了某一类。
在这里插入图片描述
我们希望标注好的图像格式为png,需要训练的原图像为jpg,二者文件名相同。
下面就是完成该操作。

json转换为png

先导入需要的包

import os
import numpy as np
import json
from tqdm import tqdm
import cv2

首先给出json转换为png的操作,

Created with Raphaël 2.3.0 开始 读取json文件信息 提取连通域 填充连通域 写入png 结束
def json2png(json_root, class_root, name_classes):

    if not os.path.exists(json_root): 
        print("No such folder!")
        return 
    if not os.path.exists(class_root) and class_root != "":
        os.makedirs(class_root)
    
    print("Start conterting annotations...")
    for root,dirs,files in os.walk(json_root):
        for file in files:
            if not file.endswith(".json"):
                continue

            json_file = file

            # print(json_root+json_file)
            json_data = json.load(open(json_root+json_file,"r",encoding="utf-8"))
            
            # 生成空白图片
            W, H = json_data["imageWidth"], json_data["imageHeight"]
            res = np.zeros((H, W)).astype('uint8')
            
            
            for multi in json_data["shapes"]:
                if multi['label'] not in name_classes:
                    continue
                
                # 获取第几个编号
                fillColor = name_classes.index(multi['label'])
                # fillColor = 255
                pts = np.array(multi['points']).astype('int')
                cv2.fillPoly(res,[pts],fillColor)
            
            cv2.imwrite(class_root+get_root_file_name(file)+'.png', res)
            # Image.fromarray(res).save(class_root+get_root_file_name(file)+'.png')

其中 name_classes 表示语义分割的训练类别,比如:

name_classes = ["background","fire"]

而需要注意的是 name_classes 必须要添加背景 “background” .

png转换为json

为了完成png转换为json,这个相对会麻烦一点。

Created with Raphaël 2.3.0 开始 提取png图像不同像素的层 对每个层提取连通域 判断连通域的包含关系,将相包含的连通域配对 使用多边形逼近连通域 将配对好的相包含的连通域采用分割裁剪的方式转换为单连通域 写入json 结束

判断连通域的包含操作

采用点是否在多边形内部来判断连通域是否包含于另外一个连通域。因为提取好的连通域只可能存在相离和包含两种关系,不可能存在相交关系。因此只需要判断其中一个点是否包含于另外一个连通域即可判断连通域的包含关系。
采用射线交点法,首先,选择一个点P,该点可以是多边形外部的一个点。从点P向任意方向发射一条射线,与多边形的每条边进行求交。然后统计射线与多边形的边的交点个数。如果交点个数为奇数,则点P在多边形内部;如果交点个数为偶数,则点P在多边形外部。
使用python算法完成如下:

def point_in_polygon(point, polygon):
    x, y = point[0]
    n = len(polygon)
    inside = False

    p1x, p1y = polygon[0][0]
    for i in range(n + 1):
        p2x, p2y = polygon[i % n][0]
        if y > min(p1y, p2y):
            if y <= max(p1y, p2y):
                if x <= max(p1x, p2x):
                    if p1y != p2y:
                        xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
                    if p1x == p2x or x <= xinters:
                        inside = not inside
        p1x, p1y = p2x, p2y

    return inside

连通域提取配对

由于标注图像的某个类别区域可能不是单连通区域,因此采用裁剪法完成
在这里插入图片描述

def get_approx_countors(layer, length_p=0.002):
    img_bin, contours = cv2.findContours(layer, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)

    length_p=0.002

    cir_all = []
    approx_all = []
    cir_in_or_out = []

    for ss in img_bin:
        epsilon = length_p * cv2.arcLength(ss, True)
        approx = cv2.approxPolyDP(ss, epsilon, True)
        approx_all.append(approx.tolist())
        cir_all.append(ss.tolist())
        cir_in_or_out.append(-1)

    # 匹配轮廓
    for i in range(len(cir_all)):
        for j in range(len(cir_all)):
            if i == j:
                continue
            if point_in_polygon(cir_all[i][0], cir_all[j]):
                cir_in_or_out[i] = j
                break
    white_cir = []
    black_cir = []
    black_match = []
    for i in range(len(cir_in_or_out)):
        if cir_in_or_out[i] == -1: #白色
            white_cir.append(i)
        else:# 黑色
            black_cir.append(cir_in_or_out[i])
            black_match.append(i)
    cir_matched = []

    for i in range(len(white_cir)):
        cir_matched.append(approx_all[white_cir[i]])
        
    for i in range(len(black_cir)):
        index_1 = white_cir.index(black_cir[i])
        s1 = cir_matched[index_1]  + approx_all[black_match[i]] + [approx_all[black_match[i]][0]] + [cir_matched[index_1][-1]]
        cir_matched[index_1] = s1
        
    # print(cir_matched)
    for i in range(len(cir_matched)):
        s1 = cir_matched[i] + [cir_matched[i][0]]
        cir_matched[i] = s1

    return cir_matched

这个函数完成的是根据某一个二值化图像提取连通域,并配对好,输出每个连通域的点集。其中length_p表示多边形逼近的程度,该值越小表示采样点越多。

转换为json

批量转化的函数如下:

def png2json(png_root, json_root, name_classes, length_p=0.002):
    if not os.path.exists(png_root): 
        print("No such folder!")
        return 
    if not os.path.exists(json_root) and json_root != "":
        os.makedirs(json_root)

    print("Start conterting annotations...")
    for root,dirs,files in os.walk(png_root):
        for file in tqdm(files):
            if not file.endswith(".png"):
                continue
            
            img = cv2.imread(png_root+file)
            img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
            H, W = img.shape

            countors_all = []
            labels_all = []
            for i in range(1, len(name_classes)):
                layer = np.where(img==i,1,0).astype('uint8')
                # print(layer)
                countors_l = get_approx_countors(layer, length_p=0.002)
                for j in range(len(countors_l)):
                    countors_all.append(countors_l[j])
                    labels_all.append(name_classes[i])

            with open(json_root+get_root_file_name(file)+".json", 'w', encoding='utf-8') as json_f:
                

                json_f.write("{\n")

                json_f.write("  \"version\": \"0.3.3\",\n")
                json_f.write("  \"flags\": {},\n")
                json_f.write("  \"shapes\": [\n")

            
                # 提取好的点集countors_all[i][j][0][(0~1)],第i个连通域第j个点的坐标(x, y)
                        
                for i in range(len(countors_all)):
                    json_f.write("    {\n")
                    json_f.write("      \"label\": \""+labels_all[i]+"\",\n")
                    json_f.write("      \"text\": \"\",\n")
                    json_f.write("      \"points\": [\n")
                    for j in range(len(countors_all[i])):
                        json_f.write("       [\n")
                        json_f.write("        "+str(countors_all[i][j][0][0])+",\n")
                        json_f.write("        "+str(countors_all[i][j][0][1])+"\n")
                        json_f.write("       ]")   
                        if j != len(countors_all[i])-1:
                            json_f.write(",")
                        json_f.write("\n")

                    json_f.write("      ],\n")
                    json_f.write("      \"shape_type\": \"polygon\",\n")
                    json_f.write("      \"flags\": {}\n")  
                    json_f.write("    }")
                    if i != len(countors_all)-1:
                        json_f.write(",")
                    json_f.write("\n")

                json_f.write("  ],\n")
                json_f.write("  \"imagePath\": \""+get_root_file_name(file)+".jpg\",\n")
                json_f.write("  \"imageData\": null,\n")
                json_f.write("  \"imageHeight\": "+str(H)+",\n")
                json_f.write("  \"imageWidth\": "+str(W)+"\n")
                json_f.write("}\n")

总结

全部代码如下:

import os
import numpy as np
import json
# from shutil import copyfile
from tqdm import tqdm
# from xml.etree.ElementTree import parse
# from PIL import Image
import cv2

def get_root_file_name(root1):

    s0, s1 = -1, -1
    for i in range(len(root1)):
        if root1[i] == "/":
            s0 = i
        if root1[i] == ".":
            s1 = i
    return root1[s0+1:s1]  

def point_in_polygon(point, polygon):
    x, y = point[0]
    n = len(polygon)
    inside = False

    p1x, p1y = polygon[0][0]
    for i in range(n + 1):
        p2x, p2y = polygon[i % n][0]
        if y > min(p1y, p2y):
            if y <= max(p1y, p2y):
                if x <= max(p1x, p2x):
                    if p1y != p2y:
                        xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
                    if p1x == p2x or x <= xinters:
                        inside = not inside
        p1x, p1y = p2x, p2y

    return inside

def get_approx_countors(layer, length_p=0.002):
    img_bin, contours = cv2.findContours(layer, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)

    length_p=0.002

    cir_all = []
    approx_all = []
    cir_in_or_out = []

    for ss in img_bin:
        epsilon = length_p * cv2.arcLength(ss, True)
        approx = cv2.approxPolyDP(ss, epsilon, True)
    
        approx_all.append(approx.tolist())
        # print(ss.tolist())
        cir_all.append(ss.tolist())
        cir_in_or_out.append(-1)

    # 匹配轮廓
    for i in range(len(cir_all)):
        for j in range(len(cir_all)):
            if i == j:
                continue
            if point_in_polygon(cir_all[i][0], cir_all[j]):
                cir_in_or_out[i] = j
                break
    # print(cir_in_or_out)
    white_cir = []
    black_cir = []
    black_match = []
    for i in range(len(cir_in_or_out)):
        if cir_in_or_out[i] == -1: #白色
            white_cir.append(i)
        else:# 黑色
            black_cir.append(cir_in_or_out[i])
            black_match.append(i)
    cir_matched = []

    for i in range(len(white_cir)):
        cir_matched.append(approx_all[white_cir[i]])

    # for i in range(len(cir_matched)):
    #     s1 = cir_matched[i] + [cir_matched[i][0]]
    #     cir_matched[i] = s1

    for i in range(len(black_cir)):
        index_1 = white_cir.index(black_cir[i])
        s1 = cir_matched[index_1]  + approx_all[black_match[i]] + [approx_all[black_match[i]][0]] + [cir_matched[index_1][-1]]
        cir_matched[index_1] = s1


    # print(cir_matched)
    for i in range(len(cir_matched)):
        s1 = cir_matched[i] + [cir_matched[i][0]]
        cir_matched[i] = s1

    return cir_matched

def json2png(json_root, class_root, name_classes):

    if not os.path.exists(json_root): 
        print("No such folder!")
        return 
    if not os.path.exists(class_root) and class_root != "":
        os.makedirs(class_root)
    
    print("Start conterting annotations...")
    for root,dirs,files in os.walk(json_root):
        for file in files:
            if not file.endswith(".json"):
                continue

            json_file = file

            # print(json_root+json_file)
            json_data = json.load(open(json_root+json_file,"r",encoding="utf-8"))
            
            # 生成空白图片
            W, H = json_data["imageWidth"], json_data["imageHeight"]
            res = np.zeros((H, W)).astype('uint8')
            
            
            for multi in json_data["shapes"]:
                if multi['label'] not in name_classes:
                    continue
                
                # 获取第几个编号
                fillColor = name_classes.index(multi['label'])
                # fillColor = 255
                pts = np.array(multi['points']).astype('int')
                cv2.fillPoly(res,[pts],fillColor)
            
            cv2.imwrite(class_root+get_root_file_name(file)+'.png', res)
            # Image.fromarray(res).save(class_root+get_root_file_name(file)+'.png')

def png2json(png_root, json_root, name_classes, length_p=0.002):
    if not os.path.exists(png_root): 
        print("No such folder!")
        return 
    if not os.path.exists(json_root) and json_root != "":
        os.makedirs(json_root)

    print("Start conterting annotations...")
    for root,dirs,files in os.walk(png_root):
        for file in tqdm(files):
            if not file.endswith(".png"):
                continue
            
            img = cv2.imread(png_root+file)
            img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
            H, W = img.shape

            countors_all = []
            labels_all = []
            for i in range(1, len(name_classes)):
                layer = np.where(img==i,1,0).astype('uint8')
                # print(layer)
                countors_l = get_approx_countors(layer, length_p=0.002)
                for j in range(len(countors_l)):
                    countors_all.append(countors_l[j])
                    labels_all.append(name_classes[i])

            with open(json_root+get_root_file_name(file)+".json", 'w', encoding='utf-8') as json_f:
                

                json_f.write("{\n")

                json_f.write("  \"version\": \"0.3.3\",\n")
                json_f.write("  \"flags\": {},\n")
                json_f.write("  \"shapes\": [\n")

            
                # 提取好的点集countors_all[i][j][0][(0~1)],第i个连通域第j个点的坐标(x, y)
                        
                for i in range(len(countors_all)):
                    json_f.write("    {\n")
                    json_f.write("      \"label\": \""+labels_all[i]+"\",\n")
                    json_f.write("      \"text\": \"\",\n")
                    json_f.write("      \"points\": [\n")
                    for j in range(len(countors_all[i])):
                        json_f.write("       [\n")
                        json_f.write("        "+str(countors_all[i][j][0][0])+",\n")
                        json_f.write("        "+str(countors_all[i][j][0][1])+"\n")
                        json_f.write("       ]")   
                        if j != len(countors_all[i])-1:
                            json_f.write(",")
                        json_f.write("\n")

                    json_f.write("      ],\n")
                    json_f.write("      \"shape_type\": \"polygon\",\n")
                    json_f.write("      \"flags\": {}\n")  
                    json_f.write("    }")
                    if i != len(countors_all)-1:
                        json_f.write(",")
                    json_f.write("\n")

                json_f.write("  ],\n")
                json_f.write("  \"imagePath\": \""+get_root_file_name(file)+".jpg\",\n")
                json_f.write("  \"imageData\": null,\n")
                json_f.write("  \"imageHeight\": "+str(H)+",\n")
                json_f.write("  \"imageWidth\": "+str(W)+"\n")
                json_f.write("}\n")


    


if __name__ == "__main__":
    name_classes = ["background","fire"]

    json_root = './img/JPEGImages/'
    class_root = './img/SegmentationClass/'
    json2png(json_root, class_root, name_classes)
	
	png_root = './img/SegmentationClass/'
	json_root = './img/JPEGImages/'
	png2json(png_root ,json_root  ,name_classes, length_p=0.001)

    

以火焰数据集标注举例,原图为
在这里插入图片描述
先把图像归一化到0,1,2,3…像素之后用上述函数转化效果如下
在这里插入图片描述
这样就可以用训练好的网络去预测数据集之外的图像,并作适量调整从而添加进数据集完善网络训练!
由于png转换为json只能采用采样的方法,因此这样做必然会导致一定程度的精度降低。如果读者有更好的思路欢迎讨论!

  • 17
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

迟钝皮纳德

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值