上一篇文章简单讲了怎么断点续训,但是有时候吧,想要看看训练过程中那些参数都长啥样,从开始到结束各个参数如何变化,在网络训练过程中只有accuracy和loss可以看到,这时候需要使用参数提取的方法。
在上一篇基础上,加入一些新内容
import tensorflow as tf
import os
from tensorflow import keras
import numpy as np
# 设置打印格式,threshold=超过多少不打印,np.inf无限大
np.set_printoptions(threshold=np.inf)
mnist = keras.datasets.mnist
(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train,x_test = x_train/255,x_test/255
model = keras.models.Sequential([
keras.layers.Flatten(),
keras.layers.Dense(128,activation='relu'),
keras.layers.Dense(10,activation='softmax')
])
model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = './临时文件/mnist.ckpt'
# 生成ckpt的同时会生成索引文件
if os.path.exists(checkpoint_save_path+'.index'):
print('-----------load the model----------------')
model.load_weights(checkpoint_save_path)
cp_callback = keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)
history = model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])
model.summary()
查看一下训练过程中各参数
print(model.trainable_variables)
然后把参数保存在文件中
with open('./临时文件/weights.txt','w')as f:
for v in model.trainable_variables:
f.write(str(v.name)+'\n')
f.write(str(v.shape) + '\n')
f.write(str(v.numpy()) + '\n')
完整代码如下
import tensorflow as tf
import os
from tensorflow import keras
import numpy as np
# 设置打印格式,threshold=超过多少不打印,np.inf无限大
np.set_printoptions(threshold=np.inf)
mnist = keras.datasets.mnist
(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train,x_test = x_train/255,x_test/255
model = keras.models.Sequential([
keras.layers.Flatten(),
keras.layers.Dense(128,activation='relu'),
keras.layers.Dense(10,activation='softmax')
])
model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = './临时文件/mnist.ckpt'
# 生成ckpt的同时会生成索引文件
if os.path.exists(checkpoint_save_path+'.index'):
print('-----------load the model----------------')
model.load_weights(checkpoint_save_path)
cp_callback = keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)
history = model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])
model.summary()
print(model.trainable_variables)
with open('./临时文件/weights.txt','w')as f:
for v in model.trainable_variables:
f.write(str(v.name)+'\n')
f.write(str(v.shape) + '\n')
f.write(str(v.numpy()) + '\n')
打开刚才保存的文件看一下