要达成的效果
import torch
import random
import skimage.io as io
import os
from PIL import Image
import cv2
import numpy as np
def gen_input_mask(shape, hole_size, hole_area=None, max_holes=1):
mask = torch.zeros(shape)
bsize, _, mask_h, mask_w = mask.shape
for i in range(bsize):
n_holes = random.choice(list(range(1, max_holes+1)))
for _ in range(n_holes):
# choose patch width
if isinstance(hole_size[0], tuple) and len(hole_size[0]) == 2:
hole_w = random.randint(hole_size[0][0], hole_size[0][1])
else:
hole_w = hole_size[0]
# choose patch height
if isinstance(hole_size[1], tuple) and len(hole_size[1]) == 2:
hole_h = random.randint(hole_size[1][0], hole_size[1][1])
else:
hole_h = hole_size[1]
# choose offset upper-left coordinate
if hole_area:
harea_xmin, harea_ymin = hole_area[0]
harea_w, harea_h = hole_area[1]
offset_x = random.randint(harea_xmin, harea_xmin + harea_w - hole_w)
offset_y = random.randint(harea_ymin, harea_ymin + harea_h - hole_h)
else:
offset_x = random.randint(0, mask_w - hole_w)
offset_y = random.randint(0, mask_h - hole_h)
mask[i, :, offset_y:offset_y + hole_h, offset_x:offset_x + hole_w] = 1.0
return mask
def gen_hole_area(size, mask_size):
mask_w, mask_h = mask_size
harea_w, harea_h = size
offset_x = random.randint(0, mask_w - harea_w)
offset_y = random.randint(0, mask_h - harea_h)
return ((offset_x, offset_y), (harea_w, harea_h))
if __name__=='__main__':
dir_path1 = "C:/Users/86152/OneDrive/桌面/label"
files= os.listdir(dir_path1)
for file in files:
image = io.imread(dir_path1+'/'+file)
mask = gen_input_mask(
shape = (1, 1, 256, 256),
hole_size = (
(48, 96),(48, 96)
),
hole_area = gen_hole_area(
(96, 96), (256, 256)
),
max_holes=2
)
mask = mask.squeeze(0)
cm, wm, hm = mask.shape
mask = mask.reshape(wm, hm, cm)
mask = mask.numpy()
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
mask = mask.astype(np.uint8)
mask = mask * 250
# mask = mask.astype(np.float32) / 255
print(mask.dtype)
print(image.dtype)
result = cv2.add(image, mask)
print(result.dtype)
# cv2.imshow("result", result)
# cv2.waitKey(0)
# result = result[...,::-1]
# BGR to RGB
io.imsave('C:/Users/86152/OneDrive/桌面/picture/'+file,result)
# io.imsave('C:/Users/86152/OneDrive/桌面/2.png',result)
# result.save("C:/Users/86152/OneDrive/桌面/2.png")