进行fine-tune的关键在于保存数据和重新读取数据,在tensorflow中有多种方法,下面介绍一种空间复杂度较低的方法。
首先,在训练阶段
saver = tf.train.Saver(var_list=tf.trainable_variables()) #只保存可训练变量
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
...
在加载模型时,要相应地只加载可训练数据,然后根据需要更改后续的网络结构
train_op = tf.train.AdamOptimizer(lr).minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_list=tf.trainable_variables())
saver.restore(sess,filename)