基于 Bert 的多任务改造

41 篇文章 2 订阅
16 篇文章 0 订阅
本文介绍了如何在BERT预训练模型的基础上,通过构造辅助任务和并行任务结构,实现多任务训练。主要内容包括主辅任务的结合、模型结构设计、代码调整实例,以及如何在run_classifier.py中处理多任务标签和评估指标。
摘要由CSDN通过智能技术生成

一. Bert 的多任务认识

以bert为代表的预训练模型多任务可以应用在多种场景。

1. 主辅任务:

比如我们当前进行任务A,但是我们可以构造一个辅助任务B一起参与训练,来达到提高任务A性能的目的,比如人为的去构造一个辅助任务 MLM(Bert 预训练语言遮蔽模型)这样的辅助任务,辅助去提高任务A,线上推理时并不输出任务B的结果。

2. 并行任务:

本身就需要进行多个任务,比如ABC,一样重要,但任务类型相似,如果分开训练,那么就需要3个模型,此时可以尝试共享一个模型,即共享大部分参数,差异化小部分参数

二,多任务模型结构的设计

1. encoder 完全共享

就像 finetune 一样,encoder 层完全共享,只在 Bert 的最后一层(池化层)根据不同任务设计不同层。而此时有两种不同的计算损失方式:

  1. 单条数据,根据 if - else 判断自身任务,只计算自身任务的 Loss 损失值,完成反向传播。
  2. 单条数据同时计算多个任务的 Loss 损失值,而后将所有损失值相加,完成反向传播。

目前我们采用第 2 种计算 Loss 方式。
在这里插入图片描述

三. 多任务模型代码修改

以下修改全部在 run_classifier.py 文件中

1. 修改 - 数据预处理
1)修改 main 函数
  • label.csv 文件的输入由一份改为两份
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    ###### add multi task by nijiahui 20220422 start ##########
    # label_list = processor.get_labels(data_dir)
    label_threeseg_list = processor.get_label_threeseg(data_dir)
    label_cabinet_list = processor.get_label_cabinet(data_dir)
    ###### add multi task by nijiahui 20220422 end ##########
    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
    )
  • 修改 model_fn_builder 参数 num_labels 为元组
    model_fn = model_fn_builder(
        bert_config=bert_config,
		###### add multi task by nijiahui 20220422 start ##########
        num_labels=(len(label_threeseg_list),len(label_cabinet_list)),
		###### add multi task by nijiahui 20220422 end ##########
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
    )
  • 修改 file_based_convert_examples_to_features 函数的 label_list 入参为元组
    if FLAGS.do_train:
        train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        ###### add multi task by nijiahui 20220422 start ##########
        # file_based_convert_examples_to_features(
        #     train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file
        # )
        label_list = (label_threeseg_list, label_cabinet_list)
        file_based_convert_examples_to_features(
            train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file
        )
        ###### add multi task by nijiahui 20220422 end ##########
        tf.logging.info("***** Running training *****")
        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
		###### add multi task by nijiahui 20220422 start ##########
        label_list = (label_threeseg_list, label_cabinet_list)
		###### add multi task by nijiahui 20220422 end ##########
        file_based_convert_examples_to_features(
            eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file
        )
  • 修改 eval 时逻辑,增加验证集每个任务的推理结果、每个任务是否推理正确、置信度等字段输出
        ###### add eval_data_result.csv file by nijiahui 20220420 start ###########
        dev_result = estimator.predict(input_fn=eval_input_fn)
        eval_data_res_file = os.path.join(FLAGS.output_dir, "eval_data_results.csv")
        with tf.gfile.GFile(eval_data_res_file, "w") as writer:
            for example, dev_prediction in zip(eval_examples, dev_result):
                index = str(example.guid)
                text = str(example.text_a)
                label_threeseg, label_cabinet = example.label

                probabilities_threeseg = dev_prediction["probabilities_threeseg"]
                probabilities_cabinet = dev_prediction["probabilities_cabinet"]

                pred_threeseg = str(label_threeseg_list[np.argmax(probabilities_threeseg)])
                max_score_threeseg = str(np.max(probabilities_threeseg))

                pred_cabinet = str(label_cabinet_list[np.argmax(probabilities_cabinet)])
                max_score_cabinet = str(np.max(probabilities_cabinet))

                is_right_threeseg = "1" if pred_threeseg==label_threeseg else "0"
                is_right_cabinet = "1" if pred_cabinet==label_cabinet else "0"
                res_list = [
                    index,
                    text,
                    pred_threeseg, label_threeseg, is_right_threeseg,
                    pred_cabinet, label_cabinet, is_right_cabinet,
                    max_score_threeseg, max_score_cabinet,
                ]
                output_line = ("\t".join(res_list) + "\n")
                writer.write(output_line)
        ###### add eval_data_result.csv file by nijiahui 20220420 end ###########
