5个简单的步骤使用Pytorch进行文本摘要总结


介绍
文本摘要是自然语言处理(NLP)的一项任务,其目的是生成源文本的简明摘要。不像摘录摘要,摘要不仅仅简单地从源文本复制重要的短语,还要提出新的相关短语,这可以被视为释义。摘要在不同的领域产生了大量的应用,从书籍和文献,科学和研发,金融研究和法律文件分析。

到目前为止,对抽象摘要最有效的方法是在摘要数据集上使用经过微调的transformer模型。在本文中,我们将演示如何在几个简单步骤中使用功能强大的模型轻松地总结文本。我们将要使用的模型已经经过了预先训练,所以不需要额外的训练:)

让我们开始吧!

步骤1:安装Transformers库
我们要用的库是Huggingface实现的Transformers 。如果你不熟悉Transformers ,你可以继续阅读我之前的文章。

要安装变压器,您可以简单地运行:

 pip install transformers
注意需要事先安装Pytorch。如果您还没有安装Pytorch,请访问Pytorch官方网站并按照说明安装它。

步骤2:导入库
成功安装transformer之后,现在可以开始将其导入到Python脚本中。我们也可以导入os来设置GPU在下一步使用的环境变量。注意,这是完全可选的,但如果您有多个gpu(如果您使用的是jupiter笔记本),这是防止错误的使用其他gpu的一个好做法。

 from transformers import pipeline
 import os
步骤3:设置使用的GPU和模型
如果你决定设置GPU(例如0),那么你可以如下图所示:

 os.environ["CUDA_VISIBLE_DEVICES"] = "0"
现在,我们准备好选择要使用的摘要模型了。Huggingface提供两种强大的摘要模型使用:BART (BART -large-cnn)和t5 (t5-small, t5-base, t5-large, t5- 3b, t5- 11b)。你可以在他们的官方paper(BART paper, t5 paper)上了解更多。

要使用在CNN/每日邮报新闻数据集上训练的BART模型,您可以通过Huggingface的内置管道模块直接使用默认参数:

 summarizer = pipeline("summarization")
如果你想使用t5模型(例如t5-base),它是在c4 Common Crawl web语料库进行预训练的,那么你可以这样做:

 summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf")
步骤4:输入文本进行总结
现在,在我们准备好我们的模型之后,我们可以开始输入我们想要总结的文本。想象一下,我们想从MedicineNet的一篇文章中总结以下关于COVID-19疫苗的内容:

One month after the United States began what has become a troubled rollout  of a national COVID vaccination campaign, the effort is finally  gathering real steam.

Close to a million doses — over 951,000, to be more exact — made their way  into the arms of Americans in the past 24 hours, the U.S. Centers for  Disease Control and Prevention reported Wednesday. That’s the largest  number of shots given in one day since the rollout began and a big jump  from the previous day, when just under 340,000 doses were given, CBS News reported.

That number is likely to jump quickly after the federal government on  Tuesday gave states the OK to vaccinate anyone over 65 and said it would release all the doses of vaccine it has available for distribution.  Meanwhile, a number of states have now opened mass vaccination sites in  an effort to get larger numbers of people inoculated, CBS News reported.

我们定义变量:

 text = """One month after the United States began what has become a troubled rollout of a national COVID vaccination campaign, the effort is finally gathering real steam.
 Close to a million doses -- over 951,000, to be more exact -- made their way into the arms of Americans in the past 24 hours, the U.S. Centers for Disease Control and Prevention reported Wednesday. That's the largest number of shots given in one day since the rollout began and a big jump from the previous day, when just under 340,000 doses were given, CBS News reported.
 That number is likely to jump quickly after the federal government on Tuesday gave states the OK to vaccinate anyone over 65 and said it would release all the doses of vaccine it has available for distribution. Meanwhile, a number of states have now opened mass vaccination sites in an effort to get larger numbers of people inoculated, CBS News reported."""
