Tensorboard ImportError: cannot import name ‘ReparseException’
环境:
- python 3.5.4
- Tensorflow==1.4.1
- html5lib==1.0.1
错误信息:
ImportError: cannot import name 'ReparseException'
定位咋:
~/anaconda3/envs/tf3/lib/python3.5/site-packages/bleach/sanitizer.py
找到对应的文件,在
from html5lib.constants import (
entities,
ReparseException,
namespaces,
prefixes,
tokenTypes,
)
中找到出问题的地方,将ReparseException
改为_ReparseException
即可解决问题。
tf模型参数提取
通过TensorFlow saver 保存的模型是有v1和v2两种格式,在tf.train.Saver中有所说明:
- write_version: controls what format to use when saving checkpoints. It also affects certain filepath matching logic. The V2 format is the recommended choice: it is much more optimized than V1 in terms of memory required and latency incurred during restore. Regardless of this flag, the Saver is able to restore from both V2 and V1 checkpoints.
实际上通过对比,v1 500M的模型到了v2需要 1.5G。。
在实际的参数提取过程中,如何把模型中的variable变成可直接操作的numpy数组是一个比较麻烦的问题。搜索后大部分答案是先构建图,然后在运行的session中进行eval。(难道就不能直接通过模型拿到参数吗?)
针对v2的模型,可以使用tensorpack
包进行快速的参数提取,得到npz的数据格式。
安装
tensorpack:
首先需要有git,然后命令行输入:
pip install -U git+https://github.com/ppwwyyxx/tensorpack.git
脚本
在仓库中的./scripts/dump-model-params.py
为模型参数提取的脚本文件。
使用示例为:
./scripts/dump-model-params.py --meta train_log/alexnet-dorefa/graphxxxx.meta train_log/alexnet-dorefa/model-xxx output.npz
源码
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: dump-model-params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import six
import argparse
import os
import tensorflow as tf
from tensorpack.tfutils import varmanip
from tensorpack.tfutils.common import get_op_tensor_name
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Keep only TRAINABLE and MODEL variables in a checkpoint.')
parser.add_argument('--meta', help='metagraph file', required=True)
parser.add_argument(dest='input', help='input model file, has to be a TF checkpoint')
parser.add_argument(dest='output', help='output model file, can be npz or TF checkpoint')
args = parser.parse_args()
# this script does not need GPU
os.environ['CUDA_VISIBLE_DEVICES'] = ''
tf.train.import_meta_graph(args.meta, clear_devices=True)
# loading...
if args.input.endswith('.npz'):
dic = np.load(args.input)
else:
dic = varmanip.load_chkpt_vars(args.input)
dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)}
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
assert len(set(var_to_dump)) == len(var_to_dump), "TRAINABLE and MODEL variables have duplication!"
globvarname = [k.name for k in tf.global_variables()]
var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname])
for name in var_to_dump:
assert name in dic, "Variable {} not found in the model!".format(name)
dic_to_dump = {k: v for k, v in six.iteritems(dic) if k in var_to_dump}
varmanip.save_chkpt_vars(dic_to_dump, args.output)