2)修改 Cls_Processor 函数
  • 增加多任务的 get_label 方法
    ###### add multi task get_label by nijiahui 20220422 start ##########
    # def get_labels(self, data_dir):
    #     """See base class."""
    #     labels = [
    #         line.rstrip().split(";")
    #         for line in open(os.path.join(data_dir, "label.csv"), "r", encoding="utf-8")
    #         if line.rstrip()
    #     ]
    #     return labels

    def get_label_threeseg(self, data_dir):
        """See base class."""
        labels = [
            line.rstrip()
            for line in open(os.path.join(data_dir, "label_threeseg.csv"), "r", encoding="utf-8")
            if line.rstrip()
        ]
        return labels

    def get_label_cabinet(self, data_dir):
        """See base class."""
        labels = [
            line.rstrip()
            for line in open(os.path.join(data_dir, "label_cabinet.csv"), "r", encoding="utf-8")
            if line.rstrip()
        ]
        return labels

    ###### add multi task get_label by nijiahui 20220422 end ##########
  • 修改 create_example 方法,修改 example 的 label 属性为包含多个 label_id 的元组
    ###### modify multi task create_examples by nijiahui 20220422 start ##########
    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            if set_type == "test":
                text_a = tokenization.convert_to_unicode(line[0])
                label_threeseg = "0"                                        # add
                label_cabinet = "0"                                         # add
            elif set_type == "dev":
                text_a = tokenization.convert_to_unicode(line[0])
                label_threeseg = tokenization.convert_to_unicode(line[1])   # add
                label_cabinet = tokenization.convert_to_unicode(line[2])    # add
            else:
                text_a = tokenization.convert_to_unicode(line[0])
                label_threeseg = tokenization.convert_to_unicode(line[1])   # add
                label_cabinet = tokenization.convert_to_unicode(line[2])    # add
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=(label_threeseg,label_cabinet))   # modify label
            )
        return examples

    def _create_examples_train(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            # 【二,horovod - 区对数据 - 分卡处理】
            # if i % hvd.size() == hvd.rank():
            if i % hvd.size() != hvd.rank():
                continue
            guid = "%s-%s" % (set_type, i)
            if set_type == "test":
                text_a = tokenization.convert_to_unicode(line[0])
                label_threeseg = "0"                                        # add
                label_cabinet = "0"                                         # add
            elif set_type == "dev":
                text_a = tokenization.convert_to_unicode(line[0])
                label_threeseg = tokenization.convert_to_unicode(line[1])   # add
                label_cabinet = tokenization.convert_to_unicode(line[2])    # add
            else:
                text_a = tokenization.convert_to_unicode(line[0])
                label_threeseg = tokenization.convert_to_unicode(line[1])   # add
                label_cabinet = tokenization.convert_to_unicode(line[2])    # add
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=(label_threeseg,label_cabinet))   # modify label
            )
        return examples
    ###### modify multi task _create_examples by nijiahui 20220422 end ##########
3)修改 InputFeatures 对象
class InputFeatures(object):
    """A single set of features of data."""

    def __init__(
        self, input_ids, input_mask, segment_ids, label_id, label_id_cabinet, is_real_example=True
    ):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        ###### add multi task InputFeatures.label_id  by nijiahui 20220422 start ##########
        self.label_id = label_id
        self.label_id_cabinet = label_id_cabinet
        ###### add multi task InputFeatures.label_id  by nijiahui 20220422 end ##########
        self.is_real_example = is_real_example
