错误提示
最近在跑深度学习的一个模型,运行时遇到了如下报错:
Traceback (most recent call last):
File "demo.py", line 9, in <module>
from builders import model_builder
File "D:\workspace\Graduation-design\aster\builders\model_builder.py", line 5, in <module>
from builders import predictor_builder
File "D:\workspace\Graduation-design\aster\builders\predictor_builder.py", line 9, in <module>
from predictors import attention_predictor
File "D:\workspace\Graduation-design\aster\predictors\attention_predictor.py", line 11, in <module>
from core import sync_attention_wrapper
File "D:\workspace\Graduation-design\aster\core\sync_attention_wrapper.py", line 7, in <module>
class SyncAttentionWrapper(seq2seq.AttentionWrapper):
AttributeError: module 'tensorflow.contrib.seq2seq' has no attribute 'AttentionWrapper'
解决方案:
使用tensorflow1.4及以上版本
问题探索及解决过程
查阅报错信息,很明显是引用了包中没有的属性,我查看了源文件,其中一处代码调用到了seq2seq的AttentionWrapper属性
我在网上查询了此条报错信息,没有可用答复,于是我只检索了“AttentionWrapper”词条,遇到了如下文章:
Tensorflow新版Seq2Seq接口使用:https://blog.csdn.net/thriving_fcl/article/details/74165062
发现seq2seq中存在_allowed_symbols属性,它是个列表,其中包含了“AttentionWrapper”字符串
我查阅了我的虚拟环境中的tensorflow包里的seq2seq文件夹中的__init__.py文件
里面也有_allowed_symbols属性,不过该属性里没有“AttentionWrapper”字符串。
tensorflow1.1
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.seq2seq.python.ops.basic_decoder import *
from tensorflow.contrib.seq2seq.python.ops.decoder import *
from tensorflow.contrib.seq2seq.python.ops.dynamic_attention_wrapper import *
from tensorflow.contrib.seq2seq.python.ops.helper import *
from tensorflow.contrib.seq2seq.python.ops.loss import *
# pylint: enable=unused-import,widcard-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ["sequence_loss"]
remove_undocumented(__name__, _allowed_symbols)
我意识到了这是tensorflow版本的问题,我的虚拟环境使用conda指令默认安装的是tensorflow1.1,而我运行的项目是要求1.4及之前版本
我创建了新的虚拟环境,指定安装tensorflow1.4
conda install tensorflow=1.4
PS:记得多加载几个镜像源,不然可能会安装失败
我查看了1.4版本的tensorflow文件夹,其中的seq2seq文件夹里的__init__.py中也有_allowed_symbols属性
tensorflow1.4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import *
from tensorflow.contrib.seq2seq.python.ops.basic_decoder import *
from tensorflow.contrib.seq2seq.python.ops.beam_search_decoder import *
from tensorflow.contrib.seq2seq.python.ops.beam_search_ops import *
from tensorflow.contrib.seq2seq.python.ops.decoder import *
from tensorflow.contrib.seq2seq.python.ops.helper import *
from tensorflow.contrib.seq2seq.python.ops.loss import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,widcard-import,line-too-long
_allowed_symbols = [
"sequence_loss",
"Decoder",
"dynamic_decode",
"BasicDecoder",
"BasicDecoderOutput",
"BeamSearchDecoder",
"BeamSearchDecoderOutput",
"BeamSearchDecoderState",
"Helper",
"CustomHelper",
"FinalBeamSearchDecoderOutput",
"gather_tree",
"GreedyEmbeddingHelper",
"InferenceHelper",
"SampleEmbeddingHelper",
"ScheduledEmbeddingTrainingHelper",
"ScheduledOutputTrainingHelper",
"TrainingHelper",
"BahdanauAttention",
"LuongAttention",
"hardmax",
"AttentionWrapperState",
"AttentionWrapper",
"AttentionMechanism",
"tile_batch",
"safe_cumprod",
"monotonic_attention",
"monotonic_probability_fn",
"BahdanauMonotonicAttention",
"LuongMonotonicAttention",
]
remove_undocumented(__name__, _allowed_symbols)
对比两版本文件可发现细节变动蛮大,运行出错当然不可避免
也提醒小伙伴们跑别人项目时一定要注意环境统一