一般我使用 tf.keras 的方式为继承,环境为 tensorflow1.15,下面是一个实例模型:
class ActorMLP(tf.keras.Model):
def __init__(self, ac_dim, name=None):
super(ActorMLP, self).__init__(name=name)
activation_fn = tf.keras.activations.tanh
kernel_initializer = None
self.dense1 = tf.keras.layers.Dense(
64, activation=activation_fn,
kernel_initializer=tf_ortho_init(np.sqrt(2)), name='fc1')
self.dense2 = tf.keras.layers.Dense(
64, activation=activation_fn,
kernel_initializer=tf_ortho_init(np.sqrt(2)), name='fc2')
self.dense3 = tf.keras.layers.Dense(
ac_dim[0], activation=None,
kernel_initializer=tf_ortho_init(0.01), name='fc3')
def call(self, state):
x = tf.cast(state, tf.float32)
x = self.dense1(x)
x = self.dense2(x)
x = self.dense3(x)
return x
使用方式大概像下面这样
model = ActorMLP(name='pi')
然后按照静态图的方法使用。这是没问题的。
问题是保存模型的时候,我发现 model.get_weights()
和 使用标准 tf1 方法(如下)获得的参数完全不同。同样,model.save_weights()
保存的参数也是错误的,更别提 model.load_weights()
了
scope = tf.compat.v1.get_default_graph().get_name_scope()
vars = tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES,
scope=os.path.join(scope, 'pi')+'/')
var_list = sess.run(vars)
后面,还是使用 tf1 提供的标准方法来保存参数,示例如下:
def _get_var_list(self, name=None):
scope = tf.compat.v1.get_default_graph().get_name_scope()
vars = tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES,
scope=os.path.join(scope, name)+'/')
return vars
def _build_saver(self):
pi_params = self._get_var_list('pi')
vf_params = self._get_var_list('vf')
return tf.compat.v1.train.Saver(var_list=pi_params+vf_params,
max_to_keep=4)
def save_weight(self, checkpoint_dir, epoch):
if not os.path.exists(checkpoint_dir):
raise
self.saver.save(
self.sess,
os.path.join(checkpoint_dir, 'tf_ckpt'),
global_step=epoch)
def load_weight(self, checkpoint_dir, epoch=None):
if not os.path.exists(checkpoint_dir):
raise
self.saver.restore(
self.sess,
os.path.join(checkpoint_dir, f'tf_ckpt-{epoch}'))