AttributeError: module ‘torch.onnx‘ has no attribute ‘set_training‘解决方案

前言

使用pytorch可视化网络结构时,遇到了pytorch和tensorboardX版本不兼容问题,又不能轻易降低pytorch版本,最终参考网上文章找到了问题原因。

1. 问题描述

使用TensorboardX可视化网络结构,示例代码如下。

from tensorboardX import SummaryWriter

with SummaryWriter(comment="XXXXXX") as w:
    w.add_graph(model)

报错:AttributeError: module ‘torch.onnx‘ has no attribute ‘set_training‘

2. 问题原因

参考网上文章1,作者发现是pytorch的高版本修改了一个方法名称。

PyTorch 1.6版本中set_training变成了select_model_mode_for_export

而tensorboardX的升级版本中也仍然没有解决这个问题。

问题就出现在下面这个 set_training

def graph(model, args, verbose=False, **kwargs):

    import torch

    with torch.onnx.set_training(model, False): 
        try:
            trace = torch.jit.trace(model, args)
            graph = trace.graph

        except RuntimeError as e:
            print(e)
            print('Error occurs, No graph saved')

3. 解决方法

是临时的解决方案,修改tensorboardX源码并打包。

第一步:从github拉取tensorboardX源码。

git clone https://github.com/lanpa/tensorboardX

第二步:切换到所需版本的标签。

git checkout v1.8

第三步:修改源码。

# 修改 with torch.onnx.set_training(model, False): 为下面语句
with torch.onnx.select_model_mode_for_export(model, False): 

第四步:重新打包。

(注意:以下步骤在Linux下完成,Windows环境请适当修改路径写法)

pip install wheel
# 切换到tensorboardX所在目录
cd tensorboardX
# 打包
pip wheel --wheel-dir=/root/ ./

# 生成 `tensorboardX-1.8+e136d41-py2.py3-none-any.whl`

第五步:安装。

# 将`whl`文件拷贝到项目目录,并安装
pip install tensorboardX-1.8+e136d41-py2.py3-none-any.whl

再次运行,就不再报错了。


  1. 参考文章 https://blog.csdn.net/qq_42730750/article/details/119741621 ↩︎

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值