4)修改 convert_single_example 方法,构造 example.feature 时加入多任务 label
def convert_single_example(ex_index, example, label_map, max_seq_length, tokenizer):
    """Converts a single `InputExample` into a single `InputFeatures`."""

    if isinstance(example, PaddingInputExample):
        return InputFeatures(
            input_ids=[0] * max_seq_length,
            input_mask=[0] * max_seq_length,
            segment_ids=[0] * max_seq_length,
            ###### add multi task by nijiahui 20220422 start ##########
            label_id=0,
            label_id_cabinet=0,
            ###### add multi task by nijiahui 20220422 end ##########
            is_real_example=False,
        )

    tokens_a = tokenizer.tokenize(example.text_a)
    tokens_b = None
    if example.text_b:
        tokens_b = tokenizer.tokenize(example.text_b)

    if tokens_b:
        # Modifies `tokens_a` and `tokens_b` in place so that the total
        # length is less than the specified length.
        # Account for [CLS], [SEP], [SEP] with "- 3"
        _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
    else:
        # Account for [CLS] and [SEP] with "- 2"
        if len(tokens_a) > max_seq_length - 2:
            tokens_a = tokens_a[0 : (max_seq_length - 2)]

    # The convention in BERT is:
    # (a) For sequence pairs:
    #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
    #  type_ids: 0     0  0    0    0     0       0 0     1  1  1  1   1 1
    # (b) For single sequences:
    #  tokens:   [CLS] the dog is hairy . [SEP]
    #  type_ids: 0     0   0   0  0     0 0
    #
    # Where "type_ids" are used to indicate whether this is the first
    # sequence or the second sequence. The embedding vectors for `type=0` and
    # `type=1` were learned during pre-training and are added to the wordpiece
    # embedding vector (and position vector). This is not *strictly* necessary
    # since the [SEP] token unambiguously separates the sequences, but it makes
    # it easier for the model to learn the concept of sequences.
    #
    # For classification tasks, the first vector (corresponding to [CLS]) is
    # used as the "sentence vector". Note that this only makes sense because
    # the entire model is fine-tuned.
    tokens = []
    segment_ids = []
    tokens.append("[CLS]")
    segment_ids.append(0)
    for token in tokens_a:
        tokens.append(token)
        segment_ids.append(0)
    tokens.append("[SEP]")
    segment_ids.append(0)

    if tokens_b:
        for token in tokens_b:
            tokens.append(token)
            segment_ids.append(1)
        tokens.append("[SEP]")
        segment_ids.append(1)

    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.
    input_mask = [1] * len(input_ids)

    # Zero-pad up to the sequence length.
    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    ###### add multi task by nijiahui 20220422 start ##########
    # label_id = label_map[example.label]
    label_id_threeseg = label_map[0][example.label[0]]
    label_id_cabinet = label_map[1][example.label[1]]
    ###### add multi task by nijiahui 20220422 end ##########

    if ex_index < 5:
        tf.logging.info("*** Example ***")
        tf.logging.info("guid: %s" % (example.guid))
        tf.logging.info(
            "tokens: %s" % " ".join([tokenization.printable_text(x) for x in tokens])
        )
        tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
        tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
        tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
        ###### add multi task by nijiahui 20220422 start ##########
        tf.logging.info(f"label: {example.label} (label_id_threeseg = {label_id_threeseg}, label_id_cabinet = {label_id_cabinet})")
        ###### add multi task by nijiahui 20220422 end ##########

    feature = InputFeatures(
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        ###### modify multi task InputFeatures.label_id by nijiahui 20220422 start ##########
        label_id=label_id_threeseg,
        label_id_cabinet=label_id_cabinet,
        ###### modify multi task InputFeatures.label_id by nijiahui 20220422 end ##########
        is_real_example=True,
    )
    return feature

5)修改 file_based_convert_examples_to_features 方法,write example 时加入 多任务 label
def file_based_convert_examples_to_features(
    examples, label_list, max_seq_length, tokenizer, output_file
):
    """Convert a set of `InputExample`s to a TFRecord file."""

    writer = tf.python_io.TFRecordWriter(output_file)

    ###### add multi task by nijiahui 20220422 start ##########
    # label_map = {}
    # for (i, label) in enumerate(label_list):
    #     label_map[label] = i
    label_map1 = {}
    label_map2 = {}

    for (i, label) in enumerate(label_list[0]):
        label_map1[label] = i
    for (i, label) in enumerate(label_list[1]):
        label_map2[label] = i
    label_map = (label_map1, label_map2)
    ###### add multi task by nijiahui 20220422 end ##########

    for (ex_index, example) in enumerate(examples):
        if ex_index % 100000 == 0:
            tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))

        feature = convert_single_example(
            ex_index, example, label_map, max_seq_length, tokenizer
        )

        def create_int_feature(values):
            f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
            return f

        features = collections.OrderedDict()
        features["input_ids"] = create_int_feature(feature.input_ids)  # input: List[int]
        features["input_mask"] = create_int_feature(feature.input_mask)  # input: List[int]
        features["segment_ids"] = create_int_feature(feature.segment_ids)  # input: List[int]
        features["label_ids"] = create_int_feature([feature.label_id])  # input: List[Tuple[str, str]]
        ###### add multi task by nijiahui 20220422 start ##########
        features["label_ids_cabinet"] = create_int_feature([feature.label_id_cabinet])
        ###### add multi task by nijiahui 20220422 end ##########
        features["is_real_example"] = create_int_feature([int(feature.is_real_example)])  # input: List[int]
        tf_example = tf.train.Example(features=tf.train.Features(feature=features))
        writer.write(tf_example.SerializeToString())
    writer.close()
