使用grid_sample和fold提取块的操作
import torch
import numpy as np
from imageio import imread, imsave
from PIL import Image
import torch.nn.functional as F
from Patch_model.DeformSearch import DeformSearch
Ref = imread("0.png")
Ref = Ref.astype(np.float32)
Ref = Ref / 127.5 - 1.
Ref_t = torch.from_numpy(Ref.transpose((2,0,1))).unsqueeze(0).float().cuda()#[b,c,h,w]
b,c,h,w =Ref_t.size()
y, x = torch.meshgrid(torch.arange(0, h,device='cuda'), torch.arange(0, w,device='cuda'))
y=y.view(1,1,1,h,w)
x=x.view(1,1,1,h,w)
DeformSearch3 = DeformSearch(3 * 3, 3 ).cuda()##offset [B,K,2N,H,W] inref # [B,K,1,H,W] x[B,C,H2,W2]
offset=torch.zeros(1,1,18,h,w).cuda()
s=DeformSearch3(offset, Ref_t, y.float(), x.float()).squeeze(1)# [1,1,3*9,H*W]
T_lv3 = F.fold(s, output_size=Ref_t.size()[-2:], kernel_size=(3, 3), padding=1) / (3. * 3.)#[B,c,H,W]
tp_save = (T_lv3 + 1.) * 127.5
tp_save = np.transpose(tp_save.squeeze().round().cpu().numpy(), (1, 2, 0)).astype(np.uint8)
save_path = "xiar.png"
imsave(save_path, tp_save)
imsave("xir.jpg", tp_save)