NLP模型——UTC

适用:文本分类,小样本

1、环境:
python >= 3.7

pip install paddlepaddle >= 2.5.1
pip install --pre --upgrade paddlenlp -f https://www.paddlepaddle.org.cn/whl/paddlenlp.html

2、直接使用测试

from paddlenlp import Taskflow
schema = ["病情诊断", "治疗方案", "病因分析", "指标解读", "就医建议", "疾病表述", "后果表述", "注意事项", "功效作用", "医疗费用"]
cls = Taskflow("zero_shot_text_classification", schema=schema )
cls("先天性厚甲症去哪里治")

输出

[{'predictions': [{'label': '就医建议', 'score': 0.5494891306403806}], 'text_a': '先天性厚甲症去哪里治'}]

**
说明

schema是语句所有类别,cls(“先天性厚甲症去哪里治”)是测试短句

3、训练

标注软件:label-studio

具体使用
https://blog.csdn.net/qq_43117155/article/details/135343557?spm=1001.2014.3001.5502

数据处理,这里需要将label-studio的数据转化成paddle nlp的训练格式
在这里插入图片描述
代码():

import argparse
import json
import os
import random
import time
from decimal import Decimal

import numpy as np
import paddle

from paddlenlp.utils.log import logger


def set_seed(seed):    paddle.seed(seed)
    random.seed(seed)
    np.random.seed(seed)


class LabelStudioDataConverter(object):
    """
    DataConverter to convert data export from LabelStudio platform
    """

    def __init__(self, options, text_separator):
        super().__init__()
        if isinstance(options, list) and len(options) == 1 and os.path.isfile(options[0]):
            with open(options[0], "r", encoding="utf-8") as fp:
                self.options = [x.strip() for x in fp]
        elif isinstance(options, list) and len(options) > 0:
            self.options = options
        else:
            raise ValueError(
                "Invalid options. Please use file with one label per line or set `options` with condidate labels."
            )
        self.text_separator = text_separator

    def convert_utc_examples(self, raw_examples):
        utc_examples = []
        for example in raw_examples:
            raw_text = example["data"]["text"].split(self.text_separator)
            if len(raw_text) < 1:
                continue
            elif len(raw_text) == 1:
                raw_text.append("")
            elif len(raw_text) > 2:
                raw_text = ["".join(raw_text[:-1]), raw_text[-1]]

            label_list = []
            for raw_label in example["annotations"][0]["result"][0]["value"]["choices"]:
                if raw_label not in self.options:
                    raise ValueError(
                        f"Label `{raw_label}` not found in label candidates `options`. Please recheck the data."
                    )
                label_list.append(np.where(np.array(self.options) == raw_label)[0].tolist()[0])

            utc_examples.append(
                {
                    "text_a": raw_text[0],
                    "text_b": raw_text[1],
                    "question": "",
                    "choices": self.options,
                    "labels": label_list,
                }
            )
        return utc_examples


def do_convert():
    set_seed(args.seed)

    tic_time = time.time()
    if not os.path.exists(args.label_studio_file):
        raise ValueError("Please input the correct path of label studio file.")

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    if len(args.splits) != 0 and len(args.splits) != 3:
        raise ValueError("Only []/ len(splits)==3 accepted for splits.")

    def _check_sum(splits):
        return Decimal(str(splits[0])) + Decimal(str(splits[1])) + Decimal(str(splits[2])) == Decimal("1")

    if len(args.splits) == 3 and not _check_sum(args.splits):
        raise ValueError("Please set correct splits, sum of elements in splits should be equal to 1.")

    with open(args.label_studio_file, "r", encoding="utf-8") as f:
        raw_examples = json.loads(f.read())

    if args.is_shuffle:
        indexes = np.random.permutation(len(raw_examples))
        index_list = indexes.tolist()
        raw_examples = [raw_examples[i] for i in indexes]

    i1, i2, _ = args.splits
    p1 = int(len(raw_examples) * i1)
    p2 = int(len(raw_examples) * (i1 + i2))

    train_ids = index_list[:p1]
    dev_ids = index_list[p1:p2]
    test_ids = index_list[p2:]

    with open(os.path.join(args.save_dir, "sample_index.json"), "w") as fp:
        maps = {"train_ids": train_ids, "dev_ids": dev_ids, "test_ids": test_ids}
        fp.write(json.dumps(maps))

    data_converter = LabelStudioDataConverter(args.options, args.text_separator)

    train_examples = data_converter.convert_utc_examples(raw_examples[:p1])
    dev_examples = data_converter.convert_utc_examples(raw_examples[p1:p2])
    test_examples = data_converter.convert_utc_examples(raw_examples[p2:])

    def _save_examples(save_dir, file_name, examples):
        count = 0
        save_path = os.path.join(save_dir, file_name)
        with open(save_path, "w", encoding="utf-8") as f:
            for example in examples:
                f.write(json.dumps(example, ensure_ascii=False) + "\n")
                count += 1
        logger.info("Save %d examples to %s." % (count, save_path))

    _save_examples(args.save_dir, "train.txt", train_examples)
    _save_examples(args.save_dir, "dev.txt", dev_examples)
    _save_examples(args.save_dir, "test.txt", test_examples)

    logger.info("Finished! It takes %.2f seconds" % (time.time() - tic_time))


if __name__ == "__main__":
    # yapf: disable
    parser = argparse.ArgumentParser()

    parser.add_argument("--label_studio_file", default="./data/label_studio.json", type=str, help="The annotation file exported from label studio platform.")
    parser.add_argument("--save_dir", default="./data", type=str, help="The path of data that you wanna save.")
    parser.add_argument("--splits", default=[0.8, 0.1, 0.1], type=float, nargs="*", help="The ratio of samples in datasets. [0.6, 0.2, 0.2] means 60% samples used for training, 20% for evaluation and 20% for test.")
    parser.add_argument("--text_separator", type=str, default='\t', help="Separator for classification with two input texts.")
    parser.add_argument("--options", default=None, type=str, nargs="+", help="The options for classification.")
    parser.add_argument("--is_shuffle", default=True, type=bool, help="Whether to shuffle the labeled dataset, defaults to True.")
    parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization")
    args = parser.parse_args()
    # yapf: enable

    do_convert()

执行代码
在这里插入图片描述
–label_studio_file是你导出的label_studio 文件,
–save_dir写你想输出的文件位置,
–splits代表训练、测试、验证集合的比例分配,就0.8 0.1 0.1比较合适
–options代表类别标签(我的标签如下,格式就一个类别一行)
在这里插入图片描述

训练

源码拿下来后,进入到这个位置
在这里插入图片描述

python run_train.py  \
    --device gpu \
    --logging_steps 10 \
    --save_steps 100 \
    --eval_steps 100 \
    --seed 1000 \
    --model_name_or_path utc-base \
    --output_dir ./checkpoint/model_best \
    --dataset_path ./data/ \
    --max_seq_length 512  \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --num_train_epochs 20 \
    --learning_rate 1e-5 \
    --do_train \
    --do_eval \
    --do_export \
    --export_model_dir ./checkpoint/model_best \
    --overwrite_output_dir \
    --disable_tqdm True \
    --metric_for_best_model macro_f1 \
    --load_best_model_at_end  True \
    --save_total_limit 1 \
    --save_plm

–dataset_path ./data/ 这个就是上一步转好的数据位置

训练开始
在这里插入图片描述
这样就开始训练了

  • 8
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值