双线性插值实现(bilinear)
import torch.nn.functional as F
import torch
def linear(x, x1, x2, v1, v2):
return (x-x1)/(x2-x1)*(v2-v1)+v1
def bilinear(x, y, src):
x1 = int(x)
y1 = int(y)
x2 = x1 + 1
y2 = y1 + 1
h, w = src.shape
if x == h-1 or y == w-1:
if x == h-1 and y < w-1:
value = linear(y, y1, y2, src[x1, y1], src[x1, y2])
elif x < h-1 and y == w-1:
value = linear(x, x1, x2, src[x1, y1], src[x2, y1])
else:
value = src[x1, y1]
else:
v11 = src[x1, y1]
v21 = src[x2, y1]
v12 = src[x1, y2]
v22 = src[x2, y2]
v_y1 = linear(x, x1, x2, v11, v21)
v_y2 = linear(x, x1, x2, v12, v22)
value = linear(y, y1, y2, v_y1, v_y2)
return value
def get_src_coordinate(x, y, src, des):
src_h, src_w = src.shape
des_h, des_w = des.shape
src_x = src_h/des_h*(x+0.5)-0.5
src_y = src_w/des_w*(y+0.5)-0.5
src_x = get_new_coordinate(src_x, src_h-1)
src_y = get_new_coordinate(src_y, src_w-1)
return src_x, src_y
def get_new_coordinate(value, boundary):
value = max(value, 0)
value = min(value, boundary)
return value
def get_des(src, des_shape):
des = torch.zeros(des_shape)
h, w = des_shape
for x in range(h):
for y in range(w):
src_x, src_y = get_src_coordinate(x, y, src, des)
des[x][y] = bilinear(src_x, src_y, src)
return des
a = torch.arange(9, dtype=torch.float32).reshape(1, 1, 3, 3)
b = F.interpolate(a, size=(5, 5), mode='bilinear')
a = a.squeeze()
b = b.squeeze()
des = get_des(a, b.shape)
print(torch.sum(b - des))
参考链接:https://blog.csdn.net/Big_Huang/article/details/106209992