前言
使用pytorch可视化网络结构时,遇到了pytorch和tensorboardX版本不兼容问题,又不能轻易降低pytorch版本,最终参考网上文章找到了问题原因。
1. 问题描述
使用TensorboardX可视化网络结构,示例代码如下。
from tensorboardX import SummaryWriter
with SummaryWriter(comment="XXXXXX") as w:
w.add_graph(model)
报错:AttributeError: module ‘torch.onnx‘ has no attribute ‘set_training‘
2. 问题原因
参考网上文章1,作者发现是pytorch的高版本修改了一个方法名称。
在
PyTorch 1.6
版本中set_training
变成了select_model_mode_for_export
而tensorboardX的升级版本中也仍然没有解决这个问题。
问题就出现在下面这个 set_training
def graph(model, args, verbose=False, **kwargs):
import torch
with torch.onnx.set_training(model, False):
try:
trace = torch.jit.trace(model, args)
graph = trace.graph
except RuntimeError as e:
print(e)
print('Error occurs, No graph saved')
3. 解决方法
是临时的解决方案,修改tensorboardX源码并打包。
第一步:从github拉取tensorboardX源码。
git clone https://github.com/lanpa/tensorboardX
第二步:切换到所需版本的标签。
git checkout v1.8
第三步:修改源码。
# 修改 with torch.onnx.set_training(model, False): 为下面语句
with torch.onnx.select_model_mode_for_export(model, False):
第四步:重新打包。
(注意:以下步骤在Linux下完成,Windows环境请适当修改路径写法)
pip install wheel
# 切换到tensorboardX所在目录
cd tensorboardX
# 打包
pip wheel --wheel-dir=/root/ ./
# 生成 `tensorboardX-1.8+e136d41-py2.py3-none-any.whl`
第五步:安装。
# 将`whl`文件拷贝到项目目录,并安装
pip install tensorboardX-1.8+e136d41-py2.py3-none-any.whl
再次运行,就不再报错了。