[tensorflow] 继承 tf.keras 模型后,参数不对

一般我使用 tf.keras 的方式为继承,环境为 tensorflow1.15,下面是一个实例模型:

class ActorMLP(tf.keras.Model):
    def __init__(self, ac_dim, name=None):
        super(ActorMLP, self).__init__(name=name)
        activation_fn = tf.keras.activations.tanh
        kernel_initializer = None
        self.dense1 = tf.keras.layers.Dense(
            64, activation=activation_fn,
            kernel_initializer=tf_ortho_init(np.sqrt(2)), name='fc1')
        self.dense2 = tf.keras.layers.Dense(
            64, activation=activation_fn,
            kernel_initializer=tf_ortho_init(np.sqrt(2)), name='fc2')
        self.dense3 = tf.keras.layers.Dense(
            ac_dim[0], activation=None,
            kernel_initializer=tf_ortho_init(0.01), name='fc3')

    def call(self, state):
        x = tf.cast(state, tf.float32)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)
        return x

使用方式大概像下面这样

model = ActorMLP(name='pi')

然后按照静态图的方法使用。这是没问题的。

问题是保存模型的时候,我发现 model.get_weights() 和 使用标准 tf1 方法(如下)获得的参数完全不同。同样,model.save_weights() 保存的参数也是错误的,更别提 model.load_weights()

scope = tf.compat.v1.get_default_graph().get_name_scope()
vars = tf.compat.v1.get_collection(
    tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES,
    scope=os.path.join(scope, 'pi')+'/')
var_list = sess.run(vars)

后面,还是使用 tf1 提供的标准方法来保存参数,示例如下:

def _get_var_list(self, name=None):
    scope = tf.compat.v1.get_default_graph().get_name_scope()
    vars = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES,
        scope=os.path.join(scope, name)+'/')
    return vars

def _build_saver(self):
    pi_params = self._get_var_list('pi')
    vf_params = self._get_var_list('vf')
    return tf.compat.v1.train.Saver(var_list=pi_params+vf_params,
                                    max_to_keep=4)

def save_weight(self, checkpoint_dir, epoch):
    if not os.path.exists(checkpoint_dir):
        raise
    self.saver.save(
        self.sess,
        os.path.join(checkpoint_dir, 'tf_ckpt'),
        global_step=epoch)

def load_weight(self, checkpoint_dir, epoch=None):
    if not os.path.exists(checkpoint_dir):
        raise
    self.saver.restore(
        self.sess,
        os.path.join(checkpoint_dir, f'tf_ckpt-{epoch}'))
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值