使用Albumentations进行数据增强
https://juejin.cn/post/7080060840858091557
from major_utils import *
from PIL import Image
import albumentations as A
import cv2
import numpy as np
def scale_image_max(image):
# 将图像转换为浮点数格式
image_float = image.astype(np.float32)
# 找到像素值范围
min_val, max_val = np.min(image_float), np.max(image_float)
# 将像素值缩放到0-255之间
image_normalized = cv2.normalize(image_float, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
# 返回归一化后的图像
return image_normalized
img_size = [1024, 1024]
data_transforms = {
"train": A.Compose([
# 改变图像大小
A.Resize(*img_size, interpolation=cv2.INTER_NEAREST),
# 水平翻转
A.HorizontalFlip(p=0.5),
# 垂直反转
A.VerticalFlip(p=0.5),
# 转置
A.Transpose(always_apply=False, p=0.3),
# 平移、缩放、旋转
A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=45, interpolation=1, border_mode=4, value=None, mask_value=None, always_apply=False, p=1),
# 随机对比度增强
A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(0.0,0.2),p=0),
# 弹性变换
A.ElasticTransform(alpha=5, sigma=50, alpha_affine=50, interpolation=1, border_mode=4, always_apply=False, p=0.3)
# 标准化图像
# A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
], p=1.0),
"valid": A.Compose([
A.Resize(*img_size, interpolation=cv2.INTER_NEAREST),
], p=1.0)
}
# 定义数据增强流程
# transform = Compose([
# RandomCrop(height=224, width=224), # 随机裁剪到224x224大小
# RandomRotation(degrees=10), # 随机旋转10度
# Resize(height=224, width=224), # 重新调整图像大小(可选,如果需要的话)
# Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化图像(可选,如果需要的话)
# ], p=1) # p=1 表示这些操作都会被应用到一个图像上
for i in range(100):
data_transform = data_transforms['train']
img = cv2.imread(r"images\0200.tif", cv2.IMREAD_UNCHANGED)
scaled_image = scale_image_max(img)
scaled_image2 = scaled_image * 2.1 - 90
# scaled_image2 = scaled_image2.astype(np.uint8)
img = np.clip(scaled_image2, 0, 255).astype(np.uint8)
msk = cv2.imread(r"labels\0200.tif", cv2.IMREAD_UNCHANGED)
data = data_transform(image=img, mask=msk)
img = data['image']
msk = data['mask']
cv2.imshow('img', img)
cv2.imshow('msk', msk)
cv2.waitKey(0)
cv2.destroyAllWindows()