一. Bert 的多任务认识
以bert为代表的预训练模型多任务可以应用在多种场景。
1. 主辅任务:
比如我们当前进行任务A,但是我们可以构造一个辅助任务B一起参与训练,来达到提高任务A性能的目的,比如人为的去构造一个辅助任务 MLM(Bert 预训练语言遮蔽模型)这样的辅助任务,辅助去提高任务A,线上推理时并不输出任务B的结果。
2. 并行任务:
本身就需要进行多个任务,比如ABC,一样重要,但任务类型相似,如果分开训练,那么就需要3个模型,此时可以尝试共享一个模型,即共享大部分参数,差异化小部分参数。
二,多任务模型结构的设计
1. encoder 完全共享
就像 finetune 一样,encoder 层完全共享,只在 Bert 的最后一层(池化层)根据不同任务设计不同层。而此时有两种不同的计算损失方式:
- 单条数据,根据 if - else 判断自身任务,只计算自身任务的 Loss 损失值,完成反向传播。
- 单条数据同时计算多个任务的 Loss 损失值,而后将所有损失值相加,完成反向传播。
目前我们采用第 2 种计算 Loss 方式。
三. 多任务模型代码修改
以下修改全部在
run_classifier.py
文件中
1. 修改 - 数据预处理
1)修改 main 函数
- label.csv 文件的输入由一份改为两份
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]()
###### add multi task by nijiahui 20220422 start ##########
# label_list = processor.get_labels(data_dir)
label_threeseg_list = processor.get_label_threeseg(data_dir)
label_cabinet_list = processor.get_label_cabinet(data_dir)
###### add multi task by nijiahui 20220422 end ##########
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
)
- 修改 model_fn_builder 参数 num_labels 为元组
model_fn = model_fn_builder(
bert_config=bert_config,
###### add multi task by nijiahui 20220422 start ##########
num_labels=(len(label_threeseg_list),len(label_cabinet_list)),
###### add multi task by nijiahui 20220422 end ##########
init_checkpoint=FLAGS.init_checkpoint,
learning_rate=FLAGS.learning_rate,
num_train_steps=num_train_steps,
num_warmup_steps=num_warmup_steps,
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_tpu,
)
- 修改 file_based_convert_examples_to_features 函数的 label_list 入参为元组
if FLAGS.do_train:
train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
###### add multi task by nijiahui 20220422 start ##########
# file_based_convert_examples_to_features(
# train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file
# )
label_list = (label_threeseg_list, label_cabinet_list)
file_based_convert_examples_to_features(
train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file
)
###### add multi task by nijiahui 20220422 end ##########
tf.logging.info("***** Running training *****")
eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
###### add multi task by nijiahui 20220422 start ##########
label_list = (label_threeseg_list, label_cabinet_list)
###### add multi task by nijiahui 20220422 end ##########
file_based_convert_examples_to_features(
eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file
)
- 修改 eval 时逻辑,增加验证集每个任务的推理结果、每个任务是否推理正确、置信度等字段输出
###### add eval_data_result.csv file by nijiahui 20220420 start ###########
dev_result = estimator.predict(input_fn=eval_input_fn)
eval_data_res_file = os.path.join(FLAGS.output_dir, "eval_data_results.csv")
with tf.gfile.GFile(eval_data_res_file, "w") as writer:
for example, dev_prediction in zip(eval_examples, dev_result):
index = str(example.guid)
text = str(example.text_a)
label_threeseg, label_cabinet = example.label
probabilities_threeseg = dev_prediction["probabilities_threeseg"]
probabilities_cabinet = dev_prediction["probabilities_cabinet"]
pred_threeseg = str(label_threeseg_list[np.argmax(probabilities_threeseg)])
max_score_threeseg = str(np.max(probabilities_threeseg))
pred_cabinet = str(label_cabinet_list[np.argmax(probabilities_cabinet)])
max_score_cabinet = str(np.max(probabilities_cabinet))
is_right_threeseg = "1" if pred_threeseg==label_threeseg else "0"
is_right_cabinet = "1" if pred_cabinet==label_cabinet else "0"
res_list = [
index,
text,
pred_threeseg, label_threeseg, is_right_threeseg,
pred_cabinet, label_cabinet, is_right_cabinet,
max_score_threeseg, max_score_cabinet,
]
output_line = ("\t".join(res_list) + "\n")
writer.write(output_line)
###### add eval_data_result.csv file by nijiahui 20220420 end ###########
2)修改 Cls_Processor 函数
- 增加多任务的 get_label 方法
###### add multi task get_label by nijiahui 20220422 start ##########
# def get_labels(self, data_dir):
# """See base class."""
# labels = [
# line.rstrip().split(";")
# for line in open(os.path.join(data_dir, "label.csv"), "r", encoding="utf-8")
# if line.rstrip()
# ]
# return labels
def get_label_threeseg(self, data_dir):
"""See base class."""
labels = [
line.rstrip()
for line in open(os.path.join(data_dir, "label_threeseg.csv"), "r", encoding="utf-8")
if line.rstrip()
]
return labels
def get_label_cabinet(self, data_dir):
"""See base class."""
labels = [
line.rstrip()
for line in open(os.path.join(data_dir, "label_cabinet.csv"), "r", encoding="utf-8")
if line.rstrip()
]
return labels
###### add multi task get_label by nijiahui 20220422 end ##########
- 修改 create_example 方法,修改 example 的 label 属性为包含多个 label_id 的元组
###### modify multi task create_examples by nijiahui 20220422 start ##########
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[0])
label_threeseg = "0" # add
label_cabinet = "0" # add
elif set_type == "dev":
text_a = tokenization.convert_to_unicode(line[0])
label_threeseg = tokenization.convert_to_unicode(line[1]) # add
label_cabinet = tokenization.convert_to_unicode(line[2]) # add
else:
text_a = tokenization.convert_to_unicode(line[0])
label_threeseg = tokenization.convert_to_unicode(line[1]) # add
label_cabinet = tokenization.convert_to_unicode(line[2]) # add
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=(label_threeseg,label_cabinet)) # modify label
)
return examples
def _create_examples_train(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
# 【二,horovod - 区对数据 - 分卡处理】
# if i % hvd.size() == hvd.rank():
if i % hvd.size() != hvd.rank():
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[0])
label_threeseg = "0" # add
label_cabinet = "0" # add
elif set_type == "dev":
text_a = tokenization.convert_to_unicode(line[0])
label_threeseg = tokenization.convert_to_unicode(line[1]) # add
label_cabinet = tokenization.convert_to_unicode(line[2]) # add
else:
text_a = tokenization.convert_to_unicode(line[0])
label_threeseg = tokenization.convert_to_unicode(line[1]) # add
label_cabinet = tokenization.convert_to_unicode(line[2]) # add
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=(label_threeseg,label_cabinet)) # modify label
)
return examples
###### modify multi task _create_examples by nijiahui 20220422 end ##########
3)修改 InputFeatures 对象
class InputFeatures(object):
"""A single set of features of data."""
def __init__(
self, input_ids, input_mask, segment_ids, label_id, label_id_cabinet, is_real_example=True
):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
###### add multi task InputFeatures.label_id by nijiahui 20220422 start ##########
self.label_id = label_id
self.label_id_cabinet = label_id_cabinet
###### add multi task InputFeatures.label_id by nijiahui 20220422 end ##########
self.is_real_example = is_real_example
4)修改 convert_single_example 方法,构造 example.feature 时加入多任务 label
def convert_single_example(ex_index, example, label_map, max_seq_length, tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
if isinstance(example, PaddingInputExample):
return InputFeatures(
input_ids=[0] * max_seq_length,
input_mask=[0] * max_seq_length,
segment_ids=[0] * max_seq_length,
###### add multi task by nijiahui 20220422 start ##########
label_id=0,
label_id_cabinet=0,
###### add multi task by nijiahui 20220422 end ##########
is_real_example=False,
)
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenizer.tokenize(example.text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0 : (max_seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
###### add multi task by nijiahui 20220422 start ##########
# label_id = label_map[example.label]
label_id_threeseg = label_map[0][example.label[0]]
label_id_cabinet = label_map[1][example.label[1]]
###### add multi task by nijiahui 20220422 end ##########
if ex_index < 5:
tf.logging.info("*** Example ***")
tf.logging.info("guid: %s" % (example.guid))
tf.logging.info(
"tokens: %s" % " ".join([tokenization.printable_text(x) for x in tokens])
)
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
###### add multi task by nijiahui 20220422 start ##########
tf.logging.info(f"label: {example.label} (label_id_threeseg = {label_id_threeseg}, label_id_cabinet = {label_id_cabinet})")
###### add multi task by nijiahui 20220422 end ##########
feature = InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
###### modify multi task InputFeatures.label_id by nijiahui 20220422 start ##########
label_id=label_id_threeseg,
label_id_cabinet=label_id_cabinet,
###### modify multi task InputFeatures.label_id by nijiahui 20220422 end ##########
is_real_example=True,
)
return feature
5)修改 file_based_convert_examples_to_features 方法,write example 时加入 多任务 label
def file_based_convert_examples_to_features(
examples, label_list, max_seq_length, tokenizer, output_file
):
"""Convert a set of `InputExample`s to a TFRecord file."""
writer = tf.python_io.TFRecordWriter(output_file)
###### add multi task by nijiahui 20220422 start ##########
# label_map = {}
# for (i, label) in enumerate(label_list):
# label_map[label] = i
label_map1 = {}
label_map2 = {}
for (i, label) in enumerate(label_list[0]):
label_map1[label] = i
for (i, label) in enumerate(label_list[1]):
label_map2[label] = i
label_map = (label_map1, label_map2)
###### add multi task by nijiahui 20220422 end ##########
for (ex_index, example) in enumerate(examples):
if ex_index % 100000 == 0:
tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
feature = convert_single_example(
ex_index, example, label_map, max_seq_length, tokenizer
)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(feature.input_ids) # input: List[int]
features["input_mask"] = create_int_feature(feature.input_mask) # input: List[int]
features["segment_ids"] = create_int_feature(feature.segment_ids) # input: List[int]
features["label_ids"] = create_int_feature([feature.label_id]) # input: List[Tuple[str, str]]
###### add multi task by nijiahui 20220422 start ##########
features["label_ids_cabinet"] = create_int_feature([feature.label_id_cabinet])
###### add multi task by nijiahui 20220422 end ##########
features["is_real_example"] = create_int_feature([int(feature.is_real_example)]) # input: List[int]
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
6)修改 file_based_input_fn_builder 方法,加入多任务 label_ids
def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
name_to_features = {
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
###### add multi task label_ids by nijiahui 20220422 start ##########
"label_ids": tf.FixedLenFeature([], tf.int64),
"label_ids_cabinet": tf.FixedLenFeature([], tf.int64),
###### add multi task label_ids by nijiahui 20220422 end ##########
"is_real_example": tf.FixedLenFeature([], tf.int64),
}
2. 修改 - 模型架构
1)修改 create_model 函数
def create_model(
bert_config,
is_training,
input_ids,
input_mask,
segment_ids,
labels,
num_labels,
use_one_hot_embeddings,
):
"""Creates a classification model."""
model = modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings,
)
# In the demo, we are doing a simple classification task on the entire
# segment.
#
# If you want to use the token-level output, use model.get_sequence_output()
# instead.
output_layer = model.get_pooled_output()
hidden_size = output_layer.shape[-1].value
###### add multi task by nijiahui 20220422 start ##########
# label_threeseg, label_cabinet = labels # 注:Tensor 数据类型不支持以此方式拆包
label_threeseg = labels[0] # add
label_cabinet = labels[1] # add
num_labels_threeseg = num_labels[0] # add
num_labels_cabinet = num_labels[1] # add
# 修改 softmax 前的 Dense 全连接,modify_start:
output_weights_threeseg = tf.get_variable(
"output_weights_threeseg", [num_labels_threeseg, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02),
)
output_bias_threeseg = tf.get_variable(
"output_bias_threeseg", [num_labels_threeseg], initializer=tf.zeros_initializer()
)
output_weights_cabinet = tf.get_variable(
"output_weights_cabinet", [num_labels_cabinet, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02),
)
output_bias_cabinet = tf.get_variable(
"output_bias_cabinet", [num_labels_cabinet], initializer=tf.zeros_initializer()
)
# 修改 softmax 前的 Dense 全连接,modify_end.
with tf.variable_scope("loss"):
if is_training:
output_layer_threeseg = tf.nn.dropout(output_layer, keep_prob=0.9) # add
output_layer_cabinet = tf.nn.dropout(output_layer, keep_prob=0.9) # add
# threeseg Dense && softmax logic
logits_threeseg = tf.matmul(output_layer_threeseg, output_weights_threeseg, transpose_b=True)
logits_threeseg = tf.nn.bias_add(logits_threeseg, output_bias_threeseg)
probabilities_threeseg = tf.nn.softmax(logits_threeseg, axis=-1)
log_probs_threeseg = tf.nn.log_softmax(logits_threeseg, axis=-1)
one_hot_labels_threeseg = tf.one_hot(label_threeseg, depth=num_labels_threeseg, dtype=tf.float32)
per_example_loss_threeseg = -tf.reduce_sum(one_hot_labels_threeseg * log_probs_threeseg, axis=-1)
loss_threeseg = tf.reduce_mean(per_example_loss_threeseg)
# cabinet Dense && softmax logic
logits_cabinet = tf.matmul(output_layer_cabinet, output_weights_cabinet, transpose_b=True)
logits_cabinet = tf.nn.bias_add(logits_cabinet, output_bias_cabinet)
probabilities_cabinet = tf.nn.softmax(logits_cabinet, axis=-1)
log_probs_cabinet = tf.nn.log_softmax(logits_cabinet, axis=-1)
one_hot_labels_cabinet = tf.one_hot(label_cabinet, depth=num_labels_cabinet, dtype=tf.float32)
per_example_loss_cabinet = -tf.reduce_sum(one_hot_labels_cabinet * log_probs_cabinet, axis=-1)
loss_cabinet = tf.reduce_mean(per_example_loss_cabinet)
# combine loss_threeseg and loss_cabinet
loss = loss_cabinet + loss_threeseg # add
# modify
return loss, \
[per_example_loss_threeseg, per_example_loss_cabinet], \
[logits_threeseg, logits_cabinet], \
[probabilities_threeseg, probabilities_cabinet], \
embedding_output
##### add multi task by nijiahui 20220422 end ##########
2)修改 model_fn_builder 函数
def model_fn_builder(
bert_config,
num_labels,
init_checkpoint,
learning_rate,
num_train_steps,
num_warmup_steps,
use_tpu,
use_one_hot_embeddings,
):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
##### add label_ids_cabinet and label_ids_threeseg by nijiahui 20220422 start
label_ids = features["label_ids"]
label_ids_cabinet = features["label_ids_cabinet"]
label_ids = [label_ids, label_ids_cabinet]
##### add label_ids_cabinet and label_ids_threeseg by nijiahui 20220422 end
is_real_example = None
if "is_real_example" in features:
is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
else:
is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
is_training = mode == tf.estimator.ModeKeys.TRAIN
# add embedding_out
(total_loss, per_example_loss, logits, probabilities, embedding_out) = create_model(
bert_config,
is_training,
input_ids,
input_mask,
segment_ids,
label_ids,
num_labels,
use_one_hot_embeddings,
)
tvars = tf.trainable_variables()
initialized_variable_names = {}
scaffold_fn = None
if init_checkpoint:
(
assignment_map,
initialized_variable_names,
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
if use_tpu:
def tpu_scaffold():
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
return tf.train.Scaffold()
scaffold_fn = tpu_scaffold
else:
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
tf.logging.info("**** Trainable Variables ****")
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(
" name = %s, shape = %s%s", var.name, var.shape, init_string
)
output_spec = None
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = optimization_hvd.create_optimizer(
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu
)
####### add label_ids_cabinet and label_ids_threeseg by nijiahui 20220422 start
# unpack
logits_threeseg = logits[0]
logits_cabinet = logits[1]
label_threeseg = label_ids[0]
label_cabinet = label_ids[1]
# threeseg
predictions_threeseg = tf.argmax(logits_threeseg, axis=-1, output_type=tf.int32)
accuracy_threeseg = tf.metrics.accuracy(label_threeseg, predictions_threeseg)
# cabinet
predictions_cabinet = tf.argmax(logits_cabinet, axis=-1, output_type=tf.int32)
accuracy_cabinet = tf.metrics.accuracy(label_cabinet, predictions_cabinet)
logging_hook = tf.train.LoggingTensorHook(
{
"loss": total_loss,
"steps": tf.train.get_or_create_global_step(),
"accuracy_threeseg": accuracy_threeseg[1],
"accuracy_cabinet": accuracy_cabinet[1],
},
every_n_iter=int(num_train_steps / FLAGS.num_train_epochs),
)
####### add label_ids_cabinet and label_ids_threeseg by nijiahui 20220422 end
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
loss=total_loss,
training_hooks=[logging_hook],
train_op=train_op,
)
elif mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
###### add multi task by nijiahui 20220422 start ##########
per_example_loss_threeseg = per_example_loss[0]
per_example_loss_cabinet = per_example_loss[1]
logits_threeseg = logits[0]
logits_cabinet = logits[1]
label_threeseg = label_ids[0]
label_cabinet = label_ids[1]
# threeseg
predictions_threeseg = tf.argmax(logits_threeseg, axis=-1, output_type=tf.int32)
accuracy_threeseg = tf.metrics.accuracy(
labels=label_threeseg, predictions=predictions_threeseg, weights=is_real_example
)
loss_threeseg = tf.metrics.mean(values=per_example_loss_threeseg, weights=is_real_example)
# cabinet
predictions_cabinet = tf.argmax(logits_cabinet, axis=-1, output_type=tf.int32)
accuracy_cabinet = tf.metrics.accuracy(
labels=label_cabinet, predictions=predictions_cabinet, weights=is_real_example
)
loss_cabinet = tf.metrics.mean(values=per_example_loss_cabinet, weights=is_real_example)
return {
"eval_accuracy_threeseg": accuracy_threeseg,
"eval_loss_threeseg": loss_threeseg,
"eval_accuracy_cabinet": accuracy_cabinet,
"eval_loss_cabinet": loss_cabinet,
}
###### add multi task by nijiahui 20220422 end ##########
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
loss=total_loss,
eval_metric_ops=metric_fn(per_example_loss, label_ids, logits, is_real_example))
else:
################ remove tpu && modify probabilities by nijiahui 20220418 start ##############
probabilities_threeseg = probabilities[0]
probabilities_cabinet = probabilities[1]
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
predictions={'probabilities_threeseg': probabilities_threeseg,
'probabilities_cabinet':probabilities_cabinet})
################ remove tpu && modify probabilities by nijiahui 20220418 end ##############
return output_spec
return model_fn