Bubbliiiing的深度学习仓库voc_annovation.py详解

一、详解都写在注释上了

import os
import random

import numpy as np
from PIL import Image
from tqdm import tqdm  # 在长循环中添加进度条提示信息

#-------------------------------------------------------#
#   想要增加测试集修改trainval_percent 
#   修改train_percent用于改变验证集的比例 9:1
#   
#   当前该库将测试集当作验证集使用,不单独划分测试集
#-------------------------------------------------------#
trainval_percent    = 1
train_percent       = 0.9
#-------------------------------------------------------#
#   指向VOC数据集所在的文件夹
#   默认指向根目录下的VOC数据集
#-------------------------------------------------------#
VOCdevkit_path      = './VOCdevkit'

if __name__ == "__main__":
    random.seed(0)  # 固定随机种子
    print("Generate txt in ImageSets.")
    segfilepath     = os.path.join(VOCdevkit_path, 'VOC2007/SegmentationClass/cut_video_out_label')  # 标签图路径
    saveBasePath    = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Segmentation') # 生成训练集、验证集序号文件的路径
    
    temp_seg = os.listdir(segfilepath)   # # 返回指定路径下的文件和文件夹列表。
    total_seg = []
    for seg in temp_seg:
        if seg.endswith(".png"):
            total_seg.append(seg)  # 将temp_seg下的所有文件都写入total_seg

    num     = len(total_seg)  # num的值为标签图的个数
    list    = range(num)  # 生成一个长度为num的列表
    tv      = int(num*trainval_percent)  # 划分训练集+验证集的长度, 当trainval_percent不为1时,1-trainval_percent就是测试集的百分比
    tr      = int(tv*train_percent)  # 划分训练集的长度
    trainval= random.sample(list,tv)  # 在list这个列表中无放回的随机采样(不重复抽样),生成一个新的列表, tv为采样长度
    train   = random.sample(trainval,tr)  # 在trainval这个列表中进行随机抽样, tr为采样长度
    
    print("train and val size",tv)
    print("traub suze",tr)
    ftrainval   = open(os.path.join(saveBasePath,'trainval.txt'), 'w')   # 以写的方式创建一个文件,文件的路径由saveBasePath路径和trainval.txt拼接
    ftest       = open(os.path.join(saveBasePath,'test.txt'), 'w')  # 测试集文件路径
    ftrain      = open(os.path.join(saveBasePath,'train.txt'), 'w')  # 训练集文件路径
    fval        = open(os.path.join(saveBasePath,'val.txt'), 'w')  # 验证集文件路径
    
    for i in list:  
        name = total_seg[i][:-4]+'/n'  # 当i=0时,total_seg=video(1).png, 进行切片索引,返回一个新的列表,从video(1).png左边开始的第一个元素到倒数第四个元素(不包含)的所有元素,name='video(1)\n'
        if i in trainval:  # 如果i这个数值在trainval这个列表里
            ftrainval.write(name)  # 就把i对应的name写入到ftrainval这个文件里
            if i in train:  # 如果i在train这个列表里
                ftrain.write(name)  # 就把它也写到ftrain这个文件里
            else:  
                fval.write(name)  # 要是这个i不在train里,就把它写到fval里
        else:  
            ftest.write(name)  #要是i不在trainval里,那就把i对应的name写到ftest文件里
    # 创建完文件并写入内容后应该及时关闭文件,可以释放系统资源
    ftrainval.close()  
    ftrain.close()  
    fval.close()  
    ftest.close()
    print("Generate txt in ImageSets done.")

    print("Check datasets format, this may take a while.")
    print("检查数据集格式是否符合要求,这可能需要一段时间。")
    classes_nums        = np.zeros([256], np.int)
    for i in tqdm(list):  # 遍历list时会显示一个进度条
        name            = total_seg[i]
        png_file_name   = os.path.join(segfilepath, name) # 将标签图路径和i对应的name拼接起来
        if not os.path.exists(png_file_name):
            raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。"%(png_file_name))
        
        png             = np.array(Image.open(png_file_name), np.uint8)  # 打开png_file_name这个文件,并将其转换为一个numpy数组,元素为8位无符号整数
        if len(np.shape(png)) > 2:  # 标签图为灰度图时,只有一个颜色通道, shape只包括高和宽
            print("标签图片%s的shape为%s,不属于灰度图或者八位彩图,请仔细检查数据集格式。"%(name, str(np.shape(png))))
            print("标签图片需要为灰度图或者八位彩图,标签的每个像素点的值就是这个像素点所属的种类。"%(name, str(np.shape(png))))

        classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256)
        '''
        对png数组进行reshape为一维数组, [-1]表示自动计算该维度的大小。
        bincount函数会返回一个长度为minlength的数组,数组的索引表示输入数组中的整数值,对应的数组元素值表示该整数值在输入数组中出现的次数。
        这里,minlength=256表示返回一个长度为256的数组,因为图像的像素值在0-255之间。
        classes_nums += ...:这是一个累加操作,将bincount函数返回的数组累加到classes_nums数组上。
        '''
            
    print("打印像素点的值与数量。")
    print('-' * 37)
    print("| %15s | %15s |"%("Key", "Value"))
    print('-' * 37)
    for i in range(256):
        if classes_nums[i] > 0:
            print("| %15s | %15s |"%(str(i), str(classes_nums[i])))   # 一共就只有两个类别, 0表示背景,它对应输出的像素值, 1表示一个类别,它对应输出像素值
            print('-' * 37)
    
    if classes_nums[255] > 0 and classes_nums[0] > 0 and np.sum(classes_nums[1:255]) == 0:
        print("检测到标签中像素点的值仅包含0与255,数据格式有误。")
        print("二分类问题需要将标签修改为背景的像素点值为0,目标的像素点值为1。")
    elif classes_nums[0] > 0 and np.sum(classes_nums[1:]) == 0:
        print("检测到标签中仅仅包含背景像素点,数据格式有误,请仔细检查数据集格式。")

    print("JPEGImages中的图片应当为.jpg文件、SegmentationClass中的图片应当为.png文件。")
    print("如果格式有误,参考:")
    print("https://github.com/bubbliiiing/segmentation-format-fix")

