def resample_mesh(img, target_spacing ,mode="bilinear",device='cuda:0',is_seg = False):
data = sitk.GetArrayFromImage(img)
size = img.GetSize()[::-1]
origin = img.GetSpacing()[::-1]
spacing = img.GetSpacing()[::-1]
target_size = np.round(np.array(spacing) / np.array(target_spacing) * np.array(size)).astype(int)
target_origin = np.array(origin) + (
(np.array(size) - 1) * np.array(spacing) - (np.array(target_size) - 1) * np.array(target_spacing)) * 0.5
d = torch.linspace(-1, 1, target_size[0])
h = torch.linspace(-1, 1, target_size[1])
w = torch.linspace(-1, 1, target_size[2])
meshz, meshy, meshx = torch.meshgrid((d, h, w))
grid = torch.stack((meshx, meshy, meshz), 3)
grid = grid.unsqueeze(0).to(device=device) # add batch dim
tensor = torch.from_numpy(data).unsqueeze(0).unsqueeze(0)
tensor = tensor.to(device=device).to(grid.dtype)
out = F.grid_sample(tensor, grid, mode=mode, padding_mode="border", align_corners=True)
out = out.cpu().detach().numpy().squeeze().squeeze()
if is_seg :
out[out>0]=255
out = out.astype(data.dtype)
rmg = sitk.GetImageFromArray(out)
rmg.SetSpacing(target_spacing)
rmg.SetOrigin(target_origin)
return rmg
实际采样结果感觉和simpleitk的结果存在一层的偏差,不知道问题在哪。