if (train_batch_step + 1) % FLAGS.eval_every == 0:
# run_num = run_num + 1
# if run_num % 2 == 0:
# break
all_valid_acc_num = 0
all_valid_num = 0
valid_batches = data_help.valid_batch_iterator()
for _, valid_batch_q in enumerate(valid_batches):
all_valid_num = all_valid_num + len(valid_batch_q)
valid_batch = (valid_batch_q, data_help.std_batch)
valid_prob = sess.run([dssm.prob_pre], feed_dict=feed_dict_builder(valid_batch, 1.0, dssm))
valid_acc_num, real_labels, _ = cal_predict_acc_num(valid_prob[0], valid_batch_q,
data_help.id2label)
all_valid_acc_num = all_valid_acc_num + valid_acc_num
current_acc = all_valid_acc_num * 1.0 / all_valid_num
这段验证验证集非常慢。
例如想利用LCMC数据集,把每一个标记为1句子对抽出一个作为标准集的一个类。那么类别非常多。