在pt文件转换为onnx过程中:TypeError: forward() missing 2 required positional argument

网络结构中含有参数(h,c)

class Q_net(nn.Module):
    def __init__(self, state_space=None,
                 action_space=None):
        super(Q_net, self).__init__()

        # space size check
        assert state_space is not None, "None state_space input: state_space should be selected."
        assert action_space is not None, "None action_space input: action_space should be selected."

        self.hidden_space = 64
        self.state_space = state_space
        self.action_space = action_space

        self.Linear1 = nn.Linear(self.state_space, self.hidden_space)
        self.lstm    = nn.LSTM(self.hidden_space,self.hidden_space, batch_first=True)
        self.Linear2 = nn.Linear(self.hidden_space, self.action_space)

    def forward(self, x, h, c):
        x = F.relu(self.Linear1(x))
        x, (new_h, new_c) = self.lstm(x,(h,c))
        x = self.Linear2(x)
        return x, new_h, new_c
    def init_hidden_state(self, batch_size, training=None):

        assert training is not None, "training step parameter should be dtermined"

        if training is True:
            return torch.zeros([1, batch_size, self.hidden_space]), torch.zeros([1, batch_size, self.hidden_space])
        else:
            return torch.zeros([1, 1, self.hidden_space]), torch.zeros([1, 1, self.hidden_space])

//参数h、c
 h, c = q_net.init_hidden_state(batch_size=batch_size, training=True)

.pth转.onnx报错: TypeError: forward() missing 2 required positional argument, 报错代码如下:

dummy_input = torch.randn(64,2) 
checkpoing = torch.load('./DRQN_POMDP_Random_SEED_1.pth', 'cpu')
model1.load_state_dict(checkpoing)
torch.onnx.export(checkpoing, dummy_input, "model_best.onnx", export_params=True, verbose=True)  # 将模型保存成.onnx格

最后的解决办法:将h、c代入模型中

dummy_input = torch.randn(1,64,2)  # 要求输入3维的矩阵, why?
model1 = Q
h, c = model1.init_hidden_state(batch_size=batch_size, training=False)
checkpoing = torch.load('./DRQN_POMDP_Random_SEED_1.pth', 'cpu')  # 导入模型参数
model1.load_state_dict(checkpoing)  # 将模型参数赋予自定义的模型
torch.onnx.export(model1, (dummy_input,h,c), "model_best.onnx", export_params=True, verbose=True)  # 将模型保存成.onnx格
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值