6)修改 file_based_input_fn_builder 方法,加入多任务 label_ids
def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder):
    """Creates an `input_fn` closure to be passed to TPUEstimator."""

    name_to_features = {
        "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
        "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
        "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
        ###### add multi task label_ids by nijiahui 20220422 start ##########
        "label_ids": tf.FixedLenFeature([], tf.int64),
        "label_ids_cabinet": tf.FixedLenFeature([], tf.int64),
        ###### add multi task label_ids by nijiahui 20220422 end ##########
        "is_real_example": tf.FixedLenFeature([], tf.int64),
    }
2. 修改 - 模型架构
1)修改 create_model 函数
def create_model(
    bert_config,
    is_training,
    input_ids,
    input_mask,
    segment_ids,
    labels,
    num_labels,
    use_one_hot_embeddings,
):
    """Creates a classification model."""
    model = modeling.BertModel(
        config=bert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        token_type_ids=segment_ids,
        use_one_hot_embeddings=use_one_hot_embeddings,
    )

    # In the demo, we are doing a simple classification task on the entire
    # segment.
    #
    # If you want to use the token-level output, use model.get_sequence_output()
    # instead.
    output_layer = model.get_pooled_output()

    hidden_size = output_layer.shape[-1].value

    ###### add multi task by nijiahui 20220422 start ##########
    # label_threeseg, label_cabinet = labels  # 注:Tensor 数据类型不支持以此方式拆包
    label_threeseg = labels[0]              # add
    label_cabinet = labels[1]               # add
    num_labels_threeseg = num_labels[0]     # add
    num_labels_cabinet = num_labels[1]      # add

    # 修改 softmax 前的 Dense 全连接,modify_start:
    output_weights_threeseg = tf.get_variable(
        "output_weights_threeseg", [num_labels_threeseg, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02),
    )
    output_bias_threeseg = tf.get_variable(
        "output_bias_threeseg", [num_labels_threeseg], initializer=tf.zeros_initializer()
    )
    output_weights_cabinet = tf.get_variable(
        "output_weights_cabinet", [num_labels_cabinet, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02),
    )
    output_bias_cabinet = tf.get_variable(
        "output_bias_cabinet", [num_labels_cabinet], initializer=tf.zeros_initializer()
    )
    # 修改 softmax 前的 Dense 全连接,modify_end.

    with tf.variable_scope("loss"):
        if is_training:
            output_layer_threeseg = tf.nn.dropout(output_layer, keep_prob=0.9)     # add
            output_layer_cabinet = tf.nn.dropout(output_layer, keep_prob=0.9)                       # add

        # threeseg Dense && softmax logic
        logits_threeseg = tf.matmul(output_layer_threeseg, output_weights_threeseg, transpose_b=True)
        logits_threeseg = tf.nn.bias_add(logits_threeseg, output_bias_threeseg)
        probabilities_threeseg = tf.nn.softmax(logits_threeseg, axis=-1)
        log_probs_threeseg = tf.nn.log_softmax(logits_threeseg, axis=-1)
        one_hot_labels_threeseg = tf.one_hot(label_threeseg, depth=num_labels_threeseg, dtype=tf.float32)
        per_example_loss_threeseg = -tf.reduce_sum(one_hot_labels_threeseg * log_probs_threeseg, axis=-1)
        loss_threeseg = tf.reduce_mean(per_example_loss_threeseg)

        # cabinet Dense && softmax logic
        logits_cabinet = tf.matmul(output_layer_cabinet, output_weights_cabinet, transpose_b=True)
        logits_cabinet = tf.nn.bias_add(logits_cabinet, output_bias_cabinet)
        probabilities_cabinet = tf.nn.softmax(logits_cabinet, axis=-1)
        log_probs_cabinet = tf.nn.log_softmax(logits_cabinet, axis=-1)
        one_hot_labels_cabinet = tf.one_hot(label_cabinet, depth=num_labels_cabinet, dtype=tf.float32)
        per_example_loss_cabinet = -tf.reduce_sum(one_hot_labels_cabinet * log_probs_cabinet, axis=-1)
        loss_cabinet = tf.reduce_mean(per_example_loss_cabinet)

        # combine loss_threeseg and loss_cabinet
        loss = loss_cabinet + loss_threeseg         # add

        # modify
        return loss, \
               [per_example_loss_threeseg, per_example_loss_cabinet], \
               [logits_threeseg, logits_cabinet], \
               [probabilities_threeseg, probabilities_cabinet], \
               embedding_output
    ##### add multi task by nijiahui 20220422 end ##########
