import torch as t
import torch.nn as nn
def ReflectPad2D(image, pad_size):
'''
镜像padding
:param image: 图像
:param pad_size: padding大小,我这里简化成正方形图,四个方向全padding同样的值
:return: padding后的图像
'''
original_shape = image.shape
image = image.reshape([1]+list(original_shape)) # reshape成4D
''' 利用torch.nn.functional.pad() 进行padding,由于该函数其只能padding 4D图的后两维,3D的最后一维 ,所以这里进行了两次reshape'''
image = t.tensor(image)
p1d = (pad_size, pad_size, pad_size, pad_size) # pad last dim by (1, 1) and 2nd to last by (2, 2)
pad_image = nn.functional.pad(image, p1d, mode='reflect')
pad_image = pad_image.reshape(original_shape[0], original_shape[1]+2*pad_size, original_shape[2]+2*pad_size) # reshape 回3D
return pad_image