pytorch 将affine_grid的旋转中心改到左上角

from torch.nn import functional as F
import torch
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import math
img_path = r'xxxdrr.jpg'
img_itk = sitk.ReadImage(img_path)
img = sitk.GetArrayFromImage(img_itk)
plt.close('all')

plt.imshow(img)
plt.show()
param = [0.52359877559829887307710723054658,0,0,0]
transform = sitk.Euler2DTransform()
transform.SetParameters(param)
resample = sitk.ResampleImageFilter()
resample.SetInterpolator(sitk.sitkLinear)
resample.SetTransform(transform)
resample.SetReferenceImage(img_itk)
roat_sitk = resample.Execute(img_itk)
rot = sitk.GetArrayFromImage(roat_sitk)
plt.imshow(rot)
plt.show()


def gen_2d_mesh_grid(h, w):
    # move into self to save compute?
    h_s = torch.linspace(-1, 1, h)
    w_s = torch.linspace(-1, 1, w)
    h_s, w_s = torch.meshgrid([ h_s, w_s])
    one_s = torch.ones_like(w_s)
    mesh_grid = torch.stack([w_s, h_s,one_s])
    return mesh_grid  # 3 x h x w
def affine_2d_grid(theta, size):
    b, c, h, w = size
    mesh_grid = gen_2d_mesh_grid(h, w)
    mesh_grid = mesh_grid.unsqueeze(0)
    mesh_grid = mesh_grid.repeat(b, 1,  1, 1)  # channel dim = 4
    mesh_grid = mesh_grid.view(b, 3, -1)
    mesh_grid = mesh_grid+1
    mesh_grid = torch.bmm(theta, mesh_grid)  # channel dim = 3
    mesh_grid = mesh_grid - 1
    mesh_grid = mesh_grid.permute(0, 2, 1)  # move channel to last dim
    mesh_grid = mesh_grid.view(b, h, w, 2)
    return mesh_grid



img_torch = torch.from_numpy(img).float().unsqueeze(0)
angle = 0.52359877559829887307710723054658
theta = torch.tensor([
    [np.cos(angle),-np.sin(angle),0],
    [np.sin(angle),np.cos(angle),0]
], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())#
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().squeeze(0))
plt.show()

grid = affine_2d_grid( theta.unsqueeze(0), img_torch.unsqueeze(0).size())
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().squeeze(0))
plt.show()

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值