保存模型
RL中,我们一般都把一个网络结构写在一个类里面,保存的时候也是,可以如下写一个 save_net 函数:
def save_net(self):
saver = tf.train.Saver()
save_path = saver.save(self.sess, "./dqn/model/file_name.ckpt")
print("Save to path: ", save_path)
在RL算法进行完N轮的训练之后,调用该函数进行模型保存:agent.save_net()
可以看到,会在model文件夹下多出四个文件:
也可以输出保存前的参数,进行观察,以便确认读取模型时是否成功读取了参数:
w1 = tf.get_default_graph().get_tensor_by_name('eval_net/l1/w1:0') # 获得variable对应的Tensor
print(self.sess.run(w1)) # run一下这个Tensor得到结果
读取模型
首先注意,读取模型用于测试时,我们需要保证用到的变量和训练时的是一样的,比如测试DQN模型的效果:
class Test4DQN:
def __init__(self):
self.sess = tf.Session()
self._build_net()
def _build_net(self):
# 测试时,只需要建立 evaluate_net,用来选择动作
self.s = tf.placeholder(tf.float32, [None, 11])
with tf.variable_scope('eval_net'):
with tf.variable_scope('l1'):
w1 = tf.Variable(np.arange(110).reshape((11, 10)), dtype=tf.float32, name="w1")
b1 = tf.Variable(np.arange(10).reshape((1, 10)), dtype=tf.float32, name="b1")
l1 = tf.nn.relu(tf.matmul(self.s, w1) + b1)
with tf.variable_scope('l2'):
w2 = tf.Variable(np.arange(240).reshape((10, 24)), dtype=tf.float32, name="w2")
b2 = tf.Variable(np.arange(24).reshape((1, 24)), dtype=tf.float32, name="b2")
self.q_eval = tf.matmul(l1, w2) + b2
# 读取模型参数
saver = tf.train.Saver()
init = tf.global_variables_initializer()
self.sess.run(init)
saver.restore(self.sess, "./xxxxx/model/file_name.ckpt")
print(self.sess.run(w1)) # 可以再次输出,和我们保存时的输出结果进行对比,保证正确读取
def choose_action(self, observation):
observation = observation[np.newaxis, :]
actions_value = self.sess.run(self.q_eval, feed_dict={self.s: observation})
action = np.argmax(actions_value)
return action
总结一下,就是先初始化一下测试框架中定义的变量(注意层级和名称需要对应,原来叫’‘w1’‘现在也要叫’‘w1’’),然后调用saver.restore(self.sess, "./xxxxx/model/file_name.ckpt")
,即可将保存的网络参数赋值给现在的网络。
之后,和原来RL的流程一样,只是不再需要保存记忆和训练而已,最后可以得到测试的效果。