本篇文章用一个paraphrase generation的任务,带领大家走完Bart实战实用的全过程。
每一部分的代码见github链接。下面按照代码的顺序进行一一讲解。
准备工作
-
选择平台,除了Linux等常规的服务器平台,这里安利一下google的平台 colab,在使用服务器不方便的情况下,可以在colab实现深度学习,以及神经网络代码。
-
安装pytorch,Linux环境的pytorch安装可以参考链接: MAC系统Linux服务器下用Anaconda安装Pytorch;在colab上,则使用
!pip install pytorch
不需要其他的配置环境步骤。 -
安装transformers
-
下载数据集,本文为大家提供了用于运行代码的两份数据,分别是train.json和eval.json。数据来自于Paraphrase Data,经过了处理保留了每对数据的sentence和headline部分,处理过程见另一篇文章。其中sentences是inputs,其对应的headlines是labels,每份训练文件存储20000对训练数据,而测试集存储10000对数据。
import numpy as np
import os
import re
import json
import torch
import torch.optim as optim
from torch import tensor
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import BartForConditionalGeneration, BartTokenizer, BartModel, AutoModel
from transformers.utils.import_utils import SENTENCEPIECE_IMPORT_ERROR
import dataset # from dataset.py
from dataset import Sent_Comp_Dataset
from dataset import collate_batch
import def_train_set # from def_train_set.py
from def_train_set import train,test
载入dataloader
全代码地址: dataset.py
下面是dataset.py代码的第一部分:代码讲解见代码后的注解。
class Sent_Comp_Dataset(Dataset): # 定义class
def __init__(self, path="", prefix="train"):
self.data_path = path # dataset的路径
self.sentence = []
self.headline = []
with open(self.data_path, encoding="utf-8", mode = 'r') as source:
context = json.load(source) # 读取.json文件的数据
for i in context:
element_sent = context[i]['sentence']
self.sentence.append(element_sent) # 取出每对数据中的sentence部分
element_head = context[i]['headline']
self.headline.append(element_head) # 取出每对数据中的headline部分
print('Files already downloaded and verified')
def __len__(self):
return len