一、详解都写在注释上了
import os
import random
import numpy as np
from PIL import Image
from tqdm import tqdm # 在长循环中添加进度条提示信息
#-------------------------------------------------------#
# 想要增加测试集修改trainval_percent
# 修改train_percent用于改变验证集的比例 9:1
#
# 当前该库将测试集当作验证集使用,不单独划分测试集
#-------------------------------------------------------#
trainval_percent = 1
train_percent = 0.9
#-------------------------------------------------------#
# 指向VOC数据集所在的文件夹
# 默认指向根目录下的VOC数据集
#-------------------------------------------------------#
VOCdevkit_path = './VOCdevkit'
if __name__ == "__main__":
random.seed(0) # 固定随机种子
print("Generate txt in ImageSets.")
segfilepath = os.path.join(VOCdevkit_path, 'VOC2007/SegmentationClass/cut_video_out_label') # 标签图路径
saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Segmentation') # 生成训练集、验证集序号文件的路径
temp_seg = os.listdir(segfilepath) # # 返回指定路径下的文件和文件夹列表。
total_seg = []
for seg in temp_seg:
if seg.endswith(".png"):
total_seg.append(seg) # 将temp_seg下的所有文件都写入total_seg
num = len(total_seg) # num的值为标签图的个数
list = range(num) # 生成一个长度为num的列表
tv = int(num*trainval_percent) # 划分训练集+验证集的长度, 当trainval_percent不为1时,1-trainval_percent就是测试集的百分比
tr = int(tv*train_percent) # 划分训练集的长度
trainval= random.sample(list,tv) # 在list这个列表中无放回的随机采样(不重复抽样),生成一个新的列表, tv为采样长度
train = random.sample(trainval,tr) # 在trainval这个列表中进行随机抽样, tr为采样长度
print("train and val size",tv)
print("traub suze",tr)
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w') # 以写的方式创建一个文件,文件的路径由saveBasePath路径和trainval.txt拼接
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w') # 测试集文件路径
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w') # 训练集文件路径
fval = open(os.path.join(saveBasePath,'val.txt'), 'w') # 验证集文件路径
for i in list:
name = total_seg[i][:-4]+'/n' # 当i=0时,total_seg=video(1).png, 进行切片索引,返回一个新的列表,从video(1).png左边开始的第一个元素到倒数第四个元素(不包含)的所有元素,name='video(1)\n'
if i in trainval: # 如果i这个数值在trainval这个列表里
ftrainval.write(name) # 就把i对应的name写入到ftrainval这个文件里
if i in train: # 如果i在train这个列表里
ftrain.write(name) # 就把它也写到ftrain这个文件里
else:
fval.write(name) # 要是这个i不在train里,就把它写到fval里
else:
ftest.write(name) #要是i不在trainval里,那就把i对应的name写到ftest文件里
# 创建完文件并写入内容后应该及时关闭文件,可以释放系统资源
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
print("Generate txt in ImageSets done.")
print("Check datasets format, this may take a while.")
print("检查数据集格式是否符合要求,这可能需要一段时间。")
classes_nums = np.zeros([256], np.int)
for i in tqdm(list): # 遍历list时会显示一个进度条
name = total_seg[i]
png_file_name = os.path.join(segfilepath, name) # 将标签图路径和i对应的name拼接起来
if not os.path.exists(png_file_name):
raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。"%(png_file_name))
png = np.array(Image.open(png_file_name), np.uint8) # 打开png_file_name这个文件,并将其转换为一个numpy数组,元素为8位无符号整数
if len(np.shape(png)) > 2: # 标签图为灰度图时,只有一个颜色通道, shape只包括高和宽
print("标签图片%s的shape为%s,不属于灰度图或者八位彩图,请仔细检查数据集格式。"%(name, str(np.shape(png))))
print("标签图片需要为灰度图或者八位彩图,标签的每个像素点的值就是这个像素点所属的种类。"%(name, str(np.shape(png))))
classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256)
'''
对png数组进行reshape为一维数组, [-1]表示自动计算该维度的大小。
bincount函数会返回一个长度为minlength的数组,数组的索引表示输入数组中的整数值,对应的数组元素值表示该整数值在输入数组中出现的次数。
这里,minlength=256表示返回一个长度为256的数组,因为图像的像素值在0-255之间。
classes_nums += ...:这是一个累加操作,将bincount函数返回的数组累加到classes_nums数组上。
'''
print("打印像素点的值与数量。")
print('-' * 37)
print("| %15s | %15s |"%("Key", "Value"))
print('-' * 37)
for i in range(256):
if classes_nums[i] > 0:
print("| %15s | %15s |"%(str(i), str(classes_nums[i]))) # 一共就只有两个类别, 0表示背景,它对应输出的像素值, 1表示一个类别,它对应输出像素值
print('-' * 37)
if classes_nums[255] > 0 and classes_nums[0] > 0 and np.sum(classes_nums[1:255]) == 0:
print("检测到标签中像素点的值仅包含0与255,数据格式有误。")
print("二分类问题需要将标签修改为背景的像素点值为0,目标的像素点值为1。")
elif classes_nums[0] > 0 and np.sum(classes_nums[1:]) == 0:
print("检测到标签中仅仅包含背景像素点,数据格式有误,请仔细检查数据集格式。")
print("JPEGImages中的图片应当为.jpg文件、SegmentationClass中的图片应当为.png文件。")
print("如果格式有误,参考:")
print("https://github.com/bubbliiiing/segmentation-format-fix")
二、二分类问题
2.1 二分类问题
如果你想将标签改为背景的像素点值为0,目标的像素点值为1,进行二分类处理,可以先运行以下代码:
# --------------------------------------------------------#
# 该文件用于调整标签的格式
# --------------------------------------------------------#
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
# --------------------------------------------------------#
# Origin_SegmentationClass_path 原始标签所在的路径
# Out_SegmentationClass_path 输出标签所在的路径
# --------------------------------------------------------#
Origin_SegmentationClass_path = "./unet-pytorch-main/VOCdevkit/VOC2007/SegmentationClass/cut_video_out_label"
Out_SegmentationClass_path = "./unet-pytorch-main/datasets/SegmentationClass"
# -----------------------------------------------------------------------------------#
# Origin_Point_Value 原始标签对应的像素点值
# Out_Point_Value 输出标签对应的像素点值
# Origin_Point_Value需要与Out_Point_Value一一对应。
# 举例如下,当:
# Origin_Point_Value = np.array([0, 255]);Out_Point_Value = np.array([0, 1])
# 代表将原始标签中值为0的像素点,调整为0,将原始标签中值为255的像素点,调整为1。
#
# 示例中仅调整了两个像素点值,实际上可以更多个,如:
# Origin_Point_Value = np.array([0, 128, 255]);Out_Point_Value = np.array([0, 1, 2])
#
# 也可以是数组(当标签值为RGB像素点时),如
# Origin_Point_Value = np.array([[0, 0, 0], [1, 1, 1]]);Out_Point_Value = np.array([0, 1])
# -----------------------------------------------------------------------------------#
Origin_Point_Value = np.array([0, 255])
Out_Point_Value = np.array([0, 1])
if __name__ == "__main__":
if not os.path.exists(Out_SegmentationClass_path):
os.makedirs(Out_SegmentationClass_path)
# ---------------------------#
# 遍历标签并赋值
# ---------------------------#
png_names = os.listdir(Origin_SegmentationClass_path) # 返回路径下的文件和文件夹
print("正在遍历全部标签。")
for png_name in tqdm(png_names):
png = Image.open(os.path.join(Origin_SegmentationClass_path, png_name))
w, h = png.size
png = np.array(png)
out_png = np.zeros([h, w])
'''
这里的mask是一个布尔数组,其形状与png相同,其中png中的值等于Origin_Point_Value[i]的位置为True,否则为False。
如果png是一个多通道图像(即,mask的维度大于2),那么mask.all(-1)会将mask沿着最后一个轴进行合并,得到一个二维的布尔数组。
这个二维的布尔数组然后被用来更新out_png中的值。在布尔数组为True的位置,out_png的值被设置为Out_Point_Value[i]。
总的来说,mask = mask.all(-1)是用来处理多通道图像的,确保mask始终是一个二维数组,这样就可以正确地更新out_png的值。
'''
for i in range(len(Origin_Point_Value)):
mask = png[:, :] == Origin_Point_Value[i]
if len(np.shape(mask)) > 2:
mask = mask.all(-1)
out_png[mask] = Out_Point_Value[i]
out_png = Image.fromarray(np.array(out_png, np.uint8))
out_png.save(os.path.join(Out_SegmentationClass_path, png_name))
# -------------------------------------#
# 统计输出,各个像素点的值得个数
# -------------------------------------#
print("正在统计输出的图片每个像素点的数量。")
classes_nums = np.zeros([256], np.int)
for png_name in tqdm(png_names):
png_file_name = os.path.join(Out_SegmentationClass_path, png_name)
if not os.path.exists(png_file_name):
raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。" % (png_file_name))
png = np.array(Image.open(png_file_name), np.uint8)
classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256)
'''
对png数组进行reshape为一维数组, [-1]表示自动计算该维度的大小。
bincount函数会返回一个长度为minlength的数组,数组的索引表示输入数组中的整数值,对应的数组元素值表示该整数值在输入数组中出现的次数。
这里,minlength=256表示返回一个长度为256的数组,因为图像的像素值在0-255之间。
classes_nums += ...:这是一个累加操作,将bincount函数返回的数组累加到classes_nums数组上。
'''
print("打印像素点的值与数量。")
print('-' * 37)
print("| %15s | %15s |" % ("Key", "Value"))
print('-' * 37)
for i in range(256):
if classes_nums[i] > 0:
print("| %15s | %15s |" % (str(i), str(classes_nums[i])))
print('-' * 37)
2.2 检查标签值是否由[0,255]改为[0,1]
可以运行如下代码:
from PIL import Image
import numpy as np
# 打开图像
img = Image.open('path_to_your_image.png')
# 将图像转换为NumPy数组
img_array = np.array(img)
# 打印数组中的唯一值
unique_values = np.unique(img_array)
print('Unique pixel values in the image:', unique_values)