1、引用python库
import tensorflow as tf
import numpy as np
import math
2、声明参数
LAYER1_SIZE = 400
LAYER2_SIZE = 300
LEARNING_RATE = 1e-3
TAU = 0.001
L2 = 0.01
3、定义类
class CriticNetwork:
"""docstring for CriticNetwork"""
def __init__(self,sess,state_dim,action_dim):
self.time_step = 0
(以下全为Class CriticNetwork中的函数)
3.1 初始化函数
def __init__(self,sess,state_dim,action_dim):
self.time_step = 0
self.sess = sess
# create q network
self.state_input,\
self.action_input,\
self.q_value_output,\
self.net = self.create_q_network(state_dim,action_dim)
# create target q network (the same structure with q network)
self.target_state_input,\
self.target_action_input,\
self.target_q_value_output,\
self.target_update = self.create_target_q_network(state_dim,action_dim,self.net)
self.create_training_method()
# initialization
self.sess.run(tf.initialize_all_variables())
self.update_target()
(1)调用create_q_network()和create_target_q_network()两个函数,创建了critic网络和target_critic网络
(2)调用函数create_training_method()
(3)调用函数sess.run()
(4)调用函数self.update_target();
3.2其他函数
def variable(self,shape,f):