矩阵转向量
class ToTensor:
r"""
ToTensor
"""
def __init__(self, expand_dims=True, **kwds) -> None:
self.expand_dims = expand_dims
self.kwds = kwds
pass
def __call__(self, ndarray) -> Any:
assert ndarray.dims in [2, 3, 4], "Support 2D(HW), 3D(DHW,CHW), or 4D(CDHW)"
dtype = self.kwds.get("dtype", ndarray.dtype)
if self.expand_dims and ndarray.dims!=4:
ndarray = np.expand_dims(ndarray, axis=0)
tensor = torch.from_numpy(ndarray.astype(dtype))
return tensor
多种裁剪方式:随机,中心,循环(滑窗)
class CropFactor:
r"""
思路: 先补全后踩点
"""
@staticmethod
def _padding_LR(pad_total):
r"""总补零 = 左补零 + 右补零"""
left_total = pad_total//2
return left_total, pad_total-left_total
@staticmethod
def _random_crop_LR(crop_size, total_size):
r"""随机裁剪"""
if crop_size < total_size: #范围内
padLR = (0, 0)
left_coor = 0
total_coors = list(range(left_coor, total_size - crop_size))
return total_coors, padLR
else:
padLR = CropFactor._padding_LR(crop_size-total_size)
left_coor = 0
total_coors = [left_coor]
return total_coors, padLR
@staticmethod
def _center_crop_LR(crop_size, total_size):
r"""中心裁剪"""
if crop_size < total_size: #范围内
padLR = (0, 0)
left_coor = (total_size - crop_size)//2
total_coors = [left_coor]
return total_coors, padLR
else:
padLR = CropFactor._padding_LR(crop_size-total_size)
left_coor = 0
total_coors = [left_coor]
return total_coors, padLR
@staticmethod
def _sequence_crop_LR(crop_size, total_size, stride):
r"""滑窗裁剪"""
if crop_size < total_size: #范围内
num = int(np.ceil((total_size-crop_size) / stride)) # s*n+k大于等于size
pad_total = num * stride + crop_size - total_size
padLR = CropFactor._padding_LR(pad_total)
total_coors = [stride*i for i in range(num)]
return total_coors, padLR
else:
padLR = CropFactor._padding_LR(crop_size - total_size)
left_coor = 0
total_coors = [left_coor]
return total_coors, padLR
class RandomCrop:
def __init__(self, random_state=np.random.RandomState(47), **kwds) -> None:
self.random_state = random_state
self.kwds = kwds
pass
def __call__(self, ndarray, **kwds) -> Any:
crop_size = self.kwds.get("crop_size", [96,144,144])
# FuncName = self.kwds.get("use_mode", "Img3D_crop")
# FuncName(ndarray, crop_size, **kwds)
return self.Img3D_crop(ndarray, crop_size, **kwds)
# return self.Img3D_crop(ndarray, crop_size, **kwds)
# 提供图片, 点集, 等数据类型的裁剪
def Img3D_crop(self, ndarray, crop_size, **kwds):
# param 1
if ndarray.ndim == 3:
z_size, y_size, x_size = ndarray.shape
elif ndarray.ndim == 4:
c_size, z_size, y_size, x_size = ndarray.shape
# param 2
crop_z, crop_y, crop_x = crop_size
# param 3
mode=kwds.get("mode", 'constant')
constant_values=kwds.get("constant_values", 0)
z_coors, z_pad = CropFactor._random_crop_LR(crop_z, z_size)
y_coors, y_pad = CropFactor._random_crop_LR(crop_y, y_size)
x_coors, x_pad = CropFactor._random_crop_LR(crop_x, x_size)
z_start = self.random_state.choice(z_coors)
y_start = self.random_state.choice(y_coors)
x_start = self.random_state.choice(x_coors)
# print(z_coors, z_pad, y_coors, y_pad, x_coors, x_pad)
if ndarray.ndim == 3: # 3D
ndarray_pad = np.pad(ndarray, pad_width=(z_pad, y_pad, x_pad), mode=mode, constant_values=constant_values)
ndarray_crop = ndarray_pad[z_start:z_start+crop_z, y_start:y_start+crop_y, x_start:x_start+crop_x]
return ndarray_crop
elif ndarray.ndim == 4:
ndarray_pad = np.pad(ndarray, pad_width=((0,0), z_pad, y_pad, x_pad), mode=mode, constant_values=constant_values)
channels = []
for c in range(c_size):
array_crop = ndarray_pad[c][z_start:z_start+crop_z, y_start:y_start+crop_y, x_start:x_start+crop_x]
channels.append(array_crop)
pass
ndarray_crop = np.stack(channels, axis=0)
return ndarray_crop
def Img2D_crop(self, ndarray, crop_size, **kwds):
# param 1
if ndarray.ndim == 2:
y_size, x_size = ndarray.shape
elif ndarray.ndim == 3:
c_size, y_size, x_size = ndarray.shape
# param 2
crop_y, crop_x = crop_size
# param 3
mode=kwds.get("mode", 'constant')
constant_values=kwds.get("constant_values", 0)
y_coors, y_pad = CropFactor._random_crop_LR(crop_y, y_size)
x_coors, x_pad = CropFactor._random_crop_LR(crop_x, x_size)
y_start = self.random_state.choice(y_coors)
x_start = self.random_state.choice(x_coors)
# print(y_coors, y_pad, x_coors, x_pad)
if ndarray.ndim == 2: # 2D
ndarray_pad = np.pad(ndarray, pad_width=(y_pad, x_pad), mode=mode, constant_values=constant_values)
ndarray_crop = ndarray_pad[y_start:y_start+crop_y, x_start:x_start+crop_x]
return ndarray_crop
elif ndarray.ndim == 3:
ndarray_pad = np.pad(ndarray, pad_width=((0,0), y_pad, x_pad), mode=mode, constant_values=constant_values)
channels = []
for c in range(c_size):
array_crop = ndarray_pad[c][y_start:y_start+crop_y, x_start:x_start+crop_x]
channels.append(array_crop)
pass
ndarray_crop = np.stack(channels, axis=0)
return ndarray_crop
@staticmethod
def point_crop():
pass
pass