Tensorflow2.0学习笔记-提取可训练参数

参数的提取,把参数存入文本

1. 参数打印输出。

在实际训练过程中,我们可以将参数的训练通过文本进行记录,或者打印出来进行查看。
其中model.trainable_variables可以返回模型中的参数。
我们可以使用printf进行打印,但是直接使用打印可能为出现很多数据无法显示,我们可以先设置

np.set_printoptions(threshold=10)  # 其中threshold表示输出的阈值,超出阈值的参数会用省略号表示,
                                   # 当阈值设置为np.inf时,表示数据全部输出
2. 参数写入文本

还可以使用代码将参数写入文本中进行查看,操作如下:

file = open('./weights.txt', 'w')
# 将训练参数存入文本,其中包括参数名称,大小,数据
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()
3. 示例代码
import tensorflow as tf
import os
import numpy as np
np.set_printoptions(threshold=np.inf)
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = "./checkpoint/mnist.ckpt"
# 判断是否拥有模型
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model--- --------------')
    # 加载模型数据
    model.load_weights(checkpoint_save_path)
# 保存模型数据
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

# 设置打印输出格式
# 参数可以设置输出模式,超过阈值threshold的参数用省略号显示,使用inf时表示全部显示
# 其中model.trainable_variables可以返回模型中的参数
np.set_printoptions(threshold=10)
print(model.trainable_variables)

file = open('./weights.txt', 'w')
# 将训练参数存入文本,其中包括参数名称,大小,数据
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

4. 运行结果

运行结果显示我们的参数,其中有四个array,分别是第一、二层神经元的w和b。

#  784 = 28 * 28 第一层神经元中一个w的参数个数和输入参数个数是匹配的,拥有128个神经元,也就代表拥有128个b
[<tf.Variable 'sequential/dense/kernel:0' shape=(784, 128) dtype=float32, numpy=
array([[-0.04371355, -0.03728049,  0.00738895, ..., -0.03773092,
......
 <tf.Variable 'sequential/dense/bias:0' shape=(128,) dtype=float32, numpy=
array([-0.16389479,  0.1481038 , -0.00113325, ...,  0.10987901,
        0.00905252, -0.19073108], dtype=float32)>, 
#  第二层神经元输入为128,拥有10个神经元。
 <tf.Variable 'sequential/dense_1/kernel:0' shape=(128, 10) dtype=float32, numpy=
array([[ 0.02201208,  0.12389698, -0.16859886, ...,  0.19811475,
......
 <tf.Variable 'sequential/dense_1/bias:0' shape=(10,) dtype=float32, numpy=
array([-0.14484479, -0.24807121,  0.11401606, -0.20439751,  0.12878774,
       -0.09878621, -0.10151227, -0.23220563,  0.5990565 , -0.03787111],
      dtype=float32)>]

然后我们可以在工程文件夹下看见我们建立的weight.txt文件,其中参数和我们打印训练的参数一致。

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值