代码如下:
import cv2
import os
import numpy as np
def main():
# 加载图片a,b,label
img_a = cv2.imread("a.jpg")
img_b = cv2.imread("b.jpg")
img_label = cv2.imread("label.jpg")
# 设置裁剪大小
crop_size = 256
# 输入图片大小是20000*20000的,裁剪大小是256*256的,步长是128,计算裁剪个数
num_crop = (20000 - 256) // 128 + 1
# 计算每次裁剪的步长
step = 128
# 循环裁剪
crops_a = []
crops_b = []
crops_label = []
for i in range(num_crop):
x = i * step
y = i * step
crop_a = img_a[x:x + crop_size, y:y + crop_size, :]
if crop_a.shape != (256, 256, 3):
continue
else:
crops_a.append(crop_a)
crop_b = img_b[x:x + crop_size, y:y + crop_size, :]
if crop_b.shape != (256, 256, 3):
continue
else:
crops_b.append(crop_b)
crop_label = img_label[x:x + crop_size, y:y + crop_size, :]
if crop_label .shape != (256, 256, 3):
continue
else:
crops_label.append(crop_label)
# 对crops_a, crops_b, crops_label列表里面的图片进行数据增强来扩充数据集
# 水平翻转
crops_a.extend([cv2.flip(img, 1) for img in crops_a])
crops_b.extend([cv2.flip(img, 1) for img in crops_b])
crops_label.extend([cv2.flip(img, 1) for img in crops_label])
# 垂直翻转
crops_a.extend([cv2.flip(img, 0) for img in crops_a])
crops_b.extend([cv2.flip(img, 0) for img in crops_b])
crops_label.extend([cv2.flip(img, 0) for img in crops_label])
# 水平垂直翻转
crops_a.extend([cv2.flip(img, -1) for img in crops_a])
crops_b.extend([cv2.flip(img, -1) for img in crops_b])
crops_label.extend([cv2.flip(img, -1) for img in crops_label])
# 旋转90度
crops_a.extend([cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) for img in crops_a])
crops_b.extend([cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) for img in crops_b])
crops_label.extend([cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) for img in crops_label])
# 旋转180度
crops_a.extend([cv2.rotate(img, cv2.ROTATE_180) for img in crops_a])
crops_b.extend([cv2.rotate(img, cv2.ROTATE_180) for img in crops_b])
crops_label.extend([cv2.rotate(img, cv2.ROTATE_180) for img in crops_label])
# 旋转270度
crops_a.extend([cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) for img in crops_a])
crops_b.extend([cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) for img in crops_b])
crops_label.extend([cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) for img in crops_label])
# 调整亮度
crops_a.extend([cv2.addWeighted(img, 1.5, img, 0, 0) for img in crops_a])
crops_b.extend([cv2.addWeighted(img, 1.5, img, 0, 0) for img in crops_b])
crops_label.extend([cv2.addWeighted(img, 1.5, img, 0, 0) for img in crops_label])
# 调整对比度
crops_a.extend([cv2.addWeighted(img, 1, img, 0, -50) for img in crops_a])
crops_b.extend([cv2.addWeighted(img, 1, img, 0, -50) for img in crops_b])
crops_label.extend([cv2.addWeighted(img, 1, img, 0, -50) for img in crops_label])
# 调整饱和度
crops_a.extend([cv2.cvtColor(img, cv2.COLOR_BGR2HSV) for img in crops_a])
crops_b.extend([cv2.cvtColor(img, cv2.COLOR_BGR2HSV) for img in crops_b])
crops_label.extend([cv2.cvtColor(img, cv2.COLOR_BGR2HSV) for img in crops_label])
# 调整色相
crops_a.extend([cv2.cvtColor(img, cv2.COLOR_HSV2BGR) for img in crops_a])
crops_b.extend([cv2.cvtColor(img, cv2.COLOR_HSV2BGR) for img in crops_b])
crops_label.extend([cv2.cvtColor(img, cv2.COLOR_HSV2BGR) for img in crops_label])
# 分别统计三个列表的元素个数,并赋值给变量count_a, count_b, count_label
count_a = len(crops_a)
count_b = len(crops_b)
count_label = len(crops_label)
# 创建train, val, test文件夹
for dirname in ["train", "val", "test"]:
if not os.path.exists(dirname):
os.makedirs(dirname)
# 创建train文件夹下的A,B,Label目录
for dirname in ["train/A", "train/B", "train/Label"]:
if not os.path.exists(dirname):
os.makedirs(dirname)
# 创建val文件夹下的A,B,Label目录
for dirname in ["val/A", "val/B", "val/Label"]:
if not os.path.exists(dirname):
os.makedirs(dirname)
# 创建test文件夹下的A,B,Label目录 for dirname in ["test/A", "test/B", "test/Label"]:
if not os.path.exists(dirname):
os.makedirs(dirname)
# 保存图片
def crop_and_save_images(img_list, prefix):
for idx, img in enumerate(img_list):
cv2.imwrite(f"{prefix}/{str(idx).zfill(4)}.jpg", img)
crop_and_save_images(crops_a[:count_a//10*8], "train/A")
crop_and_save_images(crops_b[:count_b//10*8], "train/B")
crop_and_save_images(crops_label[count_label//10*8], "train/Label")
crop_and_save_images(crops_a[count_a//10*8:count_a//10*9], "val/A")
crop_and_save_images(crops_b[count_b//10*8:count_b//10*9], "val/B")
crop_and_save_images(crops_label[count_label//10*8:count_label//10*9], "val/Label")
crop_and_save_images(crops_a[count_a//10*9:], "test/A")
crop_and_save_images(crops_b[count_b//10*9:], "test/B")
crop_and_save_images(crops_label[count_label//10*9:], "test/Label")
if __name__ == "__main__":
main()