2024.4.28 UIE模型的训练和评估
UIE使用daccano进行数据标注,并提供了将daccano数据转化为训练数据格式的代码。
一、将爬虫获取的json数据转化为daccano数据的格式
我们根据爬取的大量文档文本和其中的超链接实体关系,对原文本中的已标注实体转化成daccano标注的实体的格式
import json
import re
def split_doc(doc):
single_sentences = re.split(r"(?<=[。!?])", docs[doc_title])
sentences = []
merged_sentence = ''
for sentence in single_sentences:
if len(merged_sentence) + len(sentence) < 512:
merged_sentence += sentence
else:
sentences.append(merged_sentence)
merged_sentence = sentence
if merged_sentence != '':
sentences.append(merged_sentence)
return sentences
if __name__ == '__main__':
domain_name = 'computer_network'
doc_f = open('/home/yunpu/Data/codes/VCRS/data/wiki_data/' + domain_name + '_docs.json' ,'r')
mention_f = open('/home/yunpu/Data/codes/VCRS/data/wiki_data/' + domain_name + '_mentions_with_title.json', 'r')
out_f = open('/home/yunpu/Data/codes/VCRS/UIE/data/' + domain_name + '_daccano.jsonl' ,'w')
# entitys = set()
# docs = {}
# for line in doc_f:
# doc = json.loads(line)
# assert doc['title'] not in entitys
# entitys.add(doc['title'])
# docs[doc['title']] = doc['content']
doc_titles = set()
docs = {}
entity_id = {}
cnt = 0
mentions = {}
for line in doc_f:
doc = json.loads(line)
assert doc['title'] not in doc_titles
doc_titles.add(doc['title'])
docs[doc['title']] = doc['content']
entity_id[doc['title']] = cnt
cnt += 1
for line in mention_f:
mention = json.loads(line)
if mention['doc_title'] not in mentions.keys():
mentions[mention['doc_title']] = [mention]
else:
mentions[mention['doc_title']].append(mention)
sample_cnt = 0
for doc_title in docs.keys():
if doc_title not in mentions.keys():
continue
# sentences = re.split(r"(?<=[。!?])", docs[doc_title])
sentences = split_doc(docs[doc_title])
if sample_cnt == 0:
print(sentences)
prelen = 0
cur = 0
for sentence in sentences:
if len(sentence) < 512:
sample = {}
sample['id'] = sample_cnt
sample_cnt += 1
sample['text'] = sentence
sample['relations'] = []
sample['entities'] = []
while cur < len(mentions[doc_title]) and mentions[doc_title][cur]['start_pos'] >= prelen and mentions[doc_title][cur]['end_pos'] <= prelen + len(sentence):
if mentions[doc_title][cur]['entity_name'] not in entity_id.keys():
entity_id[mentions[doc_title][cur]['entity_name']] = cnt
cnt += 1
entity = {
"id": entity_id[mentions[doc_title][cur]['entity_name']],
"start_offset": mentions[doc_title][cur]['start_pos'] - prelen,
"end_offset": mentions[doc_title][cur]['end_pos'] - prelen,
"label": "实体"
}
sample['entities'].append(entity)
cur += 1
if len(sample['entities']) == 0:
sample_cnt -= 1
else:
out_f.write(json.dumps(sample, ensure_ascii=False) + '\n')
prelen += len(sentence)
# if sample_cnt > 50:
# break
代码标注后的daccano数据
{"id": 0, "text": "计算机网络是指将地理位置不同的具有独立功能的多台计算机及其外部设备,通过通信线路和通信设备连接起来,在网络操作系统,网络管理软件及网络通信协议的管理和协调下,实现资源共享和信息传递的计算机系统。计算机网络主要是由一些通用的、可编程的硬件互连而成的。这些可编程的硬件能够用来传送多种不同类型的数据,并能支持广泛的和日益增长的应用。计算机网络Computer network计算机网络系统互联网信息的传输与共享网络操作系统计算机网络也称计算机通信网。关于计算机网络的最简单定义是:一些相互连接的、以共享资源为目的的、自治的计算机的集合。若按此定义,则早期的面向终端的网络都不能算是计算机网络,而只能称为联机系统(因为那时的许多终端不能算是自治的计算机)。但随着硬件价格的下降,许多终端都具有一定的智能,因而“终端”和“自治的计算机”逐渐失去了严格的界限。若用微型计算机作为终端使用,按上述定义,则早期的那种面向终端的网络也可称为计算机网络。另外,从逻辑功能上看,计算机网络是以传输信息为基础目的,用通信线路将多个计算机连接起来的计算机系统的集合,一个计算机网络组成包括传输介质和通信设备。", "relations": [], "entities": [{"id": 1, "start_offset": 8, "end_offset": 12, "label": "实体"}, {"id": 2, "start_offset": 29, "end_offset": 33, "label": "实体"}, {"id": 3, "start_offset": 36, "end_offset": 40, "label": "实体"}, {"id": 4, "start_offset": 51, "end_offset": 57, "label": "实体"}, {"id": 5, "start_offset": 58, "end_offset": 64, "label": "实体"}, {"id": 6, "start_offset": 65, "end_offset": 71, "label": "实体"}, {"id": 7, "start_offset": 81, "end_offset": 85, "label": "实体"}, {"id": 8, "start_offset": 86, "end_offset": 90, "label": "实体"}, {"id": 9, "start_offset": 91, "end_offset": 96, "label": "实体"}, {"id": 10, "start_offset": 216, "end_offset": 222, "label": "实体"}, {"id": 11, "start_offset": 247, "end_offset": 251, "label": "实体"}, {"id": 12, "start_offset": 299, "end_offset": 303, "label": "实体"}, {"id": 13, "start_offset": 377, "end_offset": 382, "label": "实体"}, {"id": 3, "start_offset": 447, "end_offset": 451, "label": "实体"}, {"id": 14, "start_offset": 482, "end_offset": 486, "label": "实体"}, {"id": 15, "start_offset": 487, "end_offset": 491, "label": "实体"}]}
二、人工标注
因为百度词条中仅对一部分实体进行了标注,重复实体被忽略,因此我们首先使用代码对相同实体进行了补充标注,并人工对一些代码忽略的实体进行了标注
Doccano 是一个用于文本标注的开源工具,支持多种语言任务如命名实体识别、情感分析和文本分类。以下是详细的本地部署步骤。
一、环境准备
- 操作系统:Windows
- Python:3.8 或以上
- Docker:最新版本
二、安装 Docker
sudo apt-get update
sudo apt-get install -y \
ca-certificates \
curl \
gnupg \
lsb-release
sudo mkdir -p /etc/apt/keyrings
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \
$(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt-get update
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-compose-plugin
启动 Docker 服务:
sudo systemctl start docker
sudo systemctl enable docker
三、下载并启动 Doccano
使用 Docker 部署 Doccano 非常简单。首先,拉取 Doccano 的 Docker 镜像:
docker pull doccano/doccano
然后,运行 Doccano 容器:
docker run -d --name doccano -p 8000:8000 doccano/doccano
这会启动 Doccano,并在本地的 8000 端口上提供服务。
四、访问 Doccano
在浏览器中访问 http://localhost:8000
,你会看到 Doccano 的登录页面。首次运行时需要创建一个超级用户来管理项目和用户。
五、数据标注
先导入刚刚代码标注的json文件
可以看到我们刚刚使用代码进行的自动标注已经导入了daccano的系统之中,我们在这个基础上对遗落的实体进行了标注
最终我们取得了约50条500字长度的标注数据对UIE模型进行微调
导出标注好的数据
标注好的实体
标注好的关系
三、数据格式转换
将doccano格式的数据转换为模型训练格式
# coding=utf-8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import time
from decimal import Decimal
import numpy as np
from utils import convert_cls_examples, convert_ext_examples, set_seed
from paddlenlp.trainer.argparser import strtobool
from paddlenlp.utils.log import logger
def do_convert():
set_seed(args.seed)
tic_time = time.time()
if not os.path.exists(args.doccano_file):
raise ValueError("Please input the correct path of doccano file.")
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
if len(args.splits) != 0 and len(args.splits) != 3:
raise ValueError("Only []/ len(splits)==3 accepted for splits.")
def _check_sum(splits):
return Decimal(str(splits[0])) + Decimal(str(splits[1])) + Decimal(str(splits[2])) == Decimal("1")
if len(args.splits) == 3 and not _check_sum(args.splits):
raise ValueError("Please set correct splits, sum of elements in splits should be equal to 1.")
with open(args.doccano_file, "r", encoding="utf-8") as f:
raw_examples = f.readlines()
def _create_ext_examples(
examples,
negative_ratio,
prompt_prefix="情感倾向",
options=["正向", "负向"],
separator="##",
shuffle=False,
is_train=True,
schema_lang="ch",
):
entities, relations, aspects = convert_ext_examples(
examples, negative_ratio, prompt_prefix, options, separator, is_train, schema_lang
)
examples = entities + relations + aspects
if shuffle:
indexes = np.random.permutation(len(examples))
examples = [examples[i] for i in indexes]
return examples
def _create_cls_examples(examples, prompt_prefix, options, shuffle=False):
examples = convert_cls_examples(examples, prompt_prefix, options)
if shuffle:
indexes = np.random.permutation(len(examples))
examples = [examples[i] for i in indexes]
return examples
def _save_examples(save_dir, file_name, examples):
count = 0
save_path = os.path.join(save_dir, file_name)
with open(save_path, "w", encoding="utf-8") as f:
for example in examples:
f.write(json.dumps(example, ensure_ascii=False) + "\n")
count += 1
logger.info("Save %d examples to %s." % (count, save_path))
if len(args.splits) == 0:
if args.task_type == "ext":
examples = _create_ext_examples(
raw_examples,
args.negative_ratio,
args.prompt_prefix,
args.options,
args.separator,
args.is_shuffle,
schema_lang=args.schema_lang,
)
else:
examples = _create_cls_examples(raw_examples, args.prompt_prefix, args.options, args.is_shuffle)
_save_examples(args.save_dir, "train.txt", examples)
else:
if args.is_shuffle:
indexes = np.random.permutation(len(raw_examples))
index_list = indexes.tolist()
raw_examples = [raw_examples[i] for i in indexes]
else:
index_list = list(range(len(raw_examples)))
i1, i2, _ = args.splits
p1 = int(len(raw_examples) * i1)
p2 = int(len(raw_examples) * (i1 + i2))
train_ids = index_list[:p1]
dev_ids = index_list[p1:p2]
test_ids = index_list[p2:]
with open(os.path.join(args.save_dir, "sample_index.json"), "w") as fp:
maps = {"train_ids": train_ids, "dev_ids": dev_ids, "test_ids": test_ids}
fp.write(json.dumps(maps))
if args.task_type == "ext":
train_examples = _create_ext_examples(
raw_examples[:p1],
args.negative_ratio,
args.prompt_prefix,
args.options,
args.separator,
args.is_shuffle,
schema_lang=args.schema_lang,
)
dev_examples = _create_ext_examples(
raw_examples[p1:p2],
-1,
args.prompt_prefix,
args.options,
args.separator,
is_train=False,
schema_lang=args.schema_lang,
)
test_examples = _create_ext_examples(
raw_examples[p2:],
-1,
args.prompt_prefix,
args.options,
args.separator,
is_train=False,
schema_lang=args.schema_lang,
)
else:
train_examples = _create_cls_examples(raw_examples[:p1], args.prompt_prefix, args.options)
dev_examples = _create_cls_examples(raw_examples[p1:p2], args.prompt_prefix, args.options)
test_examples = _create_cls_examples(raw_examples[p2:], args.prompt_prefix, args.options)
_save_examples(args.save_dir, "train.txt", train_examples)
_save_examples(args.save_dir, "dev.txt", dev_examples)
_save_examples(args.save_dir, "test.txt", test_examples)
logger.info("Finished! It takes %.2f seconds" % (time.time() - tic_time))
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--doccano_file", default="./data/doccano.json", type=str, help="The doccano file exported from doccano platform.")
parser.add_argument("--save_dir", default="./data", type=str, help="The path of data that you wanna save.")
parser.add_argument("--negative_ratio", default=5, type=int, help="Used only for the extraction task, the ratio of positive and negative samples, number of negtive samples = negative_ratio * number of positive samples")
parser.add_argument("--splits", default=[0.8, 0.1, 0.1], type=float, nargs="*", help="The ratio of samples in datasets. [0.6, 0.2, 0.2] means 60% samples used for training, 20% for evaluation and 20% for test.")
parser.add_argument("--task_type", choices=['ext', 'cls'], default="ext", type=str, help="Select task type, ext for the extraction task and cls for the classification task, defaults to ext.")
parser.add_argument("--options", default=["正向", "负向"], type=str, nargs="+", help="Used only for the classification task, the options for classification")
parser.add_argument("--prompt_prefix", default="情感倾向", type=str, help="Used only for the classification task, the prompt prefix for classification")
parser.add_argument("--is_shuffle", default="True", type=strtobool, help="Whether to shuffle the labeled dataset, defaults to True.")
parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization")
parser.add_argument("--separator", type=str, default='##', help="Used only for entity/aspect-level classification task, separator for entity label and classification label")
parser.add_argument("--schema_lang", choices=["ch", "en"], default="ch", help="Select the language type for schema.")
args = parser.parse_args()
# yapf: enable
do_convert()
python doccano.py \
--doccano_file ./data/doccano_ext.json \
--task_type ext \
--save_dir ./data \
--splits 0.8 0.2 0 \
--schema_lang ch
转换后的数据格式
{"id":0,"text":"计算机网络是指将地理位置不同的具有独立功能的多台计算机及其外部设备,通过通信线路和通信设备连接起来,在网络操作系统,网络管理软件及网络通信协议的管理和协调下,实现资源共享和信息传递的计算机系统。计算机网络主要是由一些通用的、可编程的硬件互连而成的。这些可编程的硬件能够用来传送多种不同类型的数据,并能支持广泛的和日益增长的应用。计算机网络Computer network计算机网络系统互联网信息的传输与共享网络操作系统计算机网络也称计算机通信网。关于计算机网络的最简单定义是:一些相互连接的、以共享资源为目的的、自治的计算机的集合。若按此定义,则早期的面向终端的网络都不能算是计算机网络,而只能称为联机系统(因为那时的许多终端不能算是自治的计算机)。但随着硬件价格的下降,许多终端都具有一定的智能,因而“终端”和“自治的计算机”逐渐失去了严格的界限。若用微型计算机作为终端使用,按上述定义,则早期的那种面向终端的网络也可称为计算机网络。另外,从逻辑功能上看,计算机网络是以传输信息为基础目的,用通信线路将多个计算机连接起来的计算机系统的集合,一个计算机网络组成包括传输介质和通信设备。","entities":[{"id":350447,"label":"实体","start_offset":8,"end_offset":12},{"id":350448,"label":"实体","start_offset":29,"end_offset":33},{"id":350449,"label":"实体","start_offset":36,"end_offset":40},{"id":350450,"label":"实体","start_offset":51,"end_offset":57},{"id":350451,"label":"实体","start_offset":58,"end_offset":64},{"id":350452,"label":"实体","start_offset":65,"end_offset":71},{"id":350453,"label":"实体","start_offset":81,"end_offset":85},{"id":350454,"label":"实体","start_offset":86,"end_offset":90},{"id":350455,"label":"实体","start_offset":91,"end_offset":96},{"id":350456,"label":"实体","start_offset":216,"end_offset":222},{"id":350457,"label":"实体","start_offset":247,"end_offset":251},{"id":350458,"label":"实体","start_offset":299,"end_offset":303},{"id":350459,"label":"实体","start_offset":377,"end_offset":382},{"id":350460,"label":"实体","start_offset":447,"end_offset":451},{"id":350461,"label":"实体","start_offset":482,"end_offset":486},{"id":350462,"label":"实体","start_offset":487,"end_offset":491},{"id":350885,"label":"实体","start_offset":0,"end_offset":5},{"id":350886,"label":"实体","start_offset":41,"end_offset":45},{"id":350887,"label":"实体","start_offset":24,"end_offset":27},{"id":350888,"label":"实体","start_offset":209,"end_offset":214},{"id":350889,"label":"实体","start_offset":473,"end_offset":478},{"id":350890,"label":"实体","start_offset":454,"end_offset":457}],"relations":[{"id":222,"from_id":350885,"to_id":350453,"type":"功能"},{"id":223,"from_id":350885,"to_id":350454,"type":"功能"},{"id":224,"from_id":350885,"to_id":350455,"type":"定义"},{"id":225,"from_id":350888,"to_id":350456,"type":"别名"},{"id":226,"from_id":350889,"to_id":350461,"type":"包含"},{"id":227,"from_id":350889,"to_id":350462,"type":"包含"},{"id":228,"from_id":350460,"to_id":350890,"type":"连接"},{"id":229,"from_id":350449,"to_id":350887,"type":"连接"},{"id":230,"from_id":350449,"to_id":350448,"type":"连接"},{"id":231,"from_id":350886,"to_id":350887,"type":"连接"},{"id":232,"from_id":350886,"to_id":350448,"type":"连接"},{"id":233,"from_id":350885,"to_id":350450,"type":"包含"},{"id":234,"from_id":350885,"to_id":350451,"type":"包含"},{"id":235,"from_id":350885,"to_id":350452,"type":"包含"}],"Comments":[]}
四、模型微调
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from functools import partial
from typing import List, Optional
import paddle
from utils import convert_example, reader
from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.datasets import load_dataset
from paddlenlp.metrics import SpanEvaluator
from paddlenlp.trainer import (
CompressionArguments,
PdArgumentParser,
Trainer,
get_last_checkpoint,
)
from paddlenlp.transformers import UIE, UIEM, AutoTokenizer, export_model
from paddlenlp.utils.log import logger
@dataclass
class DataArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `PdArgumentParser` we can turn this class into argparse arguments to be able to
specify them on the command line.
"""
train_path: str = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dev_path: str = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
max_seq_length: Optional[int] = field(
default=512,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
dynamic_max_length: Optional[List[int]] = field(
default=None,
metadata={"help": "dynamic max length from batch, it can be array of length, eg: 16 32 64 128"},
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: Optional[str] = field(
default="uie-base",
metadata={
"help": "Path to pretrained model, such as 'uie-base', 'uie-tiny', "
"'uie-medium', 'uie-mini', 'uie-micro', 'uie-nano', 'uie-base-en', "
"'uie-m-base', 'uie-m-large', or finetuned model path."
},
)
export_model_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to directory to store the exported inference model."},
)
multilingual: bool = field(default=False, metadata={"help": "Whether the model is a multilingual model."})
def main():
parser = PdArgumentParser((ModelArguments, DataArguments, CompressionArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if model_args.model_name_or_path in ["uie-m-base", "uie-m-large"]:
model_args.multilingual = True
# Log model and data config
training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
paddle.set_device(training_args.device)
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if model_args.multilingual:
model = UIEM.from_pretrained(model_args.model_name_or_path)
else:
model = UIE.from_pretrained(model_args.model_name_or_path)
train_ds = load_dataset(reader, data_path=data_args.train_path, max_seq_len=data_args.max_seq_length, lazy=False)
dev_ds = load_dataset(reader, data_path=data_args.dev_path, max_seq_len=data_args.max_seq_length, lazy=False)
trans_fn = partial(
convert_example,
tokenizer=tokenizer,
max_seq_len=data_args.max_seq_length,
multilingual=model_args.multilingual,
dynamic_max_length=data_args.dynamic_max_length,
)
train_ds = train_ds.map(trans_fn)
dev_ds = dev_ds.map(trans_fn)
if training_args.device == "npu":
data_collator = DataCollatorWithPadding(tokenizer, padding="longest")
else:
data_collator = DataCollatorWithPadding(tokenizer)
criterion = paddle.nn.BCELoss()
def uie_loss_func(outputs, labels):
start_ids, end_ids = labels
start_prob, end_prob = outputs
start_ids = paddle.cast(start_ids, "float32")
end_ids = paddle.cast(end_ids, "float32")
loss_start = criterion(start_prob, start_ids)
loss_end = criterion(end_prob, end_ids)
loss = (loss_start + loss_end) / 2.0
return loss
def compute_metrics(p):
metric = SpanEvaluator()
start_prob, end_prob = p.predictions
start_ids, end_ids = p.label_ids
metric.reset()
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
metric.update(num_correct, num_infer, num_label)
precision, recall, f1 = metric.accumulate()
metric.reset()
return {"precision": precision, "recall": recall, "f1": f1}
trainer = Trainer(
model=model,
criterion=uie_loss_func,
args=training_args,
data_collator=data_collator,
train_dataset=train_ds if training_args.do_train or training_args.do_compress else None,
eval_dataset=dev_ds if training_args.do_eval or training_args.do_compress else None,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
trainer.optimizer = paddle.optimizer.AdamW(
learning_rate=training_args.learning_rate, parameters=model.parameters()
)
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
trainer.save_model()
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluate and tests model
if training_args.do_eval:
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
# export inference model
if training_args.do_export:
# You can also load from certain checkpoint
# trainer.load_state_dict_from_checkpoint("/path/to/checkpoint/")
if training_args.device == "npu":
# npu will transform int64 to int32 for internal calculation.
# To reduce useless transformation, we feed int32 inputs.
input_spec_dtype = "int32"
else:
input_spec_dtype = "int64"
if model_args.multilingual:
input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="input_ids"),
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="position_ids"),
]
else:
input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="input_ids"),
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="token_type_ids"),
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="position_ids"),
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="attention_mask"),
]
if model_args.export_model_dir is None:
model_args.export_model_dir = os.path.join(training_args.output_dir, "export")
export_model(model=trainer.model, input_spec=input_spec, path=model_args.export_model_dir)
trainer.tokenizer.save_pretrained(model_args.export_model_dir)
if training_args.do_compress:
@paddle.no_grad()
def custom_evaluate(self, model, data_loader):
metric = SpanEvaluator()
model.eval()
metric.reset()
for batch in data_loader:
if model_args.multilingual:
logits = model(input_ids=batch["input_ids"], position_ids=batch["position_ids"])
else:
logits = model(
input_ids=batch["input_ids"],
token_type_ids=batch["token_type_ids"],
position_ids=batch["position_ids"],
attention_mask=batch["attention_mask"],
)
start_prob, end_prob = logits
start_ids, end_ids = batch["start_positions"], batch["end_positions"]
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
metric.update(num_correct, num_infer, num_label)
precision, recall, f1 = metric.accumulate()
logger.info("f1: %s, precision: %s, recall: %s" % (f1, precision, f1))
model.train()
return f1
trainer.compress(custom_evaluate=custom_evaluate)
if __name__ == "__main__":
main()
继续使用paddle提供的微调代码,设置微调参数
export finetuned_model=./checkpoint/model_best
python finetune.py \
--device gpu \
--logging_steps 10 \
--save_steps 100 \
--eval_steps 100 \
--seed 42 \
--model_name_or_path uie-base \
--output_dir $finetuned_model \
--train_path data/train.txt \
--dev_path data/dev.txt \
--max_seq_length 512 \
--per_device_eval_batch_size 16 \
--per_device_train_batch_size 16 \
--num_train_epochs 20 \
--learning_rate 1e-5 \
--label_names "start_positions" "end_positions" \
--do_train \
--do_eval \
--do_export \
--export_model_dir $finetuned_model \
--overwrite_output_dir \
--disable_tqdm True \
--metric_for_best_model eval_f1 \
--load_best_model_at_end True \
--save_total_limit 1
模型训练过程
(vcrs) (base) yunpu@sduu-SYS-4029GP-TRT2:~/Data/codes/VCRS/UIE$ ./finetune.sh
/home/yunpu/Apps/anaconda3/envs/vcrs/lib/python3.8/site-packages/_distutils_hack/__init__.py:33: UserWarning: Setuptools is replacing distutils.
warnings.warn("Setuptools is replacing distutils.")
[2024-06-23 17:20:36,982] [ WARNING] - evaluation_strategy reset to IntervalStrategy.STEPS for do_eval is True. you can also set evaluation_strategy='epoch'.
[2024-06-23 17:20:36,982] [ INFO] - The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
[2024-06-23 17:20:36,982] [ INFO] - ============================================================
[2024-06-23 17:20:36,982] [ INFO] - Model Configuration Arguments
[2024-06-23 17:20:36,983] [ INFO] - paddle commit id :fbf852dd832bc0e63ae31cd4aa37defd829e4c03
[2024-06-23 17:20:36,983] [ INFO] - export_model_dir :./checkpoint/model_best
[2024-06-23 17:20:36,983] [ INFO] - model_name_or_path :/home/yunpu/Data/codes/VCRS/UIE/checkpoint_all_entities/model_best
[2024-06-23 17:20:36,983] [ INFO] - multilingual :True
[2024-06-23 17:20:36,983] [ INFO] -
[2024-06-23 17:20:36,983] [ INFO] - ============================================================
[2024-06-23 17:20:36,983] [ INFO] - Data Configuration Arguments
[2024-06-23 17:20:36,983] [ INFO] - paddle commit id :fbf852dd832bc0e63ae31cd4aa37defd829e4c03
[2024-06-23 17:20:36,983] [ INFO] - dev_path :data/ner_cn/dev.txt
[2024-06-23 17:20:36,983] [ INFO] - dynamic_max_length :None
[2024-06-23 17:20:36,983] [ INFO] - max_seq_length :512
[2024-06-23 17:20:36,983] [ INFO] - train_path :data/ner_cn/train.txt
[2024-06-23 17:20:36,983] [ INFO] -
[2024-06-23 17:20:36,983] [ WARNING] - Process rank: -1, device: gpu, world_size: 1, distributed training: False, 16-bits training: False
[2024-06-23 17:20:36,993] [ INFO] - We are using <class 'paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer'> to load '/home/yunpu/Data/codes/VCRS/UIE/checkpoint_all_entities/model_best'.
[2024-06-23 17:20:37,776] [ INFO] - Loading configuration file /home/yunpu/Data/codes/VCRS/UIE/checkpoint_all_entities/model_best/config.json
[2024-06-23 17:20:37,786] [ INFO] - Loading weights file /home/yunpu/Data/codes/VCRS/UIE/checkpoint_all_entities/model_best/model_state.pdparams
[2024-06-23 17:20:39,740] [ INFO] - Loaded weights file from disk, setting weights to model.
W0623 17:20:39.767447 43400 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 8.6, Driver API Version: 12.4, Runtime API Version: 11.8
W0623 17:20:39.768709 43400 gpu_resources.cc:164] device: 0, cuDNN Version: 8.9.
[2024-06-23 17:20:42,136] [ INFO] - All model checkpoint weights were used when initializing UIEM.
[2024-06-23 17:20:42,136] [ INFO] - All the weights of UIEM were initialized from the model checkpoint at /home/yunpu/Data/codes/VCRS/UIE/checkpoint_all_entities/model_best.
If your task is similar to the task the model of the checkpoint was trained on, you can already use UIEM for predictions without further training.
[2024-06-23 17:20:43,438] [ INFO] - The global seed is set to 42, local seed is set to 43 and random seed is set to 42.
[2024-06-23 17:20:43,760] [ DEBUG] - ============================================================
[2024-06-23 17:20:43,760] [ DEBUG] - Training Configuration Arguments
[2024-06-23 17:20:43,760] [ DEBUG] - paddle commit id : fbf852dd832bc0e63ae31cd4aa37defd829e4c03
[2024-06-23 17:20:43,760] [ DEBUG] - paddlenlp commit id : 3105c18b013e1cdcbf860af1c6c54f4e33c88ee7
[2024-06-23 17:20:43,761] [ DEBUG] - _no_sync_in_gradient_accumulation: True
[2024-06-23 17:20:43,761] [ DEBUG] - activation_quantize_type : None
[2024-06-23 17:20:43,761] [ DEBUG] - adam_beta1 : 0.9
[2024-06-23 17:20:43,761] [ DEBUG] - adam_beta2 : 0.999
[2024-06-23 17:20:43,761] [ DEBUG] - adam_epsilon : 1e-08
[2024-06-23 17:20:43,761] [ DEBUG] - algo_list : None
[2024-06-23 17:20:43,761] [ DEBUG] - amp_custom_black_list : None
[2024-06-23 17:20:43,761] [ DEBUG] - amp_custom_white_list : None
[2024-06-23 17:20:43,761] [ DEBUG] - amp_master_grad : False
[2024-06-23 17:20:43,761] [ DEBUG] - batch_num_list : None
[2024-06-23 17:20:43,761] [ DEBUG] - batch_size_list : None
[2024-06-23 17:20:43,761] [ DEBUG] - bf16 : False
[2024-06-23 17:20:43,761] [ DEBUG] - bf16_full_eval : False
[2024-06-23 17:20:43,761] [ DEBUG] - bias_correction : False
[2024-06-23 17:20:43,761] [ DEBUG] - current_device : gpu:0
[2024-06-23 17:20:43,761] [ DEBUG] - data_parallel_config :
[2024-06-23 17:20:43,762] [ DEBUG] - data_parallel_rank : 0
[2024-06-23 17:20:43,762] [ DEBUG] - dataloader_drop_last : False
[2024-06-23 17:20:43,762] [ DEBUG] - dataloader_num_workers : 0
[2024-06-23 17:20:43,762] [ DEBUG] - dataset_rank : 0
[2024-06-23 17:20:43,762] [ DEBUG] - dataset_world_size : 1
[2024-06-23 17:20:43,762] [ DEBUG] - device : gpu
[2024-06-23 17:20:43,762] [ DEBUG] - disable_tqdm : True
[2024-06-23 17:20:43,762] [ DEBUG] - distributed_dataloader : False
[2024-06-23 17:20:43,762] [ DEBUG] - do_compress : False
[2024-06-23 17:20:43,762] [ DEBUG] - do_eval : True
[2024-06-23 17:20:43,762] [ DEBUG] - do_export : True
[2024-06-23 17:20:43,762] [ DEBUG] - do_predict : False
[2024-06-23 17:20:43,762] [ DEBUG] - do_train : True
[2024-06-23 17:20:43,762] [ DEBUG] - enable_auto_parallel : False
[2024-06-23 17:20:43,762] [ DEBUG] - eval_accumulation_steps : None
[2024-06-23 17:20:43,762] [ DEBUG] - eval_batch_size : 16
[2024-06-23 17:20:43,762] [ DEBUG] - eval_steps : 2727
[2024-06-23 17:20:43,762] [ DEBUG] - evaluation_strategy : IntervalStrategy.STEPS
[2024-06-23 17:20:43,763] [ DEBUG] - flatten_param_grads : False
[2024-06-23 17:20:43,763] [ DEBUG] - force_reshard_pp : False
[2024-06-23 17:20:43,763] [ DEBUG] - fp16 : False
[2024-06-23 17:20:43,763] [ DEBUG] - fp16_full_eval : False
[2024-06-23 17:20:43,763] [ DEBUG] - fp16_opt_level : O1
[2024-06-23 17:20:43,763] [ DEBUG] - gradient_accumulation_steps : 1
[2024-06-23 17:20:43,763] [ DEBUG] - greater_is_better : True
[2024-06-23 17:20:43,763] [ DEBUG] - hybrid_parallel_topo_order : pp_first
[2024-06-23 17:20:43,763] [ DEBUG] - ignore_data_skip : False
[2024-06-23 17:20:43,763] [ DEBUG] - ignore_load_lr_and_optim : False
[2024-06-23 17:20:43,763] [ DEBUG] - ignore_save_lr_and_optim : False
[2024-06-23 17:20:43,763] [ DEBUG] - input_dtype : int64
[2024-06-23 17:20:43,763] [ DEBUG] - input_infer_model_path : None
[2024-06-23 17:20:43,763] [ DEBUG] - label_names : ['start_positions', 'end_positions']
[2024-06-23 17:20:43,763] [ DEBUG] - lazy_data_processing : True
[2024-06-23 17:20:43,763] [ DEBUG] - learning_rate : 1e-05
[2024-06-23 17:20:43,763] [ DEBUG] - load_best_model_at_end : True
[2024-06-23 17:20:43,763] [ DEBUG] - load_sharded_model : False
[2024-06-23 17:20:43,763] [ DEBUG] - local_process_index : 0
[2024-06-23 17:20:43,764] [ DEBUG] - local_rank : -1
[2024-06-23 17:20:43,764] [ DEBUG] - log_level : -1
[2024-06-23 17:20:43,764] [ DEBUG] - log_level_replica : -1
[2024-06-23 17:20:43,764] [ DEBUG] - log_on_each_node : True
[2024-06-23 17:20:43,764] [ DEBUG] - logging_dir : ./checkpoint/model_best/runs/Jun23_17-20-36_sduu-SYS-4029GP-TRT2
[2024-06-23 17:20:43,764] [ DEBUG] - logging_first_step : False
[2024-06-23 17:20:43,764] [ DEBUG] - logging_steps : 100
[2024-06-23 17:20:43,764] [ DEBUG] - logging_strategy : IntervalStrategy.STEPS
[2024-06-23 17:20:43,764] [ DEBUG] - logical_process_index : 0
[2024-06-23 17:20:43,764] [ DEBUG] - lr_end : 1e-07
[2024-06-23 17:20:43,764] [ DEBUG] - lr_scheduler_type : SchedulerType.LINEAR
[2024-06-23 17:20:43,764] [ DEBUG] - max_evaluate_steps : -1
[2024-06-23 17:20:43,764] [ DEBUG] - max_grad_norm : 1.0
[2024-06-23 17:20:43,764] [ DEBUG] - max_steps : -1
[2024-06-23 17:20:43,764] [ DEBUG] - metric_for_best_model : eval_f1
[2024-06-23 17:20:43,764] [ DEBUG] - minimum_eval_times : None
[2024-06-23 17:20:43,764] [ DEBUG] - moving_rate : 0.9
[2024-06-23 17:20:43,764] [ DEBUG] - no_cuda : False
[2024-06-23 17:20:43,765] [ DEBUG] - num_cycles : 0.5
[2024-06-23 17:20:43,765] [ DEBUG] - num_train_epochs : 20.0
[2024-06-23 17:20:43,765] [ DEBUG] - onnx_format : True
[2024-06-23 17:20:43,765] [ DEBUG] - optim : OptimizerNames.ADAMW
[2024-06-23 17:20:43,765] [ DEBUG] - optimizer_name_suffix : None
[2024-06-23 17:20:43,765] [ DEBUG] - output_dir : ./checkpoint/model_best
[2024-06-23 17:20:43,765] [ DEBUG] - overwrite_output_dir : True
[2024-06-23 17:20:43,765] [ DEBUG] - past_index : -1
[2024-06-23 17:20:43,765] [ DEBUG] - per_device_eval_batch_size : 16
[2024-06-23 17:20:43,765] [ DEBUG] - per_device_train_batch_size : 16
[2024-06-23 17:20:43,765] [ DEBUG] - pipeline_parallel_config :
[2024-06-23 17:20:43,765] [ DEBUG] - pipeline_parallel_degree : -1
[2024-06-23 17:20:43,765] [ DEBUG] - pipeline_parallel_rank : 0
[2024-06-23 17:20:43,765] [ DEBUG] - power : 1.0
[2024-06-23 17:20:43,765] [ DEBUG] - prediction_loss_only : False
[2024-06-23 17:20:43,765] [ DEBUG] - process_index : 0
[2024-06-23 17:20:43,765] [ DEBUG] - prune_embeddings : False
[2024-06-23 17:20:43,765] [ DEBUG] - recompute : False
[2024-06-23 17:20:43,765] [ DEBUG] - remove_unused_columns : True
[2024-06-23 17:20:43,766] [ DEBUG] - report_to : ['visualdl']
[2024-06-23 17:20:43,766] [ DEBUG] - resume_from_checkpoint : None
[2024-06-23 17:20:43,766] [ DEBUG] - round_type : round
[2024-06-23 17:20:43,766] [ DEBUG] - run_name : ./checkpoint/model_best
[2024-06-23 17:20:43,766] [ DEBUG] - save_on_each_node : False
[2024-06-23 17:20:43,766] [ DEBUG] - save_sharded_model : False
[2024-06-23 17:20:43,766] [ DEBUG] - save_steps : 2727
[2024-06-23 17:20:43,766] [ DEBUG] - save_strategy : IntervalStrategy.STEPS
[2024-06-23 17:20:43,766] [ DEBUG] - save_total_limit : 1
[2024-06-23 17:20:43,766] [ DEBUG] - scale_loss : 32768
[2024-06-23 17:20:43,766] [ DEBUG] - seed : 42
[2024-06-23 17:20:43,766] [ DEBUG] - sep_parallel_degree : -1
[2024-06-23 17:20:43,766] [ DEBUG] - sharding : []
[2024-06-23 17:20:43,766] [ DEBUG] - sharding_degree : -1
[2024-06-23 17:20:43,766] [ DEBUG] - sharding_parallel_config :
[2024-06-23 17:20:43,766] [ DEBUG] - sharding_parallel_degree : -1
[2024-06-23 17:20:43,766] [ DEBUG] - sharding_parallel_rank : 0
[2024-06-23 17:20:43,766] [ DEBUG] - should_load_dataset : True
[2024-06-23 17:20:43,767] [ DEBUG] - should_load_sharding_stage1_model: False
[2024-06-23 17:20:43,767] [ DEBUG] - should_log : True
[2024-06-23 17:20:43,767] [ DEBUG] - should_save : True
[2024-06-23 17:20:43,767] [ DEBUG] - should_save_model_state : True
[2024-06-23 17:20:43,767] [ DEBUG] - should_save_sharding_stage1_model: False
[2024-06-23 17:20:43,767] [ DEBUG] - skip_memory_metrics : True
[2024-06-23 17:20:43,767] [ DEBUG] - skip_profile_timer : True
[2024-06-23 17:20:43,767] [ DEBUG] - strategy : dynabert+ptq
[2024-06-23 17:20:43,767] [ DEBUG] - tensor_parallel_config :
[2024-06-23 17:20:43,767] [ DEBUG] - tensor_parallel_degree : -1
[2024-06-23 17:20:43,767] [ DEBUG] - tensor_parallel_rank : 0
[2024-06-23 17:20:43,767] [ DEBUG] - to_static : False
[2024-06-23 17:20:43,767] [ DEBUG] - train_batch_size : 16
[2024-06-23 17:20:43,767] [ DEBUG] - unified_checkpoint : False
[2024-06-23 17:20:43,767] [ DEBUG] - unified_checkpoint_config :
[2024-06-23 17:20:43,767] [ DEBUG] - use_hybrid_parallel : False
[2024-06-23 17:20:43,767] [ DEBUG] - use_pact : True
[2024-06-23 17:20:43,767] [ DEBUG] - wandb_api_key : None
[2024-06-23 17:20:43,767] [ DEBUG] - warmup_ratio : 0.1
[2024-06-23 17:20:43,768] [ DEBUG] - warmup_steps : 0
[2024-06-23 17:20:43,768] [ DEBUG] - weight_decay : 0.0
[2024-06-23 17:20:43,768] [ DEBUG] - weight_name_suffix : None
[2024-06-23 17:20:43,768] [ DEBUG] - weight_quantize_type : channel_wise_abs_max
[2024-06-23 17:20:43,768] [ DEBUG] - width_mult_list : None
[2024-06-23 17:20:43,768] [ DEBUG] - world_size : 1
[2024-06-23 17:20:43,768] [ DEBUG] -
[2024-06-23 17:20:43,768] [ INFO] - Starting training from resume_from_checkpoint : None
/home/yunpu/Apps/anaconda3/envs/vcrs/lib/python3.8/site-packages/paddle/distributed/parallel.py:410: UserWarning: The program will return to single-card operation. Please check 1, whether you use spawn or fleetrun to start the program. 2, Whether it is a multi-card program. 3, Is the current environment multi-card.
warnings.warn(
[2024-06-23 17:20:43,769] [ INFO] - [timelog] checkpoint loading time: 0.00s (2024-06-23 17:20:43)
[2024-06-23 17:20:43,769] [ INFO] - ***** Running training *****
[2024-06-23 17:20:43,769] [ INFO] - Num examples = 46,720
[2024-06-23 17:20:43,770] [ INFO] - Num Epochs = 20
[2024-06-23 17:20:43,770] [ INFO] - Instantaneous batch size per device = 16
[2024-06-23 17:20:43,770] [ INFO] - Total train batch size (w. parallel, distributed & accumulation) = 16
[2024-06-23 17:20:43,770] [ INFO] - Gradient Accumulation steps = 1
[2024-06-23 17:20:43,770] [ INFO] - Total optimization steps = 58,400
[2024-06-23 17:20:43,770] [ INFO] - Total num train samples = 934,400
[2024-06-23 17:20:43,772] [ DEBUG] - Number of trainable parameters = 278,044,418 (per device)
/home/yunpu/Apps/anaconda3/envs/vcrs/lib/python3.8/site-packages/paddlenlp/transformers/tokenizer_utils_base.py:2512: FutureWarning: The `max_seq_len` argument is deprecated and will be removed in a future version, please use `max_length` instead.
warnings.warn(
/home/yunpu/Apps/anaconda3/envs/vcrs/lib/python3.8/site-packages/paddlenlp/transformers/tokenizer_utils_base.py:1912: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
warnings.warn(
/home/yunpu/Apps/anaconda3/envs/vcrs/lib/python3.8/site-packages/paddlenlp/transformers/tokenizer_utils_base.py:3023: UserWarning: Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
warnings.warn(
得到的模型结果
五、模型效果评估
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from functools import partial
import paddle
from utils import (
convert_example,
create_data_loader,
get_relation_type_dict,
reader,
unify_prompt_name,
)
from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.datasets import MapDataset, load_dataset
from paddlenlp.metrics import SpanEvaluator
from paddlenlp.transformers import UIE, UIEM, AutoTokenizer
from paddlenlp.utils.log import logger
@paddle.no_grad()
def evaluate(model, metric, data_loader, multilingual=False):
"""
Given a dataset, it evals model and computes the metric.
Args:
model(obj:`paddle.nn.Layer`): A model to classify texts.
metric(obj:`paddle.metric.Metric`): The evaluation metric.
data_loader(obj:`paddle.io.DataLoader`): The dataset loader which generates batches.
multilingual(bool): Whether is the multilingual model.
"""
model.eval()
metric.reset()
for batch in data_loader:
if multilingual:
start_prob, end_prob = model(batch["input_ids"], batch["position_ids"])
else:
start_prob, end_prob = model(
batch["input_ids"], batch["token_type_ids"], batch["position_ids"], batch["attention_mask"]
)
start_ids = paddle.cast(batch["start_positions"], "float32")
end_ids = paddle.cast(batch["end_positions"], "float32")
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
metric.update(num_correct, num_infer, num_label)
precision, recall, f1 = metric.accumulate()
model.train()
return precision, recall, f1
def do_eval():
paddle.set_device(args.device)
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
if args.multilingual:
model = UIEM.from_pretrained(args.model_path)
else:
model = UIE.from_pretrained(args.model_path)
test_ds = load_dataset(reader, data_path=args.test_path, max_seq_len=args.max_seq_len, lazy=False)
class_dict = {}
relation_data = []
if args.debug:
for data in test_ds:
class_name = unify_prompt_name(data["prompt"])
# Only positive examples are evaluated in debug mode
if len(data["result_list"]) != 0:
p = "的" if args.schema_lang == "ch" else " of "
if p not in data["prompt"]:
class_dict.setdefault(class_name, []).append(data)
else:
relation_data.append((data["prompt"], data))
relation_type_dict = get_relation_type_dict(relation_data, schema_lang=args.schema_lang)
else:
class_dict["all_classes"] = test_ds
trans_fn = partial(
convert_example, tokenizer=tokenizer, max_seq_len=args.max_seq_len, multilingual=args.multilingual
)
for key in class_dict.keys():
if args.debug:
test_ds = MapDataset(class_dict[key])
else:
test_ds = class_dict[key]
test_ds = test_ds.map(trans_fn)
data_collator = DataCollatorWithPadding(tokenizer)
test_data_loader = create_data_loader(test_ds, mode="test", batch_size=args.batch_size, trans_fn=data_collator)
metric = SpanEvaluator()
precision, recall, f1 = evaluate(model, metric, test_data_loader, args.multilingual)
logger.info("-----------------------------")
logger.info("Class Name: %s" % key)
logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" % (precision, recall, f1))
if args.debug and len(relation_type_dict.keys()) != 0:
for key in relation_type_dict.keys():
test_ds = MapDataset(relation_type_dict[key])
test_ds = test_ds.map(trans_fn)
test_data_loader = create_data_loader(
test_ds, mode="test", batch_size=args.batch_size, trans_fn=data_collator
)
metric = SpanEvaluator()
precision, recall, f1 = evaluate(model, metric, test_data_loader)
logger.info("-----------------------------")
if args.schema_lang == "ch":
logger.info("Class Name: X的%s" % key)
else:
logger.info("Class Name: %s of X" % key)
logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" % (precision, recall, f1))
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default=None, help="The path of saved model that you want to load.")
parser.add_argument("--test_path", type=str, default=None, help="The path of test set.")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size per GPU/CPU/NPU for training.")
parser.add_argument("--device", type=str, default="gpu", choices=["gpu", "cpu", "npu"], help="Device selected for evaluate.")
parser.add_argument("--max_seq_len", type=int, default=512, help="The maximum total input sequence length after tokenization.")
parser.add_argument("--debug", action='store_true', help="Precision, recall and F1 score are calculated for each class separately if this option is enabled.")
parser.add_argument("--multilingual", action='store_true', help="Whether is the multilingual model.")
parser.add_argument("--schema_lang", choices=["ch", "en"], default="ch", help="Select the language type for schema.")
args = parser.parse_args()
# yapf: enable
do_eval()
python evaluate.py \
--model_path ./checkpoint/model_best \
--test_path ./data/dev.txt \
--batch_size 16 \
--max_seq_len 512 \
--multilingual
模型在实体识别任务上取得了Evaluation Precision: 0.61081 | Recall: 0.58020 | F1: 0.59511的成绩,我认为已经满足我们的实体识别标注需求。
实际样例测试
from pprint import pprint
from paddlenlp import Taskflow
schema = ['概念', '定义', '别名'] # Define the schema for entity extraction
ie = Taskflow('information_extraction', schema=schema, task_path='/home/yunpu/Data/codes/VCRS/UIE/checkpoint/model_best')
# ie = Taskflow('information_extraction', schema=schema, task_path='/home/yunpu/Data/codes/VCRS/UIE/checkpoint/model_best')
schema = ['实体'] # Define the schema for relation extraction
# schema = {'实体':['相关', '定义', '提供', '包含', '支持', '任务', '指导', '别名', '解决' ,'实现', '功能', '影响']}
pprint(ie("数据链路层的功能是为网络层提供服务。最主要的服务是将数据从源机器的网络层传输到目标机器的网络层。在源机器的网络层有一个实体(称为进程),它将一些比特交给数据链路层,要求传输到目标机器。数据链路层的任务就是将这些比特传输给目标机器,然后再进一步交付给网络层,如图3-2 Ca)所示。实际的传输过程则是沿着图3-2 Cb)所示的路径进行的,但很容易将这个过程想象成两个数据链路层的进程使用一个数据链路协议进行通信。基于这个原因,在本章中我们将隐式使用图3-2 Ca)的模型。"))
输出
[{'实体': [{'end': 5,
'probability': 0.6005318961963724,
'start': 0,
'text': '数据链路层'},
{'end': 215,
'probability': 0.6650135868184393,
'start': 213,
'text': '本章'},
{'end': 13,
'probability': 0.6238407158454891,
'start': 10,
'text': '网络层'},
{'end': 66,
'probability': 0.8487259132296145,
'start': 64,
'text': '进程'},
{'end': 200,
'probability': 0.9407074886524889,
'start': 194,
'text': '数据链路协议'}]}]
可以看出比训练前好了很多