本文是记录一些在深度学习中的预处理的一些语法和函数
torchvision.transforms的图像变换
2D、3D中心裁剪:
import random
def random_crop_2d(img, label, crop_size):
random_x_max = img.shape[0] - crop_size[0]
random_y_max = img.shape[1] - crop_size[1]
if random_x_max < 0 or random_y_max < 0:
return None
x_random = random.randint(0, random_x_max)
y_random = random.randint(0, random_y_max)
crop_img = img[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1]]
crop_label = label[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1]]
return crop_img, crop_label
def random_crop_3d(img, label, crop_size):
random_x_max = img.shape[0] - crop_size[0]
random_y_max = img.shape[1] - crop_size[1]
random_z_max = img.shape[2] - crop_size[2]
if random_x_max < 0 or random_y_max < 0 or random_z_max < 0:
return None
x_random = random.randint(0, random_x_max)
y_random = random.randint(0, random_y_max)
z_random = random.randint(0, random_z_max)
crop_img = img[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1], z_random:z_random + crop_size[2]]
crop_label = label[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1],
z_random:z_random + crop_size[2]]
return crop_img, crop_label
class RandomCrop_3d:
def __init__(self, slices):
self.slices = slices
def _get_range(self, slices, crop_slices):
if slices < crop_slices:
start = 0
else:
start = random.randint(0, slices - crop_slices)
end = start + crop_slices
if end > slices:
end = slices
return start, end
def __call__(self, img, mask):
ss, es = self._get_range(mask.size(0), self.slices)
tmp_img = torch.zeros((img.size(0), self.slices, img.size(2),img.size(3)))
tmp_mask = torch.zeros((mask.size(0), self.slices, mask.size(2),mask.size(3)))
tmp_img[:, :es - ss] = img[:, ss:es]
tmp_mask[:, :es - ss] = mask[:, ss:es]
return tmp_img, tmp_mask
transforms的一些图像处理算法:
"""
This part is based on the dataset class implemented by pytorch,
including train_dataset and test_dataset, as well as data augmentation
"""
from torch.utils.data import Dataset
import torch
import numpy as np
import random
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms.functional import normalize
#----------------------data augment-------------------------------------------
class Resize:
def __init__(self, scale):
# self.shape = [shape, shape, shape] if isinstance(shape, int) else shape
self.scale = scale
def __call__(self, img, mask):
img, mask = img.unsqueeze(0), mask.unsqueeze(0).float()
img = F.interpolate(img, scale_factor=(1,self.scale,self.scale),mode='trilinear', align_corners=False, recompute_scale_factor=True)
mask = F.interpolate(mask, scale_factor=(1,self.scale,self.scale), mode="nearest", recompute_scale_factor=True)
return img[0], mask[0]
class RandomResize:
def __init__(self,s_rank, w_rank,h_rank):
self.w_rank = w_rank
self.h_rank = h_rank
self.s_rank = s_rank
def __call__(self, img, mask):
random_w = random.randint(self.w_rank[0],self.w_rank[1])
random_h = random.randint(self.h_rank[0],self.h_rank[1])
random_s = random.randint(self.s_rank[0],self.s_rank[1])
self.shape = [random_s,random_h,random_w]
img, mask = img.unsqueeze(0), mask.unsqueeze(0).float()
img = F.interpolate(img, size=self.shape,mode='trilinear', align_corners=False)
mask = F.interpolate(mask, size=self.shape, mode="nearest")
return img[0], mask[0].long()
class RandomCrop:
def __init__(self, slices):
self.slices = slices
def _get_range(self, slices, crop_slices):
if slices < crop_slices:
start = 0
else:
start = random.randint(0, slices - crop_slices)
end = start + crop_slices
if end > slices:
end = slices
return start, end
def __call__(self, img, mask):
ss, es = self._get_range(mask.size(1), self.slices)
# print(self.shape, img.shape, mask.shape)
tmp_img = torch.zeros((img.size(0), self.slices, img.size(2), img.size(3)))
tmp_mask = torch.zeros((mask.size(0), self.slices, mask.size(2), mask.size(3)))
tmp_img[:,:es-ss] = img[:,ss:es]
tmp_mask[:,:es-ss] = mask[:,ss:es]
return tmp_img, tmp_mask
class RandomFlip_LR:
def __init__(self, prob=0.5):
self.prob = prob
def _flip(self, img, prob):
if prob[0] <= self.prob:
img = img.flip(2)
return img
def __call__(self, img, mask):
prob = (random.uniform(0, 1), random.uniform(0, 1))
return self._flip(img, prob), self._flip(mask, prob)
class RandomFlip_UD:
def __init__(self, prob=0.5):
self.prob = prob
def _flip(self, img, prob):
if prob[1] <= self.prob:
img = img.flip(3)
return img
def __call__(self, img, mask):
prob = (random.uniform(0, 1), random.uniform(0, 1))
return self._flip(img, prob), self._flip(mask, prob)
class RandomRotate:
def __init__(self, max_cnt=3):
self.max_cnt = max_cnt
def _rotate(self, img, cnt):
img = torch.rot90(img,cnt,[1,2])
return img
def __call__(self, img, mask):
cnt = random.randint(0,self.max_cnt)
return self._rotate(img, cnt), self._rotate(mask, cnt)
class Center_Crop:
def __init__(self, base, max_size):
self.base = base # base默认取16,因为4次下采样后为1
self.max_size = max_size
if self.max_size%self.base:
self.max_size = self.max_size - self.max_size%self.base # max_size为限制最大采样slices数,防止显存溢出,同时也应为16的倍数
def __call__(self, img , label):
if img.size(1) < self.base:
return None
slice_num = img.size(1) - img.size(1) % self.base
slice_num = min(self.max_size, slice_num)
left = img.size(1)//2 - slice_num//2
right = img.size(1)//2 + slice_num//2
crop_img = img[:,left:right]
crop_label = label[:,left:right]
return crop_img, crop_label
class ToTensor:
def __init__(self):
self.to_tensor = transforms.ToTensor()
def __call__(self, img, mask):
img = self.to_tensor(img)
mask = torch.from_numpy(np.array(mask))
return img, mask[None]
class Normalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, img, mask):
return normalize(img, self.mean, self.std, False), mask
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, mask):
for t in self.transforms:
img, mask = t(img, mask)
return img, mask
把数据集分成train + Val 参数自己定义
import torch
import os
import shutil
from random import sample
root = './data/head'
Origin = 'images'
Segmen = 'labels'
n = 0.8
# 数据分类处理
## 提取文件夹内的名字
data_file = os.listdir(f'{root}/{Origin}')
segm_file = os.listdir(f'{root}/{Segmen}')
train_size = int(len(data_file) * n)
train_img_url = sample(data_file,train_size)
val_img_url = list(set(data_file)^set(train_img_url)) ## 求差集
## 移动图片
for i in range(len(train_img_url)): ## 移动train
if not os.path.exists(f'{root}/train'):
os.mkdir(f'{root}/train') ## 创建train
if not os.path.exists(f'{root}/train/{Origin}'):
os.mkdir(f'{root}/train/{Origin}') ## 创建源图文件夹
if not os.path.exists(f'{root}/train/{Segmen}'):
os.mkdir(f'{root}/train/{Segmen}') ## 创建分割文件夹
## 转移源图
image = os.path.join(f'{root}/{Origin}', train_img_url[i])
image = image.replace('\\', '/')
image_class = os.path.join(f'{root}/train/{Origin}', train_img_url[i])
image_class = image_class.replace('\\', '/')
shutil.copy(image, f'{root}/train/{Origin}')
## 转移分割图
seg = os.path.join(f'{root}/{Segmen}', train_img_url[i].replace('jpg', 'png'))
seg = seg.replace('\\', '/')
seg_class = os.path.join(f'{root}/train/{Segmen}', train_img_url[i].replace('jpg', 'png'))
seg_class = seg_class.replace('\\', '/')
shutil.copy(seg, f'{root}/train/{Segmen}')
for i in range(len(val_img_url)): ## 移动Val
if not os.path.exists(f'{root}/Val'):
os.mkdir(f'{root}/Val') ## 创建Val
if not os.path.exists(f'{root}/Val/{Origin}'):
os.mkdir(f'{root}/Val/{Origin}') ## 创建源图文件夹
if not os.path.exists(f'{root}/Val/{Segmen}'):
os.mkdir(f'{root}/Val/{Segmen}') ## 创建分割文件夹
## 转移源图
image = os.path.join(f'{root}/{Origin}', val_img_url[i])
image = image.replace('\\', '/')
image_class = os.path.join(f'{root}/Val/{Origin}', val_img_url[i])
image_class = image_class.replace('\\', '/')
shutil.copy(image, f'{root}/Val/{Origin}')
## 转移分割图
seg = os.path.join(f'{root}/{Segmen}', val_img_url[i].replace('jpg', 'png'))
seg = seg.replace('\\', '/')
seg_class = os.path.join(f'{root}/Val/{Segmen}', val_img_url[i].replace('jpg', 'png'))
seg_class = seg_class.replace('\\', '/')
shutil.copy(seg, f'{root}/Val/{Segmen}')
f = open(os.path.join(root, 'train_path_list.txt'), 'w')
for name in train_img_url:
ct_path = os.path.join(f'{root}/{Origin}', name)
f.write(ct_path + "\n")
f.close()
f = open(os.path.join(root, 'val_path_list.txt'), 'w')
for name in val_img_url:
ct_path = os.path.join(f'{root}/{Origin}', name)
f.write(ct_path + "\n")
f.close()
实现Nii格式的CT医学图像三维重建:
import vtk
def showNiiVtk3D(niipath):
render = vtk.vtkRenderer() # 搭建舞台,实例化对象render
renWin = vtk.vtkRenderWindow() # 实例化窗口对象
ir = vtk.vtkRenderWindowInteractor() # 定义一个为鼠标/键/时间事件提供独立于平台的交互机制
ir.SetRenderWindow(renWin) # 将ir机制关联到设置好的renWin
renWin.AddRenderer(render) # 将舞台render加入到renWin窗口中
style = vtk.vtkInteractorStyleTrackballCamera() # 定义对象,当移动摄像头、按键、屏幕上的所有内容都会动起来。
ir.SetInteractorStyle(style) # 将style对象关联到鼠标/按键机制
reader = vtk.vtkNIFTIImageReader() # 定义一个文件读取对象
reader.SetFileName(niipath) # 读取文件
contourfilter = vtk.vtkContourFilter() # 过滤器vtkContourFilter用于从数据中抽取一系列等值面。
contourfilter.SetInputConnection(reader.GetOutputPort())
contourfilter.SetValue(0, 250)
smooth = vtk.vtkSmoothPolyDataFilter() # 光滑图像
smooth.SetInputConnection(contourfilter.GetOutputPort())
smooth.SetNumberOfIterations(300)
normal = vtk.vtkPolyDataNormals() # 法线
normal.SetInputConnection(smooth.GetOutputPort())
normal.SetFeatureAngle(60)
conMapper = vtk.vtkPolyDataMapper() # 实例化映射器conMapper
conMapper.SetInputConnection(normal.GetOutputPort()) # 源数据输入给映射器输入
conMapper.ScalarVisibilityOff()
conActor = vtk.vtkActor() # 创建演员对象
conActor.SetMapper(conMapper) # 为演员指定mapper进行映射
conActor.GetProperty().SetColor(1, 0, 0) # 设置演员颜色为红色
render.AddActor(conActor) # 将演员加入到场景
boxFilter = vtk.vtkOutlineFilter()
boxFilter.SetInputConnection(reader.GetOutputPort())
boxMapper = vtk.vtkPolyDataMapper()
boxMapper.SetInputConnection(boxFilter.GetOutputPort())
boxActor = vtk.vtkActor()
boxActor.SetMapper(boxMapper)
boxActor.GetProperty().SetColor(0, 1, 0)
camera = vtk.vtkCamera()
camera.SetViewUp(0, 0, -1)
camera.SetPosition(0, 1, 0)
camera.SetFocalPoint(0, 0, 0)
camera.ComputeViewPlaneNormal()
camera.Dolly(1.5)
render.SetActiveCamera(camera)
render.ResetCamera()
ir.Initialize()
ir.Start()
showNiiVtk3D('F:\\data\\diao_0.nii')
另一版代码(不推荐):
import vtk
reader = vtk.vtkNIFTIImageReader()
reader.SetFileName('./fixed_data/ct/volume-27.nii')
reader.Update()
mapper = vtk.vtkGPUVolumeRayCastMapper()
mapper.SetInputData(reader.GetOutput())
volume = vtk.vtkVolume()
volume.SetMapper(mapper)
property = vtk.vtkVolumeProperty()
popacity = vtk.vtkPiecewiseFunction()
popacity.AddPoint(1000, 0.0)
popacity.AddPoint(4000, 0.68)
popacity.AddPoint(7000, 0.83)
color = vtk.vtkColorTransferFunction()
color.AddHSVPoint(1000, 0.042, 0.73, 0.55)
color.AddHSVPoint(2500, 0.042, 0.73, 0.55, 0.5, 0.92)
color.AddHSVPoint(4000, 0.088, 0.67, 0.88)
color.AddHSVPoint(5500, 0.088, 0.67, 0.88, 0.33, 0.45)
color.AddHSVPoint(7000, 0.95, 0.063, 1.0)
property.SetColor(color)
property.SetScalarOpacity(popacity)
property.ShadeOn()
property.SetInterpolationTypeToLinear()
property.SetShade(0, 1)
property.SetDiffuse(0.9)
property.SetAmbient(0.1)
property.SetSpecular(0.2)
property.SetSpecularPower(10.0)
property.SetComponentWeight(0, 1)
property.SetDisableGradientOpacity(1)
property.DisableGradientOpacityOn()
property.SetScalarOpacityUnitDistance(0.891927)
volume.SetProperty(property)
ren = vtk.vtkRenderer()
ren.AddActor(volume)
ren.SetBackground(0.1, 0.2, 0.4)
renWin = vtk.vtkRenderWindow()
renWin.AddRenderer(ren)
iren = vtk.vtkRenderWindowInteractor()
iren.SetRenderWindow(renWin)
renWin.SetSize(600, 600)
renWin.Render()
iren.Start()
读取nii医学图像格式图片:
## 方法一
itk_img = sitk.ReadImage('./fixed_data/ct/volume-27.nii')
img = sitk.GetArrayFromImage(itk_img)
print(img.shape) # (155, 240, 240) 表示各个维度的切片数量
Monai框架语法:
推荐文章:使用MONAI深度学习框架进行3D图像空间变换_不入流儿的博客-CSDN博客_monai框架
load_decathlon_datalist:
使用 load_decathlon_datalist (MONAI)快速加载JSON数据_Tina姐的博客-CSDN博客
CacheDataset:
monai.data.CacheDataset vs monai.data.Dataset_Tina姐的博客-CSDN博客
sliding_window_inference:
LoadImage:(加载图片)
from monai.transforms import LoadImage, LoadImageD
dict_loader = LoadImage(dtype=np.float32, image_only=True)
data_dict = dict_loader("../Coronary_Segmentation_deep_learning/dataset/10423186/img.nii.gz")
print(data_dict.shape)
## 此时的data_dict是Tensor类型的
LoadImageD:(加载图片)
from monai.transforms import LoadImage, LoadImageD
dict_loader = LoadImaged(keys=("image", "label"), image_only=False)
# data_dict = loader(字典列表[0])
data_dict = dict_loader({"image": "../Coronary_Segmentation_deep_learning/dataset/10423186/img.nii.gz",
"label": "../Coronary_Segmentation_deep_learning/dataset/10423186/img.nii.gz"})
SKImage的官网:
skimgae的Frangi滤波:
from skimage.filters import ridges
path = 'D:/test/Python/GAN/SegAN-master/VOC/head/images/001.jpg'
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
cv2.imshow('1',img)
cv2.waitKey(0)
cv2.destroyAllWindows()
'''
‘## 注意这里要转换float64格式,不然不显示图像’
'''
img = img.astype(np.float64)
img = ridges.frangi(img, sigmas=range(1, 2, 2), black_ridges=False)