实现:
def reload_sess(self):
self.sess.run(tf.global_variables_initializer())
self.saver = tf.train.Saver(max_to_keep=5, allow_empty=True, reshape=True)
if self.args.pretrained_epoch > 0:
print('\nlaoding...')
ckpt = tf.train.get_checkpoint_state(self.args.ckpt_path)
if ckpt and ckpt.model_checkpoint_path:
self.global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(ckpt.model_checkpoint_path,
slim.get_variables_to_restore(),
ignore_missing_vars=True)
self.sess.run(init_assign_op, feed_dict=init_feed_dict)
print('loaded successful, global_step = %s, ' % self.global_step,
'path=%s\n' % ckpt.model_checkpoint_path)
else:
print('have no ckpt,start from 0 again')
其中init_assign_op, init_feed_dict = slim.assign_from_checkpoint(ckpt.model_checkpoint_path, slim.get_variables_to_restore(), ignore_missing_vars=True)
这一句 ignore_missing_vars=True即让找不到的变量随机初始化,所以我们在加载模型的时候,将原来的fc层的variable scope名改变一下,找不到fc层,就会将这一层进行随机初始化。
def get_logits(self, embedding, labels, out_num, w_init=None, ):
with tf.variable_scope('logit'):# 原来为 logits,改为logit即可
# inputs and weights norm
embedding_norm = tf.norm(embedding, axis=1, keep_dims=True)
embedding = tf.div(embedding, embedding_norm, name='norm_embedding') * self.s # shape(16,512)
weights = tf.get_variable(name='embedding_weights', shape=(embedding.get_shape()[-1], out_num),
initializer=w_init, dtype=tf.float32)
weights_norm = tf.norm(weights, axis=0, keep_dims=True)
weights = tf.div(weights, weights_norm, name='norm_weights') # shape(512,950)
fc7 = tf.matmul(embedding, weights, name='cos_t') # shape(16,950)
one_hot_mask = tf.one_hot(labels, depth=out_num, name='one_hot_mask')
zy = tf.reduce_sum(fc7 * one_hot_mask, axis=1) # shape(16,1)
cos_t = zy / self.s
t = tf.acos(cos_t, name='t')
t_m = t + self.m
# t_m = tf.where(tf.greater(t_m, math.pi), t, t_m, name='t_m_pi') # 多么简洁,别弄反了,如果大于pi,不可以用pi,否则无法梯度下降
cos_t_m = tf.cos(t_m, name='cos_t_m')
new_zy = cos_t_m * self.s
# 掩模的必要性,实现只再label的地方进行cos(t+m)
diff = new_zy - zy
diff = tf.expand_dims(diff, axis=1)
mask = tf.multiply(one_hot_mask, diff, name='mask')
# print(mask.shape)
logits = (fc7 + mask) # (16,950)(16,1),广播相乘
return logits, fc7
还有一个要改的地方就是程序的out_num,加载预训练模型,只训练fc层,分类数目改为你想要的即可。