加载预训练模型,只训练fc层

实现:

def reload_sess(self):
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(max_to_keep=5, allow_empty=True, reshape=True)
        if self.args.pretrained_epoch > 0:
            print('\nlaoding...')
            ckpt = tf.train.get_checkpoint_state(self.args.ckpt_path)
            if ckpt and ckpt.model_checkpoint_path:
                self.global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
                init_assign_op, init_feed_dict = slim.assign_from_checkpoint(ckpt.model_checkpoint_path,
                                                                             slim.get_variables_to_restore(),
                                                                             ignore_missing_vars=True)
                self.sess.run(init_assign_op, feed_dict=init_feed_dict)
                print('loaded successful, global_step = %s, ' % self.global_step,
                      'path=%s\n' % ckpt.model_checkpoint_path)
            else:
                print('have no ckpt,start from 0 again')

其中init_assign_op, init_feed_dict = slim.assign_from_checkpoint(ckpt.model_checkpoint_path, slim.get_variables_to_restore(), ignore_missing_vars=True)
这一句 ignore_missing_vars=True即让找不到的变量随机初始化,所以我们在加载模型的时候,将原来的fc层的variable scope名改变一下,找不到fc层,就会将这一层进行随机初始化。

def get_logits(self, embedding, labels, out_num, w_init=None, ):
        with tf.variable_scope('logit'):# 原来为 logits,改为logit即可
            # inputs and weights norm
            embedding_norm = tf.norm(embedding, axis=1, keep_dims=True)
            embedding = tf.div(embedding, embedding_norm, name='norm_embedding') * self.s  # shape(16,512)

            weights = tf.get_variable(name='embedding_weights', shape=(embedding.get_shape()[-1], out_num),
                                     
                                      initializer=w_init, dtype=tf.float32)
            weights_norm = tf.norm(weights, axis=0, keep_dims=True)
            weights = tf.div(weights, weights_norm, name='norm_weights')  # shape(512,950)
            fc7 = tf.matmul(embedding, weights, name='cos_t')  # shape(16,950)
            one_hot_mask = tf.one_hot(labels, depth=out_num, name='one_hot_mask')          
            zy = tf.reduce_sum(fc7 * one_hot_mask, axis=1)  # shape(16,1)
            cos_t = zy / self.s
            t = tf.acos(cos_t, name='t')
            t_m = t + self.m

            # t_m = tf.where(tf.greater(t_m, math.pi), t, t_m, name='t_m_pi')  # 多么简洁,别弄反了,如果大于pi,不可以用pi,否则无法梯度下降
            cos_t_m = tf.cos(t_m, name='cos_t_m')
            new_zy = cos_t_m * self.s
            # 掩模的必要性,实现只再label的地方进行cos(t+m)
            diff = new_zy - zy
            diff = tf.expand_dims(diff, axis=1)
            mask = tf.multiply(one_hot_mask, diff, name='mask')
            # print(mask.shape)
            logits = (fc7 + mask)  # (16,950)(16,1),广播相乘

            return logits, fc7

还有一个要改的地方就是程序的out_num,加载预训练模型,只训练fc层,分类数目改为你想要的即可。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我现在强的可怕~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值