增强主要是增加图片数量,添加噪声翻转等,mask的像素不会发生变换,但是会一起翻转,注意mask最好像素还是不要改,比如mask像素类别是0,1,2,3四个类别,就不要让他改动。本文的最后再附上检查mask像素是否超标的代码,此检查代码可以有效解决报num_classes相关的错误。
一、数据增强代码:
import imgaug.augmenters as iaa # 导入iaa
import cv2
import glob
import os
import numpy as np
if __name__ == '__main__':
img_dir = 'F:/CT_lung_seg_and_class/seg_data/CC-CCI/image' # 图片文件路径
msk_dir = 'F:/CT_lung_seg_and_class/seg_data/CC-CCI/mask' # 标签文件路径
#img_type = '.png'
img_tmp_dir = 'F:/CT_lung_seg_and_class/seg_data/CC-CCI_AUG/image/' # 输出图片文件路径
msk_tmp_dir = 'F:/CT_lung_seg_and_class/seg_data/CC-CCI_AUG/mask/'
img_list = os.listdir(img_dir)
msk_list = os.listdir(msk_dir)
for i in range(len(img_list)):
img_name = img_list[i]
msk_name = msk_list[i]
img = cv2.imread(filename=img_dir + "/" + img_name)
img = np.expand_dims(img, axis=0).astype(np.float32)
msk = cv2.imread(filename=msk_dir + "/" + msk_name)
msk = np.expand_dims(msk, axis=0).astype(np.int32)
# 定义数据增强策略
# 每次选择一个翻转方式
seq = iaa.Sequential([
iaa.Fliplr(0.5), # 水平翻转
iaa.Flipud(0.5), # 垂直翻转
iaa.GaussianBlur(sigma=(0, 3.0)), # 高斯模糊
iaa.Sharpen(alpha=(0, 0.3), lightness=(0.9, 1.1)), # 锐化处理
iaa.Affine(scale=(0.9, 1), translate_percent=(0, 0.1), rotate=(-40, 40), cval=0, mode='constant'), # 仿射变换
# iaa.CropAndPad(px=(-10, 0), percent=None, pad_mode='constant', pad_cval=0, keep_size=True), # 裁剪缩放
# iaa.PiecewiseAffine(scale=(0, 0.05), nb_rows=4, nb_cols=4, cval=0), # 以控制点的方式随机形变
iaa.ContrastNormalization((0.75, 1.5), per_channel=True), # 对比度增强,0.75-1.5随机数值为alpha,该alpha应用于每个通道
iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5), # 高斯噪声
iaa.Multiply((0.8, 1.2), per_channel=0.2), # 20%的图片像素值乘以0.8-1.2中间的数值,用以增加图片明亮度或改变颜色
])
# 同时对原图和分割进行数据增强
for j in range(8):
img_aug, msk_aug = seq(images=img, segmentation_maps=msk)
img_out = img_tmp_dir + img_name.split(".")[0] + "_" + str(j) + '.jpg'
msk_out = msk_tmp_dir + msk_name.split(".")[0] + "_" + str(j) + '.png'
cv2.imwrite(img_out, img_aug[0])
cv2.imwrite(msk_out, msk_aug[0,:,:,0])
print("正在进行数据增强{}".format(i))
二、将mask和图像进行加颜色重合代码:
import os
from PIL import Image
root_path_background = "F:/CT_lung_seg_and_class/seg_data/CC-CCI_clean/val/images/"
root_path_paste = "F:/CT_lung_seg_and_class/seg_data/CC-CCI_clean/val/mask/"
output_path = "F:/CT_lung_seg_and_class/seg_data/CC-CCI_clean/val/manual/"
img_list = os.listdir(root_path_background)
label_list = os.listdir(root_path_paste)
img_list = sorted(img_list)
label_list = sorted(label_list)
for num, img_label in enumerate(zip(img_list, label_list)):
img = Image.open(os.path.join(root_path_background, img_label[0]))
label = Image.open(os.path.join(root_path_paste, img_label[1]))
label = label.convert("RGB")
fin = Image.blend(img, label, 0.3)
# fin.show()
fin.save(os.path.join(output_path, img_label[0]))
将mask和原图经过图像重合后图片效果如下:
三、检查mask是否在限定类别以内,比如如果mask只包含0,1两个像素值,该代码可以找出超出像素1的图像。代码如下:
import os
import os.path as osp
from tqdm import tqdm
import cv2
import numpy as np
num_classes = 4
mask_dir = "F:/CT_lung_seg_and_class/seg_data/CC-CCI_clean/train/mask1/"
mask_names = os.listdir(mask_dir)
for mask_name in tqdm(mask_names):
mask_path = osp.join(mask_dir, mask_name)
mask = cv2.imread(mask_path, 0)
h, w = mask.shape[:2]
pix = []
for i in range(0, num_classes):
pix.append(len(np.where(mask==i)[0]))
if sum(pix) != h*w:
print("error: " + mask_name)