步骤4:总结
最后,我们可以开始总结输入的文本。这里,我们声明了希望汇总输出的min_length和max_length,并且关闭了采样以生成固定的汇总。我们可以通过运行以下命令来实现:

 summary_text = summarizer(text, max_length=100, min_length=5, do_sample=False)[0]['summary_text']
 print(summary_text)
我们得到总结文本:

Over 951,000 doses of vaccine given in one day in the past 24 hours, CDC says . That’s the largest number of shots given in a month since the  rollout began . The federal government gave states the OK to vaccinate  anyone over 65 on Tuesday . A number of states have now opened mass  vaccination sites in an effort to get more people inoculated, CBS News  reports .

从总结的文本中可以看出,该模型知道24小时相当于一天,并聪明地将美国疾病控制与预防中心(U.S. Centers for Disease Control and Prevention)缩写为CDC。此外,该模型成功地从第一段和第二段链接信息,指出这是自上个月开始展示以来给出的最大次数。我们可以看到,该摘要模型的性能相当不错。

最后把所有这些放在一起,这里是jupyter notebook形式的整个代码:

https://gist.github.com/itsuncheng/f3c4dde81ac4651383c4480958da4f8e#file-summarization-ipynb

Lewis, Mike, et al. “Bart: Denoising sequence-to-sequence pre-training for natural language generation, translation, and comprehension.” arXiv preprint arXiv:1910.13461 (2019).

