在将pth转onnx过程中,报错:RuntimeError: Exporting the operator repeat_interleave to ONNX opset version 11 is not supported.
由于repeat_interleave函数是将张量中的元素按照某一维度复制n次,比如torch.repeat_interleave(x, k, dim=2), 因此可以用repeat和view函数代替进行实现。
(如有问题,请指正!)
例如:
import torch
x = torch.randn(2, 1, 4, 1)
print(x)
print("--------------------------------------------------")
y1 = torch.repeat_interleave(x, 2, dim=2)
print(y1)
print("--------------------------------------------------")
y2 = x.repeat(1, 1, 1, 2).view(x.shape[0], x.shape[1], -1, x.shape[3])
print(y2)
得到的结果是:
tensor([[[[-1.9781],
[-1.3968],
[-0.6847],
[ 1.1089]]],
[[[-0.5929],
[ 1.7383],
[-0.6736],
[-1.2737]]]])
------------------------------------------------------------------------
tensor([[[[-1.9781],
[-1.9781],
[-1.3968],
[-1.3968],
[-0.6847],
[-0.6847],
[ 1.1089],
[ 1.1089]]],
[[[-0.5929],
[-0.5929],
[ 1.7383],
[ 1.7383],
[-0.6736],
[-0.6736],
[-1.2737],
[-1.2737]]]])
------------------------------------------------------------------------
tensor([[[[-1.9781],
[-1.9781],
[-1.3968],
[-1.3968],
[-0.6847],
[-0.6847],
[ 1.1089],
[ 1.1089]]],
[[[-0.5929],
[-0.5929],
[ 1.7383],
[ 1.7383],
[-0.6736],
[-0.6736],
[-1.2737],
[-1.2737]]]])