mT5-m2o-chinese_simplified-CrossSum-中文文本生成摘要模块
一、下载地址
提示:Huggingface目前国内已经无法登陆,必须通过翻墙,这里提供网站:
HF-Mirror镜像网站,可用于下载模型:
二、模块介绍
huggingface上将长文本中文数据处理为短文本数据的模型,通过以下代码可以实现长文本批量处理
三、使用步骤
1.数据集
采用的是清华数据集的test.txt,格式是1段长文本+一个数字标签,按行读取,数字标签保留,长文本编程短文本
例如(其中一行):
中新经纬客户端5月19日电19日两市开盘之后震荡分化,沪指震荡走弱,创业板指在锂电池、消费电子板块带动下震荡拉升。截至收盘,沪指报3510.96点,跌幅0.51%,成交额3499.73亿元;深成指报14484.45点,涨幅0.23%,成交额4415.12亿元;创业板指报3114.72点,涨幅0.8%。【编辑:陈海峰】 0
2.使用方式(README.MD)
代码如下(示例):
data = pd.read_csv(
'import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
article_text = """Videos that say approved vaccines are dangerous and cause autism, cancer or infertility are among those that will be taken down, the company said. The policy includes the termination of accounts of anti-vaccine influencers. Tech giants have been criticised for not doing more to counter false health information on their sites. In July, US President Joe Biden said social media platforms were largely responsible for people's scepticism in getting vaccinated by spreading misinformation, and appealed for them to address the issue. YouTube, which is owned by Google, said 130,000 videos were removed from its platform since last year, when it implemented a ban on content spreading misinformation about Covid vaccines. In a blog post, the company said it had seen false claims about Covid jabs "spill over into misinformation about vaccines in general". The new policy covers long-approved vaccines, such as those against measles or hepatitis B. "We're expanding our medical misinformation policies on YouTube with new guidelines on currently administered vaccines that are approved and confirmed to be safe and effective by local health authorities and the WHO," the post said, referring to the World Health Organization."""
model_name = "csebuetnlp/mT5_m2o_chinese_simplified_crossSum"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
input_ids = tokenizer(
[WHITESPACE_HANDLER(article_text)],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512
)["input_ids"]
output_ids = model.generate(
input_ids=input_ids,
max_length=84,
no_repeat_ngram_size=2,
num_beams=4
)[0]
summary = tokenizer.decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
print(summary)
这是对测试文本的读取,缺点是时间较长。
3.批量读取数据
#将代码粗略封装后,处理文件。代码如下(示例):
#将前面的代码封装在summary()中。
def process_lines(line):
processed_line = ''
for word in line.split():
if word.isdigit():
processed_line += '\t' + word
else:
word = summary(word)
processed_line += word
return processed_line
#打开test.txt文件,对文件行进行处理。
with open('test.txt', 'r', encoding='utf-8') as file:
lines = file.readlines()
lines = [process_lines(line) for line in lines]
#打开result.txt文件,将结果写入result.txt中。
with open('result.txt', 'w', encoding='utf-8') as result_file:
for line in lines:
result_file.write(line + '\n')
四、 总结-完整代码
利用huggingface上的中文文本摘要程序进行中文文本摘要处理,注意transformer的版本问题。
完整代码:
import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
def summary(texts):
#model名字。需要修改成你下载的模型名字,这里我用了model
model_name = r"C:\Users\Administrator\Desktop\model"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
input_ids = tokenizer(
[WHITESPACE_HANDLER(texts)],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512
)["input_ids"]
output_ids = model.generate(
input_ids=input_ids,
max_length=84,
no_repeat_ngram_size=2,
num_beams=4
)[0]
summary = tokenizer.decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return(summary)
#将前面的代码封装在summary()中。
def process_lines(line):
processed_line = ''
for word in line.split():
if word.isdigit():
processed_line += '\t' + word
else:
word = summary(word)
processed_line += word
return processed_line
#打开test.txt文件,对文件行进行处理。
with open('test.txt', 'r', encoding='utf-8') as file:
lines = file.readlines()
lines = [process_lines(line) for line in lines]
#打开result.txt文件,将结果写入result.txt中。
with open('result.txt', 'w', encoding='utf-8') as result_file:
for line in lines:
result_file.write(line + '\n')