一、PPO代码
选用的是如下repo中的PPO代码:
https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/contents/
二、网络结构
1.网络变量
项目中的网络输入针对的是向量,但AI2THOR环境下的观测状态及神经网络需要的输入是一个三维张量,因此需要修改项目中的网路结构,S为[None, 84, 84, 3],Q_tar为[None, 1],Q_eval为[None, 1],pi为[None, action_size],a为[None, ]
2.Actor和Critic网络
#s的输入顺序:s->conv1->relu->conv1->relu->fc->relu->out_put_fc
with tf.variable_scope("base_conv", reuse=reuse):
# Weights
W_conv1, b_conv1 = self._conv_variable([8, 8, 3, 16], "base_conv1")
W_conv2, b_conv2 = self._conv_variable([4, 4, 16, 32], "base_conv2")
# Nodes
h_conv1 = tf.nn.relu(self._conv2d(self.tfs, W_conv1, 4) + b_conv1) # stride=4
h_conv2 = tf.nn.relu(self._conv2d(h_conv1, W_conv2, 2) + b_conv2) # stride=2
with tf.variable_scope("base_fc", reuse=reuse):
W_fc1, b_fc1 = self._fc_variable([2592, 256], "base_fc1")
# Nodes
conv_output = tf.reshape(h_conv2, [-1, 2592])
# (-1,9,9,32) -> (-1,2592)
conv_output_fc = tf.nn.relu(tf.matmul(conv_output, W_fc1) + b_fc1)
#out_put_fc经过fc、softmax后得到pi
with tf.variable_scope("base_policy", reuse=reuse):
# Weight for policy output layer
W_fc_p, b_fc_p = self._fc_variable([256, A_DIM], "base_fc_p")
# Policy (output)
base_pi = tf.nn.softmax(tf.matmul(conv_output_fc, W_fc_p) + b_fc_p)
#out_put_fc经过fc后得到v
with tf.variable_scope("base_value", reuse=reuse):
# Weight for policy output layer
W_fc_p, b_fc_p = self._fc_variable([256, 1], "base_fc_pc")
# Policy (output)
v = tf.matmul(conv_output_fc, W_fc_p) + b_fc_p
3.变量赋值
s = s[np.newaxis,,;,:] #s是[84,84,3],tfs是[None, 84,84,3],需要给s增加一个维度
pi = self.sess.run(self.pi, feed_dict={self.tfs:s})
4.批训练
变量维度中的None则对应着batch
5.Tensorflow中变量的类型要是numpy,不能是python中的list,否则会出现报错或错误的矩阵运算结果, 影响神经网络以及训练的效果.
def update(self, s, a, r):
self.sess.run(self.update_oldpi_op)
s = np.array(s)
a = np.array(a)
r = np.array(r)
r = r[:, np.newaxis]
adv = self.sess.run(self.advantage, {self.tfs: s, self.tfdc_r: r})
#s如果不转化为numpy则求得的v的shape变为(32, 32)
# update actor
[self.sess.run(self.atrain_op, {self.tfs: s, self.tfa: a, self.tfadv: adv}) for _ in range(A_UPDATE_STEPS)]
# update critic
[self.sess.run(self.ctrain_op, {self.tfs: s, self.tfdc_r: r}) for _ in range(C_UPDATE_STEPS)]