tokenizer.batch_encode_plus(['Hello , this is one sentence','This is another sentence.']){'input_ids':[[8774,3,6,48,19,80,7142,1],[100,19,430,7142,5,1]],'attention_mask':[[1,1,1,1,1,1,1,1],[1,1,1,1,1,1]]}with tokenizer.as_target_tokenizer():print(tokenizer.batch_encode_plus(['Hello , this is one sentence','This is another sentence.'])){'input_ids':[[8774,3,6,48,19,80,7142,1],[100,19,430,7142,5,1]],'attention_mask':[[1,1,1,1,1,1,1,1],[1,1,1,1,1,1]]}
dataset['train']= dataset['train'].shuffle(1).select(range(20000))
dataset['validation']= dataset['validation'].shuffle(1).select(range(1000))
dataset['test']= dataset['test'].shuffle(1).select(range(1000))print(dataset['train'][0]){'document':"Clay, who has agreed a two-year deal, made 39 appearances for Scottish Premiership club Motherwell last season after joining them in June 2016.\nThe 25-year-old had spent the two previous seasons with Grimsby, playing 74 National League games.\nClay is Leyton Orient's ninth signing since being relegated from League Two last season.",'summary':'National League side Leyton Orient have signed Motherwell midfielder Craig Clay on a free transfer.','id':'40635923'}
5.数据预处理
deff(examples, tokenizer):
data = tokenizer.batch_encode_plus(['summarize:'+ i for i in examples['document']],# t5要求加上summarize
max_length=1024,truncation=True)
data['labels']= tokenizer.batch_encode_plus(examples['summary'], max_length=128,# 编码标签labeltruncation=True)['input_ids']return data
dataset = dataset.map(f,
batched=True,
batch_size=1000,
num_proc=12,
remove_columns=['document','summary','id'],
fn_kwargs={'tokenizer': tokenizer})
dataset
DatasetDict({
train: Dataset({
features:['input_ids','attention_mask','labels'],
num_rows:20000})
validation: Dataset({
features:['input_ids','attention_mask','labels'],
num_rows:1000})
test: Dataset({
features:['input_ids','attention_mask','labels'],
num_rows:1000})})print(dataset['train'][0]){'input_ids':[21603,10,254,5595,6,113,65,4686,3,9,192,18,1201,1154,6,263,6352,3179,7,21,12580,6552,2009,1886,8007,2091,336,774,227,6109,135,16,1515,4619,37,944,18,1201,18,1490,141,1869,8,192,1767,9385,28,23427,7,969,6,1556,3,4581,868,3815,1031,5,20988,19,312,21220,3,16495,31,7,24651,8097,437,271,3,60,8791,26,45,3815,2759,336,774,5,1],'attention_mask':[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],'labels':[868,3815,596,312,21220,3,16495,43,3814,8007,2091,2076,1846,49,12870,20988,30,3,9,339,2025,5,1]}
6.重写Collate_fn 批量读取数据 Label统一长度
import torch
defcollate_fn(data):
max_length =max([len(i['labels'])for i in data])# 求批次最长labelsfor i in data:# 把所有的label都pad到max_length
pads =[-100]*(max_length -len(i['labels']))
i['labels']= i['labels']+ pads
data = tokenizer.pad(# 把多个数据整合成一个tensor
encoded_inputs=data,
padding=True,
max_length=None,
pad_to_multiple_of=None,
return_tensors='pt')
data['decoder_input_ids']= torch.zeros_like(data['labels'], dtype=torch.long)
data['decoder_input_ids'][:,1:]= data['labels'][:,:-1]
data['decoder_input_ids'][data['decoder_input_ids']==-100]=0return data
data =[{'input_ids':[21603,10,37,3719,13],'attention_mask':[1,1,1,1,1],'labels':[10455,120,80]},{'input_ids':[21603,10,7086,8408,563],'attention_mask':[1,1,1,1,1],'labels':[301,53,4074,1669]}]
collate_fn(data){'input_ids': tensor([[21603,10,37,3719,13],[21603,10,7086,8408,563]]),'attention_mask': tensor([[1,1,1,1,1],[1,1,1,1,1]]),'labels': tensor([[10455,120,80,-100],[301,53,4074,1669]]),'decoder_input_ids': tensor([[0,10455,120,80],[0,301,53,4074]])}
7.数据加载器
loader = torch.utils.data.DataLoader(
dataset = dataset['train'],
batch_size=4,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)for data in loader:breakfor k, v in data.items():print(k, v.shape)# 4批次 809序列长度 47labels长度
input_ids torch.Size([4,809])
attention_mask torch.Size([4,809])
labels torch.Size([4,47])
decoder_input_ids torch.Size([4,47])len(loader)5000512*3210016435200
deftest(model):
model.eval()
loader_test = torch.utils.data.DataLoader(
dataset=dataset['test'],
batch_size=4,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)for data in loader_test:breakwith torch.no_grad():
out = model(**data)for i inrange(4):
input_ids = tokenizer.decode(data['input_ids'][i])
pred = tokenizer.decode(out['logits'].argmax(dim=2)[i])
label = tokenizer.decode(data['decoder_input_ids'][i])print('pred:', pred)print('label:', label)print()
test(model)
pred: Annakeepero Oinomar was thes in uncle beat the round. the football mped out. 4ship of the fifth capital the:
label:<pad> Goal hero Rabin Omar made headlines when his club from the fourth tier of Scottish football dumped a Premiership side out of the Scottish Cup.</s>
pred: world Hamilton says startinghe was "actual about about finishing starting start in both race One race the the the the the the the the world world world world world world
label:<pad> Lewis Hamilton said he was "not worried" about his difficult start to the Formula 1 season.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
pred: scientistss have on's spaceetta probe to will comet'.P have they have have been of the they a structures formed formed the
label:<pad> Scientists working on Europe's Rosetta probe, which is tracking Comet 67P, say they may have found evidence for how such icy objects were formed.
pred: football's also latest thing of football football football in football football football football football football football
label:<pad> It's the ugly side of the beautiful game.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
10.训练
from transformers.trainer_pt_utils import get_parameter_names
from transformers import AdamW
from transformers.optimization import get_scheduler
device = torch.device('cuda:0'if torch.cuda.is_available()else'cpu')
device
device(type='cuda', index=0)
data
{'input_ids': tensor([[21603,10,667,...,0,0,0],[21603,10,15743,...,5719,535,1],[21603,10,634,...,0,0,0],[21603,10,24607,...,0,0,0]]),'attention_mask': tensor([[1,1,1,...,0,0,0],[1,1,1,...,1,1,1],[1,1,1,...,0,0,0],[1,1,1,...,0,0,0]]),'labels': tensor([[71,1249,17030,18,8861,26737,297,13,26238,31,7,13017,1635,7,107,2309,2050,65,118,3754,5,1,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100],[17159,115,17,7,43,118,3279,81,823,2789,31,7,17804,91,18,858,18,5842,7,199,747,19,3,179,12,2862,2261,21154,16,502,6,227,3,9,1871,3977,13,1717,14566,53,826,3,9,5738,7952,5,1],[15670,65,243,150,1139,331,47,1513,116,24041,10564,15,26,165,8379,5,1,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100],[3,29367,377,1655,1703,845,8288,33,3,19874,21,2411,979,45,70,416,192,6407,581,27620,11,30629,5,1,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100,-100]]),'decoder_input_ids': tensor([[0,71,1249,17030,18,8861,26737,297,13,26238,31,7,13017,1635,7,107,2309,2050,65,118,3754,5,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,17159,115,17,7,43,118,3279,81,823,2789,31,7,17804,91,18,858,18,5842,7,199,747,19,3,179,12,2862,2261,21154,16,502,6,227,3,9,1871,3977,13,1717,14566,53,826,3,9,5738,7952,5],[0,15670,65,243,150,1139,331,47,1513,116,24041,10564,15,26,165,8379,5,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,3,29367,377,1655,1703,845,8288,33,3,19874,21,2411,979,45,70,416,192,6407,581,27620,11,30629,5,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]])}deftrain():
parameter_names = get_parameter_names(model,[torch.nn.LayerNorm])
parameter_names =[i for i in parameter_names if'bias'notin i]# weight_decay 权重衰减
parameter_names =[{'params':[p for i, p in model.named_parameters()if i in parameter_names],'weight_decay':1e-2},{'params':[p for i, p in model.named_parameters()if i notin parameter_names],'weight_decay':0.0}]
optimizer = AdamW(parameter_names, betas=(0.9,0.999), eps=1e-8, lr=2e-5)# 定义优化器
scheduler = get_scheduler(name='linear',
num_warmup_steps=0,
num_training_steps=len(loader),
optimizer=optimizer)
model.to(device)
model.train()for i, data inenumerate(loader):
input_ids, attention_mask = data['input_ids'], data['attention_mask']
labels, decoder_input_ids = data['labels'], data['decoder_input_ids']
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)
decoder_input_ids = decoder_input_ids.to(device)
out = model(input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
decoder_input_ids=decoder_input_ids)
loss = out['loss']
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
model.zero_grad()if i %50==0:
lr = optimizer.state_dict()['param_groups'][0]['lr']
pred = tokenizer.decode(out['logits'].argmax(dim=2)[0])# logits:[4, 47, 32128]
label = tokenizer.decode(data['decoder_input_ids'][0])print(i, loss.item(), lr)print('pred: ', pred)print('label: ', label)print()
train()