TensorFlow中对训练后的神经网络参数(权重、偏置)提取

 基于TensorFlow可以轻而易举搭建一个神经网络,而且很好地支持GPU加速训练。但基于TensorFlow的预测过程,往往需要在嵌入式设备上才能得以应用。对于我目前做的工作而言,用TF搭建神经网络以及用GPU加速训练过程的主要用处就是:获取训练后的参数(权重和偏置),将这些参数直接放到嵌入式板卡如FPGA中,以其低功耗、高性能、低延时等特点完成嵌入式AI工程。那么,提取出TF训练后的参数变成很重要的过程。

        不少IDE可以提供可视化的参数显示,本文介绍的方法是不依赖IDE的神经网络参数提取。我们知道,在training之后,模型将会被保存在一个特定的路径下。model里面包括包括加算图结构、节点信息、参数等数据,那么只要在model里面就一定可以找到参数的信息。

              在TensorFlow里,提供了tf.train.NewCheckpointReader来查看model.ckpt文件中保存的变量信息。

参数就是trainable集合的变量,所以也可以通过这个tf.train.NewCheckpointReader来查看,具体代码如下:

import tensorflow as tf
import numpy as np
reader = tf.train.NewCheckpointReader('llw/MNIST_model/mnist_model-29001')
all_variables = reader.get_variable_to_shape_map()
w1 = reader.get_tensor("layer1/weights")
print(type(w1))
print(w1.shape)
print(w1[0])  
输出为:

<class 'numpy.ndarray'>
(784, 500)
[  2.24018339e-02  -2.00362392e-02  -1.12209506e-02   6.77579222e-03
  -9.59016196e-03   1.21959345e-02  -9.51156951e-03  -1.60046462e-02
  -1.37826744e-02  -1.76466629e-02  -2.11188430e-03   3.54206143e-03
  -2.03107391e-02   2.13961536e-03  -4.41462384e-04  -1.93272587e-02
  -3.71702737e-03   2.22449750e-03   2.98950635e-02  -2.47442089e-02
  -7.97873642e-03   2.99713714e-03  -1.77890640e-02   2.59044971e-02
   9.38970014e-04   1.46359997e-02  -2.18281448e-02   1.55605981e-02
  -2.44196616e-02  -2.03805566e-02  -7.10553257e-03  -8.46040528e-03
  -1.21834688e-02  -1.71028115e-02  -1.73374973e-02   1.58206956e-03
   7.28264870e-03  -2.08463762e-02  -7.46442471e-03   7.55013386e-03
   4.64899749e-05   3.26069025e-03  -1.22860866e-02  -2.33450923e-02
   8.73958052e-04  -2.50798613e-02  -2.91012623e-03   2.18578596e-02

....

上述的文件路径llw/MNIST_model/mnist_model-29001,为checkpoint指定的路径:

model_checkpoint_path: "mnist_model-29001"

上述程序只是输出第一个节点的参数(500个),总参数光第一层参数就有784X500个,不太适合全部打印在屏幕上。

所以可以通过python 的file write()函数将参数写到txt文本中。在这里不做详述。
--------------------- 
作者:木盏 
来源:CSDN 
原文:https://blog.csdn.net/leviopku/article/details/78510977 
版权声明:本文为博主原创文章,转载请附上博文链接!

TensorFlow ,我们可以使用 tf.train.Saver() 类来保存和加载模型参数,包括网络权重偏置。下面是一个示例代码,用于保存和加载一个简单的全连接神经网络权重偏置: ```python import tensorflow as tf # 定义一个简单的全连接神经网络 x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x, W) + b) # 定义损失函数和优化器 y_ = tf.placeholder(tf.float32, [None, 10]) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 创建一个 Saver 对象 saver = tf.train.Saver() # 训练模型并保存参数 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(1000): batch_xs, batch_ys = ... # 获取训练数据 sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) saver.save(sess, './model.ckpt') # 保存模型参数 # 加载模型参数并进行推理 with tf.Session() as sess: saver.restore(sess, './model.ckpt') # 加载模型参数 # 进行推理 ... ``` 在上面的代码,我们首先定义了一个简单的全连接神经网络,并定义了损失函数和优化器。然后创建了一个 Saver 对象,用于保存和加载模型参数。在训练过程,我们使用 sess.run() 函数运行 train_step 操作,并在训练结束后使用 saver.save() 方法保存模型参数。在推理过程,我们使用 saver.restore() 方法加载模型参数,并进行推理。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值