class RandomCrop(object):
"""
Crop randomly the image in a sample. This is usually used for data augmentation.
Drop ratio is implemented for randomly dropout crops with empty label. (Default to be 0.2)
This transformation only applicable in train mode
Args:
output_size (tuple or int): Desired output size. If int, cubic crop is made.
"""
def __init__(self, output_size, drop_ratio=0.1, min_pixel=1):
self.name = 'Random Crop'
assert isinstance(output_size, (int, tuple, list))
if isinstance(output_size, int):
self.output_size = (output_size, output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
assert isinstance(drop_ratio, (int,float))
if drop_ratio >=0 and drop_ratio<=1:
self.drop_ratio = drop_ratio
else:
raise RuntimeError('Drop ratio should be between 0 and 1')
assert isinstance(min_pixel, int)
if min_pixel >=0 :
self.min_pixel = min_pixel
else:
raise RuntimeError('Min label pixel count should be integer larger than 0')
def __call__(self,sample):
image, label = sample['image'], sample['label']
size_old = image[0].GetSize()
size_new = self.output_size
contain_label = False
roiFilter = sitk.RegionOfInterestImageFilter()
roiFilter.SetSize([size_new[0],size_new[1]])
# statFilter = sitk.StatisticsImageFilter()
# statFilter.Execute(label)
# print(statFilter.GetMaximum(), statFilter.GetSum())
binaryThresholdFilter = sitk.BinaryThresholdImageFilter()
binaryThresholdFilter.SetLowerThreshold(1)
binaryThresholdFilter.SetUpperThreshold(255)
binaryThresholdFilter.SetInsideValue(1)
binaryThresholdFilter.SetOutsideValue(0)
label_ = binaryThresholdFilter.Execute(label)
# check if the whole slice contain label > minimum pixel
statFilter = sitk.StatisticsImageFilter()
statFilter.Execute(label_)
if statFilter.GetSum() < self.min_pixel:
contain_label = True
while not contain_label:
# get the start crop coordinate in ijk
if size_old[0] <= size_new[0]:
start_i = 0
else:
start_i = np.random.randint(0, size_old[0]-size_new[0])
if size_old[1] <= size_new[1]:
start_j = 0
else:
start_j = np.random.randint(0, size_old[1]-size_new[1])
roiFilter.SetIndex([start_i,start_j])
label_crop = roiFilter.Execute(label_)
statFilter.Execute(label_crop)
# will iterate until a sub volume containing label is extracted
# pixel_count = seg_crop.GetHeight()*seg_crop.GetWidth()*seg_crop.GetDepth()
# if statFilter.GetSum()/pixel_count<self.min_ratio:
if statFilter.GetSum()<self.min_pixel:
contain_label = self.drop(self.drop_ratio) # has some probabilty to contain patch with empty label
else:
contain_label = True
for image_channel in range(len(image)):
image[image_channel] = roiFilter.Execute(image[image_channel])
label = roiFilter.Execute(label)
return {'image': image, 'label': label}
def drop(self,probability):
return random.random() <= probability
医学图像预处理----RandomCrop(随机剪裁)
最新推荐文章于 2023-12-12 13:20:25 发布