from torch.nn import functional as F
import torch
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import math
img_path = r'xxxdrr.jpg'
img_itk = sitk.ReadImage(img_path)
img = sitk.GetArrayFromImage(img_itk)
plt.close('all')
plt.imshow(img)
plt.show()
param = [0.52359877559829887307710723054658,0,0,0]
transform = sitk.Euler2DTransform()
transform.SetParameters(param)
resample = sitk.ResampleImageFilter()
resample.SetInterpolator(sitk.sitkLinear)
resample.SetTransform(transform)
resample.SetReferenceImage(img_itk)
roat_sitk = resample.Execute(img_itk)
rot = sitk.GetArrayFromImage(roat_sitk)
plt.imshow(rot)
plt.show()
def gen_2d_mesh_grid(h, w):
# move into self to save compute?
h_s = torch.linspace(-1, 1, h)
w_s = torch.linspace(-1, 1, w)
h_s, w_s = torch.meshgrid([ h_s, w_s])
one_s = torch.ones_like(w_s)
mesh_grid = torch.stack([w_s, h_s,one_s])
return mesh_grid # 3 x h x w
def affine_2d_grid(theta, size):
b, c, h, w = size
mesh_grid = gen_2d_mesh_grid(h, w)
mesh_grid = mesh_grid.unsqueeze(0)
mesh_grid = mesh_grid.repeat(b, 1, 1, 1) # channel dim = 4
mesh_grid = mesh_grid.view(b, 3, -1)
mesh_grid = mesh_grid+1
mesh_grid = torch.bmm(theta, mesh_grid) # channel dim = 3
mesh_grid = mesh_grid - 1
mesh_grid = mesh_grid.permute(0, 2, 1) # move channel to last dim
mesh_grid = mesh_grid.view(b, h, w, 2)
return mesh_grid
img_torch = torch.from_numpy(img).float().unsqueeze(0)
angle = 0.52359877559829887307710723054658
theta = torch.tensor([
[np.cos(angle),-np.sin(angle),0],
[np.sin(angle),np.cos(angle),0]
], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())#
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().squeeze(0))
plt.show()
grid = affine_2d_grid( theta.unsqueeze(0), img_torch.unsqueeze(0).size())
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().squeeze(0))
plt.show()