pytorch中转onnx时grid_sample没有相对应的算子导致不能转成功,网上的方案都是四维(4D)数据输入,但我的输入是五维(5D),找了一下代码,基于一些代码做了修改。代码没有考虑性能,只考虑了正常导出
import torch
from torch import nn
import torch.nn.functional as F
def grid_sample_3d(input, grid,align_corners):
N, C, ID, IH, IW = input.shape
_, D, H, W, _ = grid.shape
ix = grid[..., 0]
iy = grid[..., 1]
iz = grid[..., 2]
if(align_corners == False):
ix = ((ix + 1) * IW - 1) / 2
iy = ((iy + 1) * IH - 1) / 2
iz = ((iz + 1) * ID - 1) / 2
else:
ix = ((ix + 1) / 2) * (IW - 1)
iy = ((iy + 1) / 2) * (IH - 1)
iz = ((iz + 1) / 2) * (ID - 1)
with torch.no_grad():
ix_tnw = torch.floor(ix);
iy_tnw = torch.floor(iy);
iz_tnw = torch.floor(iz);
ix_tne = ix_tnw + 1;
iy_tne = iy_tnw;
iz_tne = iz_tnw;
ix_tsw = ix_tnw;
iy_tsw = iy_tnw + 1;
iz_tsw = iz_tnw;
ix_tse = ix_tnw + 1;
iy_tse = iy_tnw + 1;
iz_tse = iz_tnw;
ix_bnw = ix_tnw;
iy_bnw = iy_tnw;
iz_bnw = iz_tnw + 1;
ix_bne = ix_tnw + 1;
iy_bne = iy_tnw;
iz_bne = iz_tnw + 1;
ix_bsw = ix_tnw;
iy_bsw = iy_tnw + 1;
iz_bsw = iz_tnw + 1;
ix_bse = ix_tnw + 1;
iy_bse = iy_tnw + 1;
iz_bse = iz_tnw + 1;
tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
with torch.no_grad():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if (ix_tnw.device.type != 'cpu'):
print(ix_tnw.device.type)
print("---------------------")
zero_tensor = torch.tensor(0).float().to(device)
iw_tensor = torch.tensor(IW - 1).float().to(device)
ih_tensor = torch.tensor(IH - 1).float().to(device)
id_tensor = torch.tensor(ID - 1).float().to(device)
else:
zero_tensor = torch.tensor(0).float()
iw_tensor = torch.tensor(IW - 1).float()
ih_tensor = torch.tensor(IH - 1).float()
id_tensor = torch.tensor(ID - 1).float()
ix_tnw = torch.where(ix_tnw < 0, zero_tensor, ix_tnw.float())
ix_tnw = torch.where(ix_tnw > IW - 1, iw_tensor, ix_tnw.float())
iy_tnw = torch.where(iy_tnw < 0, zero_tensor, iy_tnw.float())
iy_tnw = torch.where(iy_tnw > IH - 1, ih_tensor, iy_tnw.float())
iz_tnw = torch.where(iz_tnw < 0, zero_tensor, iz_tnw.float())
iz_tnw = torch.where(iz_tnw > ID - 1, id_tensor, iz_tnw.float())
ix_tne = torch.where(ix_tne < 0, zero_tensor, ix_tne.float())
ix_tne = torch.where(ix_tne > IW - 1, iw_tensor, ix_tne.float())
iy_tne = torch.where(iy_tne < 0, zero_tensor, iy_tne.float())
iy_tne = torch.where(iy_tne > IH - 1, ih_tensor, iy_tne.float())
iz_tne = torch.where(iz_tne < 0, zero_tensor, iz_tne.float())
iz_tne = torch.where(iz_tne > ID - 1, id_tensor, iz_tne.float())
ix_tsw = torch.where(ix_tsw < 0, zero_tensor, ix_tsw.float())
ix_tsw = torch.where(ix_tsw > IW - 1, iw_tensor, ix_tsw.float())
iy_tsw = torch.where(iy_tsw < 0, zero_tensor, iy_tsw.float())
iy_tsw = torch.where(iy_tsw > IH - 1, ih_tensor, iy_tsw.float())
iz_tsw = torch.where(iz_tsw < 0, zero_tensor, iz_tsw.float())
iz_tsw = torch.where(iz_tsw > ID - 1, id_tensor, iz_tsw.float())
ix_tse = torch.where(ix_tse < 0, zero_tensor, ix_tse.float())
ix_tse = torch.where(ix_tse > IW - 1, iw_tensor, ix_tse.float())
iy_tse = torch.where(iy_tse < 0, zero_tensor, iy_tse.float())
iy_tse = torch.where(iy_tse > IH - 1, ih_tensor, iy_tse.float())
iz_tse = torch.where(iz_tse < 0, zero_tensor, iz_tse.float())
iz_tse = torch.where(iz_tse > ID - 1, id_tensor, iz_tse.float())
ix_bnw = torch.where(ix_bnw < 0, zero_tensor, ix_bnw.float())
ix_bnw = torch.where(ix_bnw > IW - 1, iw_tensor, ix_bnw.float())
iy_bnw = torch.where(iy_bnw < 0, zero_tensor, iy_bnw.float())
iy_bnw = torch.where(iy_bnw > IH - 1, ih_tensor, iy_bnw.float())
iz_bnw = torch.where(iz_bnw < 0, zero_tensor, iz_bnw.float())
iz_bnw = torch.where(iz_bnw > ID - 1, id_tensor, iz_bnw.float())
ix_bne = torch.where(ix_bne < 0, zero_tensor, ix_bne.float())
ix_bne = torch.where(ix_bne > IW - 1, iw_tensor, ix_bne.float())
iy_bne = torch.where(iy_bne < 0, zero_tensor, iy_bne.float())
iy_bne = torch.where(iy_bne > IH - 1, ih_tensor, iy_bne.float())
iz_bne = torch.where(iz_bne < 0, zero_tensor, iz_bne.float())
iz_bne = torch.where(iz_bne > ID - 1, id_tensor, iz_bne.float())
ix_bsw = torch.where(ix_bsw < 0, zero_tensor, ix_bsw.float())
ix_bsw = torch.where(ix_bsw > IW - 1, iw_tensor, ix_bsw.float())
iy_bsw = torch.where(iy_bsw < 0, zero_tensor, iy_bsw.float())
iy_bsw = torch.where(iy_bsw > IH - 1, ih_tensor, iy_bsw.float())
iz_bsw = torch.where(iz_bsw < 0, zero_tensor, iz_bsw.float())
iz_bsw = torch.where(iz_bsw > ID - 1, id_tensor, iz_bsw.float())
ix_bse = torch.where(ix_bse < 0, zero_tensor, ix_bse.float())
ix_bse = torch.where(ix_bse > IW - 1, iw_tensor, ix_bse.float())
iy_bse = torch.where(iy_bse < 0, zero_tensor, iy_bse.float())
iy_bse = torch.where(iy_bse > IH - 1, ih_tensor, iy_bse.float())
iz_bse = torch.where(iz_bse < 0, zero_tensor, iz_bse.float())
iz_bse = torch.where(iz_bse > ID - 1, id_tensor, iz_bse.float())
input = input.view(N, C, ID * IH * IW)
tnw_val = torch.gather(input, 2, (iz_tnw * IW * IH + iy_tnw * IW + ix_tnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
tne_val = torch.gather(input, 2, (iz_tne * IW * IH + iy_tne * IW + ix_tne).long().view(N, 1, D * H * W).repeat(1, C, 1))
tsw_val = torch.gather(input, 2, (iz_tsw * IW * IH + iy_tsw * IW + ix_tsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
tse_val = torch.gather(input, 2, (iz_tse * IW * IH + iy_tse * IW + ix_tse).long().view(N, 1, D * H * W).repeat(1, C, 1))
bnw_val = torch.gather(input, 2, (iz_bnw * IW * IH + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
bne_val = torch.gather(input, 2, (iz_bne * IW * IH + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1))
bsw_val = torch.gather(input, 2, (iz_bsw * IW * IH + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
bse_val = torch.gather(input, 2, (iz_bse * IW * IH + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1))
out_val = (tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) +
tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) +
tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) +
tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W) +
bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) +
bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) +
bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) +
bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W))
return out_val
if __name__ == "__main__":
data = torch.rand(1,1,256,104,80)
grid = torch.rand(1,256,104,80,3)
ret = F.grid_sample(data,grid,align_corners=False).squeeze(1)
print(ret)
ret2 = grid_sample_3d(data,grid,False).squeeze(1)
print(ret2)
感谢
https://github.com/pytorch/pytorch/issues/34704