cannot import name 'ReparseException' & tf模型参数提取

版权声明:本文为博主原创文章,转载请标注出处。 https://blog.csdn.net/Yan_Joy/article/details/79964647

Tensorboard ImportError: cannot import name ‘ReparseException’


环境:

  • python 3.5.4
  • Tensorflow==1.4.1
  • html5lib==1.0.1

错误信息:

ImportError: cannot import name 'ReparseException'

定位咋:

~/anaconda3/envs/tf3/lib/python3.5/site-packages/bleach/sanitizer.py

找到对应的文件,在

from html5lib.constants import (
    entities,
    ReparseException,
    namespaces,
    prefixes,
    tokenTypes,
)

中找到出问题的地方,将ReparseException改为_ReparseException即可解决问题。

tf模型参数提取


通过TensorFlow saver 保存的模型是有v1和v2两种格式,在tf.train.Saver中有所说明:

  • write_version: controls what format to use when saving checkpoints. It also affects certain filepath matching logic. The V2 format is the recommended choice: it is much more optimized than V1 in terms of memory required and latency incurred during restore. Regardless of this flag, the Saver is able to restore from both V2 and V1 checkpoints.

实际上通过对比,v1 500M的模型到了v2需要 1.5G。。

在实际的参数提取过程中,如何把模型中的variable变成可直接操作的numpy数组是一个比较麻烦的问题。搜索后大部分答案是先构建图,然后在运行的session中进行eval。(难道就不能直接通过模型拿到参数吗?)

针对v2的模型,可以使用tensorpack包进行快速的参数提取,得到npz的数据格式。

安装

tensorpack
首先需要有git,然后命令行输入:

 pip install -U git+https://github.com/ppwwyyxx/tensorpack.git   

脚本

在仓库中的./scripts/dump-model-params.py为模型参数提取的脚本文件。

使用示例为:

./scripts/dump-model-params.py --meta train_log/alexnet-dorefa/graphxxxx.meta train_log/alexnet-dorefa/model-xxx output.npz

源码

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: dump-model-params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

import numpy as np
import six
import argparse
import os
import tensorflow as tf

from tensorpack.tfutils import varmanip
from tensorpack.tfutils.common import get_op_tensor_name

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Keep only TRAINABLE and MODEL variables in a checkpoint.')
    parser.add_argument('--meta', help='metagraph file', required=True)
    parser.add_argument(dest='input', help='input model file, has to be a TF checkpoint')
    parser.add_argument(dest='output', help='output model file, can be npz or TF checkpoint')
    args = parser.parse_args()

    # this script does not need GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = ''

    tf.train.import_meta_graph(args.meta, clear_devices=True)

    # loading...
    if args.input.endswith('.npz'):
        dic = np.load(args.input)
    else:
        dic = varmanip.load_chkpt_vars(args.input)
    dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)}

    # save variables that are GLOBAL, and either TRAINABLE or MODEL
    var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
    assert len(set(var_to_dump)) == len(var_to_dump), "TRAINABLE and MODEL variables have duplication!"
    globvarname = [k.name for k in tf.global_variables()]
    var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname])

    for name in var_to_dump:
        assert name in dic, "Variable {} not found in the model!".format(name)

    dic_to_dump = {k: v for k, v in six.iteritems(dic) if k in var_to_dump}
    varmanip.save_chkpt_vars(dic_to_dump, args.output)
阅读更多

没有更多推荐了,返回首页