二、二分类问题

2.1 二分类问题

如果你想将标签改为背景的像素点值为0,目标的像素点值为1,进行二分类处理,可以先运行以下代码:

# --------------------------------------------------------#
#   该文件用于调整标签的格式
# --------------------------------------------------------#
import os

import numpy as np
from PIL import Image
from tqdm import tqdm

# --------------------------------------------------------#
#   Origin_SegmentationClass_path   原始标签所在的路径
#   Out_SegmentationClass_path      输出标签所在的路径
# --------------------------------------------------------#
Origin_SegmentationClass_path = "./unet-pytorch-main/VOCdevkit/VOC2007/SegmentationClass/cut_video_out_label"
Out_SegmentationClass_path = "./unet-pytorch-main/datasets/SegmentationClass"

# -----------------------------------------------------------------------------------#
#   Origin_Point_Value  原始标签对应的像素点值
#   Out_Point_Value     输出标签对应的像素点值
#                       Origin_Point_Value需要与Out_Point_Value一一对应。
#   举例如下,当:
#   Origin_Point_Value = np.array([0, 255]);Out_Point_Value = np.array([0, 1])
#   代表将原始标签中值为0的像素点,调整为0,将原始标签中值为255的像素点,调整为1。
#
#   示例中仅调整了两个像素点值,实际上可以更多个,如:
#   Origin_Point_Value = np.array([0, 128, 255]);Out_Point_Value = np.array([0, 1, 2])
#
#   也可以是数组(当标签值为RGB像素点时),如
#   Origin_Point_Value = np.array([[0, 0, 0], [1, 1, 1]]);Out_Point_Value = np.array([0, 1])
# -----------------------------------------------------------------------------------#
Origin_Point_Value = np.array([0, 255])
Out_Point_Value = np.array([0, 1])

if __name__ == "__main__":
    if not os.path.exists(Out_SegmentationClass_path):
        os.makedirs(Out_SegmentationClass_path)

    # ---------------------------#
    #   遍历标签并赋值
    # ---------------------------#
    png_names = os.listdir(Origin_SegmentationClass_path)  # 返回路径下的文件和文件夹
    print("正在遍历全部标签。")
    for png_name in tqdm(png_names):
        png = Image.open(os.path.join(Origin_SegmentationClass_path, png_name))
        w, h = png.size

        png = np.array(png)
        out_png = np.zeros([h, w])
        '''
        这里的mask是一个布尔数组,其形状与png相同,其中png中的值等于Origin_Point_Value[i]的位置为True,否则为False。
        如果png是一个多通道图像(即,mask的维度大于2),那么mask.all(-1)会将mask沿着最后一个轴进行合并,得到一个二维的布尔数组。
        这个二维的布尔数组然后被用来更新out_png中的值。在布尔数组为True的位置,out_png的值被设置为Out_Point_Value[i]。
        总的来说,mask = mask.all(-1)是用来处理多通道图像的,确保mask始终是一个二维数组,这样就可以正确地更新out_png的值。
        '''
        for i in range(len(Origin_Point_Value)):
            mask = png[:, :] == Origin_Point_Value[i]
            if len(np.shape(mask)) > 2:
                mask = mask.all(-1)
            out_png[mask] = Out_Point_Value[i]

        out_png = Image.fromarray(np.array(out_png, np.uint8))
        out_png.save(os.path.join(Out_SegmentationClass_path, png_name))

    # -------------------------------------#
    #   统计输出,各个像素点的值得个数
    # -------------------------------------#
    print("正在统计输出的图片每个像素点的数量。")
    classes_nums = np.zeros([256], np.int)
    for png_name in tqdm(png_names):
        png_file_name = os.path.join(Out_SegmentationClass_path, png_name)
        if not os.path.exists(png_file_name):
            raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。" % (png_file_name))

        png = np.array(Image.open(png_file_name), np.uint8)
        classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256)
        '''
               对png数组进行reshape为一维数组, [-1]表示自动计算该维度的大小。
               bincount函数会返回一个长度为minlength的数组,数组的索引表示输入数组中的整数值,对应的数组元素值表示该整数值在输入数组中出现的次数。
               这里,minlength=256表示返回一个长度为256的数组,因为图像的像素值在0-255之间。
               classes_nums += ...:这是一个累加操作,将bincount函数返回的数组累加到classes_nums数组上。
               '''

    print("打印像素点的值与数量。")
    print('-' * 37)
    print("| %15s | %15s |" % ("Key", "Value"))
    print('-' * 37)
    for i in range(256):
        if classes_nums[i] > 0:
            print("| %15s | %15s |" % (str(i), str(classes_nums[i])))
            print('-' * 37)

2.2 检查标签值是否由[0,255]改为[0,1]

可以运行如下代码:

from PIL import Image
import numpy as np

# 打开图像
img = Image.open('path_to_your_image.png')

# 将图像转换为NumPy数组
img_array = np.array(img)

# 打印数组中的唯一值
unique_values = np.unique(img_array)
print('Unique pixel values in the image:', unique_values)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值