混淆ndarray和tensor的flatten方法出现的问题

在导出onnx模型的时候,输入的observation是ndarray字典,将它输入到一个features_extractors的时候会报下面的错误:

Traceback (most recent call last):
  File "E:/Workspace/PybulletVelocityController/Test/testPTh.py", line 48, in <module>
    encoded_tensor_list.append(extractor(dummy_input[key]))
  File "C:\Users\huyu\anaconda3\envs\drones\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\huyu\anaconda3\envs\drones\lib\site-packages\torch\nn\modules\flatten.py", line 45, in forward
    return input.flatten(self.start_dim, self.end_dim)
TypeError: flatten() takes from 0 to 1 positional arguments but 2 were given

 debug到函数tensor的flatten函数那块一直觉得没有错误,但就是报错,不知道怎么回事。

后来发现ndarray也有flatten函数。

 

 它的参数刚好是0 或 1。

这就解释了为什么报上面的错,原因是没有调用tensor的flatten方法,而是去调用了ndarray的flatten方法。

在stable-baselines3 里面应该有将ndarray数据预处理成tensor的方法(之前没有注意到),因此在训练过程中没有报这样的数据类型错误。

 

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值