import os
from random import shuffle
import numpy as np
def image2patch(image, patch_size=(10, 10), step_size=None, mask_percent=0, random_mask=True):
m, n = patch_size
if len(image.shape) == 2:
C = 1
H, W = image.shape
image = np.reshape(image, [H, W, 1])
elif len(image.shape) == 3:
H, W, C = image.shape
else:
raise os.error
if not step_size:
step_h, step_w = patch_size
else:
step_h, step_w = step_size
H_num = int((H-patch_size[0])/step_h)
W_num = int((W-patch_size[1])/step_w)
patch_num = H_num * W_num
y_ = [step_h * i for i in range(H_num)]
x_ = [step_w * i for i in range(W_num)]
assert mask_percent >=0 and mask_percent <=1
mask_num = [0 for i in range(int(mask_percent*patch_num))]
patch_flag = mask_num + [1 for i in range(patch_num-len(mask_num))]
if random_mask:
shuffle(patch_flag)
res = np.zeros([patch_num, m, n, C])
index = 0
for y in y_:
for x in x_:
patch = image[y:y+m, x:x+n, :]
if patch.shape != (m, n, C):
print('Warning, check size of patch.')
res[index] = np.zeros([m, n, C])
else:
res[index] = patch*patch_flag[index]
index = index + 1
return H_num, W_num, res
测试
import matplotlib.pyplot as plt
from skimage import data, img_as_float
from skimage.transform import rescale, resize, downscale_local_mean
# img = img_as_float(data.camera()) # 灰色
img = img_as_float(data.astronaut()) # 彩色
print(img.shape)
# image_downscaled = downscale_local_mean(img, (4, 3))
m, n, im2col = image2patch(img, patch_size=(50, 50),
step_size=(40, 60),
mask_percent=0.3,
random_mask=True)
print(im2col.shape)
plt.figure(figsize=(12,12))
plt.imshow(img)
plt.figure(figsize=(12,12))
for i in np.arange(m*n):
plt.subplot(m, n, i+1)
plt.imshow(im2col[i])
plt.axis("off")
plt.subplots_adjust(wspace = 0.05,hspace = 0.05)
plt.show()
效果
原图1
分块2