# -*- coding: utf-8 -*-
# @Author: H.Yang
# @Date: 2021-08-14 11:58:04
# @Last Modified by: H.Yang
# @Last Modified time: 2021-08-14 11:59:27
import torch
import torch.nn as nn
test = torch.rand(1,1,3,3)
test[0][0][0][0]=1
test[0][0][0][1]=2
test[0][0][0][2]=3
test[0][0][1][0]=4
test[0][0][1][1]=5
test[0][0][1][2]=6
test[0][0][2][0]=7
test[0][0][2][1]=8
test[0][0][2][2]=9
print(test)
#tensor([[[[1., 2., 3.],
# [4., 5., 6.],
# [7., 8., 9.]]]])
sample_one = torch.zeros(1,1,1,2)
sample_one [0][0][0][0] = -1 # x
sample_one [0][0][0][1] = -1 # y
sample_two = torch.zeros(1,1,1,2)
sample_two [0][0][0][0] = -2/3 # x
sample_two [0][0][0][1] = -2/3 # y
sample_thr = torch.zeros(1,1,1,2)
sample_thr [0][0][0][0] = -0.8 # x
sample_thr [0][0][0][1] = -0.3 # y
result_one = torch.nn.functional.grid_sample(test,sample_one,mode='bilinear',padding_mode="zeros",align_corners=True)
print(result_one )
#tensor([[[[1.]]]])
result_two= torch.nn.functional.grid_sample(test,sample_two ,mode='bilinear',padding_mode="zeros",align_corners=False)
print(result_two)
#tensor([[[[1.]]]])
result_thr= torch.nn.functional.grid_sample(test,sample_thr ,mode='bilinear',padding_mode="zeros",align_corners=True)
print(result_thr)
#tensor([[[[3.3000]]]]) 通过双线性插值,坐标距离越近权重越大,直接算就可以得到(1*0.8+2*0.2)*0.3+(4*0.8+5*0.2)*0.7=3
torch.nn.functional.grid_sample实验
最新推荐文章于 2024-06-20 19:31:58 发布
关键词由CSDN通过智能技术生成