2)修改 model_fn_builder 函数
def model_fn_builder(
    bert_config,
    num_labels,
    init_checkpoint,
    learning_rate,
    num_train_steps,
    num_warmup_steps,
    use_tpu,
    use_one_hot_embeddings,
):
    """Returns `model_fn` closure for TPUEstimator."""

    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        ##### add label_ids_cabinet and label_ids_threeseg by nijiahui 20220422  start
        label_ids = features["label_ids"]
        label_ids_cabinet = features["label_ids_cabinet"]
        label_ids = [label_ids, label_ids_cabinet]
        ##### add label_ids_cabinet and label_ids_threeseg by nijiahui 20220422  end
        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # add embedding_out
        (total_loss, per_example_loss, logits, probabilities, embedding_out) = create_model(
            bert_config,
            is_training,
            input_ids,
            input_mask,
            segment_ids,
            label_ids,
            num_labels,
            use_one_hot_embeddings,
        )

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (
                assignment_map,
                initialized_variable_names,
            ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info(
                "  name = %s, shape = %s%s", var.name, var.shape, init_string
            )

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:

            train_op = optimization_hvd.create_optimizer(
                total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu
            )
            ####### add label_ids_cabinet and label_ids_threeseg by nijiahui 20220422  start
            # unpack
            logits_threeseg = logits[0]
            logits_cabinet = logits[1]

            label_threeseg = label_ids[0]
            label_cabinet = label_ids[1]
            # threeseg
            predictions_threeseg = tf.argmax(logits_threeseg, axis=-1, output_type=tf.int32)
            accuracy_threeseg = tf.metrics.accuracy(label_threeseg, predictions_threeseg)

            # cabinet
            predictions_cabinet = tf.argmax(logits_cabinet, axis=-1, output_type=tf.int32)
            accuracy_cabinet = tf.metrics.accuracy(label_cabinet, predictions_cabinet)

            logging_hook = tf.train.LoggingTensorHook(
                {
                    "loss": total_loss,
                    "steps": tf.train.get_or_create_global_step(),
                    "accuracy_threeseg": accuracy_threeseg[1],
                    "accuracy_cabinet": accuracy_cabinet[1],
                },
                every_n_iter=int(num_train_steps / FLAGS.num_train_epochs),
            )
            ####### add label_ids_cabinet and label_ids_threeseg by nijiahui 20220422  end

            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=total_loss,
                training_hooks=[logging_hook],
                train_op=train_op,
            )


        elif mode == tf.estimator.ModeKeys.EVAL:
            def metric_fn(per_example_loss, label_ids, logits, is_real_example):
                ###### add multi task by nijiahui 20220422 start ##########
                per_example_loss_threeseg = per_example_loss[0]
                per_example_loss_cabinet = per_example_loss[1]
                logits_threeseg = logits[0]
                logits_cabinet = logits[1]
                label_threeseg = label_ids[0]
                label_cabinet = label_ids[1]

                # threeseg
                predictions_threeseg = tf.argmax(logits_threeseg, axis=-1, output_type=tf.int32)
                accuracy_threeseg = tf.metrics.accuracy(
                    labels=label_threeseg, predictions=predictions_threeseg, weights=is_real_example
                )
                loss_threeseg = tf.metrics.mean(values=per_example_loss_threeseg, weights=is_real_example)

                # cabinet
                predictions_cabinet = tf.argmax(logits_cabinet, axis=-1, output_type=tf.int32)
                accuracy_cabinet = tf.metrics.accuracy(
                    labels=label_cabinet, predictions=predictions_cabinet, weights=is_real_example
                )
                loss_cabinet = tf.metrics.mean(values=per_example_loss_cabinet, weights=is_real_example)

                return {
                    "eval_accuracy_threeseg": accuracy_threeseg,
                    "eval_loss_threeseg": loss_threeseg,
                    "eval_accuracy_cabinet": accuracy_cabinet,
                    "eval_loss_cabinet": loss_cabinet,
                }
                ###### add multi task by nijiahui 20220422 end ##########
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metric_ops=metric_fn(per_example_loss, label_ids, logits, is_real_example))

        else:
            ################ remove tpu && modify probabilities by nijiahui 20220418 start ##############
            probabilities_threeseg = probabilities[0]
            probabilities_cabinet = probabilities[1]

            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={'probabilities_threeseg': probabilities_threeseg,
                             'probabilities_cabinet':probabilities_cabinet})
            ################ remove tpu && modify probabilities by nijiahui  20220418 end ##############

        return output_spec

    return model_fn
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>