Raffel, Colin, et al. “Exploring the limits of transfer learning with a unified text-to-text transformer.” arXiv preprint arXiv:1910.10683 (2019).
————————————————
版权声明:本文为CSDN博主「uoiqu90093jgj」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/ai52learn/article/details/112765454

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: PyTorch使用TensorBoard可以通过安装TensorBoardX库来实现。TensorBoardX是一个PyTorch的扩展库,它提供了一种将PyTorch的数据可视化的方法,可以将训练过程的损失函数、准确率等指标以图表的形式展示出来,方便用户对模型的训练过程进行监控和调试。具体使用方法可以参考TensorBoardX的官方文档。 ### 回答2: PyTorch是一款流行的深度学习框架,用于实现神经网络模型和训练过程。TensorBoard是与TensorFlow框架一起使用的一个可视化工具,方便进行模型训练和性能调优。但是,PyTorch用户也可以充分利用TensorBoard来监控他们的模型。 在PyTorch使用TensorBoard主要包括以下几个步骤: 1. 安装TensorBoard和TensorFlow:需要在PyTorch的虚拟环境安装TensorFlow和TensorBoard,这可以使用pip来完成。 2. 导入所需的库:首先,需要导入PyTorch库和TensorFlow库。在这里,PyTorch库用于定义、训练和测试模型,而TensorFlow库用于可视化和监视模型训练过程。可以使用以下代码导入这些库: ``` import tensorflow as tf from torch.utils.tensorboard import SummaryWriter ``` 3. 创建SummaryWriter对象:SummaryWriter是TensorBoard类的主要接口。可以使用它来创建TensorBoard的摘要文件和事件文件。在下面的代码,可以创建一个名为“runs/xxx”的摘要写入器: ``` writer = SummaryWriter('runs/xxx') ``` 4. 定义模型:在PyTorch定义模型。在下面的代码,定义了一个包含两个全连接层的简单线性模型: ``` import torch.nn as nn class LinearModel(nn.Module): def __init__(self): super(LinearModel, self).__init__() self.fc1 = nn.Linear(784, 100) self.fc2 = nn.Linear(100, 10) def forward(self, x): x = x.view(-1, 784) x = nn.functional.relu(self.fc1(x)) x = self.fc2(x) return x ``` 5. 记录数据:使用writer对象记录数据。可以使用以下代码来记录训练数据: ``` for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): # 定义前向传递 outputs = model(images) # 计算损失 loss = criterion(outputs, labels) # 后向传递和优化器的更新 optimizer.zero_grad() loss.backward() optimizer.step() # 记录损失和准确率 writer.add_scalar('Training/Loss', loss.item(), epoch * len(train_loader) + i) total = labels.size(0) _, predicted = torch.max(outputs.data, 1) correct = (predicted == labels).sum().item() writer.add_scalar('Training/Accuracy', 100 * correct / total, epoch * len(train_loader) + i) ``` 6. 可视化和监控:在运行完上述代码后,可以返回到TensorBoard,可视化和监视训练过程。输入以下命令,启动TensorBoard服务: ``` tensorboard --logdir=runs ``` 然后,在Web浏览器,输入http://localhost:6006访问TensorBoard服务器。此时,可以看到图形界面显示了许多模型指标,例如损失和准确率。点击“Scalars”选项卡,就可以查看训练过程的损失和准确率曲线。 总之,在PyTorch使用TensorBoard可以方便地监视模型的训练和性能,并且TensorBoard可以提供可视化和交互式工具来帮助调试模型。 ### 回答3: PyTorch是近年来开发迅速的深度学习框架之一,基于Python语言,操作简便易学,广受欢迎。其应用范围广泛,包括图像识别、文本分类、语言模型等多种场景。 TensorBoard是TensorFlow框架提供的可视化工具,能够展现模型训练过程的各类参数、数据和图形化结果。然而,使用PyTorch的开发者也可以使用TensorBoard,PyTorch支持使用TensorBoard进行训练过程可视化。 下面是关于使用TensorBoard来监测PyTorch训练过程的几种方法: 一、使用TensorboardX TensorBoardX是一种基于PyTorch创建的TensorBoard工具,它使用了TensorFlow的tensorboard接口。使用该工具需要对PyTorch进行一些包的安装。 首先安装TensorboardX包: ```python !pip install tensorboardX ``` 然后,创建一个SummaryWriter,监测损失函数、准确率、图像等数据: ```python from tensorboardX import SummaryWriter writer = SummaryWriter("tb_dir") for i in range(100): writer.add_scalar('loss/train', i**2, i) writer.add_scalar('loss/test', 0.7*i**2, i) writer.add_scalar('accuracy/test', 0.9*i, i) writer.add_scalar('accuracy/train', 0.6*i, i) ``` 最后启动TensorBoard,运行 pytorch使用tensorboard的命令行。 ``` tensorboard --logdir tb_dir --host localhost --port 8088 ``` 二、使用PyTorch内置的TensorBoard可视化 pytorch 1.2版本以上,又增加了 PyTorch自带的TensorBoard可视化,PyTorch 内置的与TensorBoard的API兼容,创建SummaryWriter的方法更加简便,而不需要安装多个包。在训练过程,与使用TensorBoardX类似,将需要监测的数据文件写入到SummaryWriter: ```python from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() for i in range(100): writer.add_scalar('loss/train', i**2, i) writer.add_scalar('loss/test', 0.7*i**2, i) writer.add_scalar('accuracy/test', 0.9*i, i) writer.add_scalar('accuracy/train', 0.6*i, i) ``` 运行 tensorboard --logdir , 输入PyTorch写入的文件即可。 三、使用Fastai集成的TensorBoardCallback 除了TensorboardX和PyTorch内置的TensorBoard可视化外,有另外一个可选方案,即使用Fastai的TensorBoardCallback。Fastai是基于PyTorch的高级深度学习框架,其包含了处理端到端的许多好用工具,包括用于监控训练进程的TensorBoardCallback。下面是使用方法: ```python from fastai.basics import * path = untar_data(URLs.MNIST_SAMPLE) data = ImageDataBunch.from_folder(path) learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=ShowGraph) learn.fit(5) ``` 设置callback_fns的ShowGraph即可可视化监测模型的训练过程。 总结 PyTorch是一个强大的深度学习框架,它提供了多种工具监测模型的训练过程。TensorBoard是目前广泛使用的可视化工具之一,使用TensorboardX、PyTorch内置的可视化、Fastai的TensorBoardCallback等方法均可实现PyTorch训练过程的监测和可视化,方便开发者了解模型的训练进程,发现问题并进行调整优化。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值