[机器学习]Pytorch导出onnx报错“RuntimeError: Cannot insert a Tensor that requires grad as a constant...“

出错的代码是这样的:

import torch
import torch.nn as nn

#############################################
# Define 
#############################################

# Custom BI-LSTM, because unity barracuda dont support "bidirectional = True"
class BILSTM():
    def __init__(self, inputSize, hiddenSize, numLayers):
        self.biLayer1 = nn.LSTM(input_size=inputSize,hidden_size=hiddenSize, num_layers=numLayers, batch_first=True, dropout=0.5).cuda()
        self.biLayer2 = nn.LSTM(input_size=inputSize,hidden_size=hiddenSize, num_layers=numLayers, batch_first=True, dropout=0.5).cuda()
 
    def forward(self, x):
        out1, (hidden1, _) = self.biLayer1(x)
        out2, (hidden2, _) = self.biLayer2(torch.flip(x, dims=[1]))
        out2 = torch.flip(out2,dims=[1])
        hidden = torch.cat([hidden1, hidden2], dim = 0)
        return (out1, out2), (hidden, 0)

class Rnn(nn.Module):
    def __init__(self, inputSize, hiddenSize, numLayers):
        super(Rnn, self).__init__()
        self.lstm = BILSTM(inputSize, hiddenSize, numLayers)

    def forward(self, x):
        out, (h_n, c_n) = self.lstm.forward(x)
        return out

net = Rnn(128, 128, 2)
net = net.cuda()

#############################################
# Export 
#############################################

x = torch.randn(1, 233, 128).cuda()
net.eval()
with torch.no_grad(): 
    torch.onnx.export( 
        net, 
        x, 
        "test.onnx", 
        opset_version=11, 
        input_names=['input'], 
        output_names=['output'])

有两个class, class Rnn表示解决问题的模型, 还有一个class BI-LSTM, 为什么不用pytorch自带的BI-LSTM呢, 因为截止目前(2022/8/8), Unity的Barracuda插件不支持LSTM算子的direction参数,也就是说只能使用单向LSTM.

然后导出就报错了:

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

用这个报错信息, 很难找出解决方法…
最后发现BILSTM也必须 继承nn.Module.

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值