PG-net模型训练部分

注意:在这里最后几个epoch才优化coverage_loss

import os
import args
import torch
import random
from tqdm import tqdm
import evaluate
from optimizer import BertAdam
from dataset.dataloader import Dureader
from dataset.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from model_dir.modeling import BertForQuestionAnswering,BertConfig

random.seed(args.seed)
torch.manual_seed(args.seed)
device = args.device
#print(device)
device_ids = [0]

if len(device_ids) > 0:
    torch.cuda.manual_seed_all(args.seed)
model = BertForQuestionAnswering.from_pretrained('./chinese_roberta_wwm_ext')

if len(device_ids) > 1:
    model = torch.nn.DataParallel(model,device_ids = device_ids)
    model = model.cuda(device = device_ids[0])
elif len(device_ids) == 1:
    model.to(device)
    
param_optimizer = list(model.named_parameters())
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
no_decay = ['bias','LayerNorm.bias','LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params':[p for n,p in param_optimizer if not any(nd in n for nd in no_decay)],'weight_decay':0.01},
    {'params':[p for n,p in param_optimizer if any(nd in n for nd in no_decay)]}
]
optimizer = BertAdam(optimizer_grouped_parameters,lr = args.learning_rate,warmup = 0.1,t_total = args.num_train_optimization_steps)


data = Dureader()
train_dataloader,dev_dataloader = data.train_iter,data.dev_iter
best_loss = 100000.0
model.train()
for epoch in range(args.num_train_epochs):
    main_losses,ide_losses = 0,0
    train_loss,train_loss_total = 0.0,0.0
    n_words,n_words_total = 0,0
    n_sent,n_sents_total = 0,0
    
    for step,batch in enumerate(tqdm(train_dataloader,desc = 'Epoch')):
            input_ids, input_mask, input_ids_q, input_mask_q, answer_ids, answer_mask, \
            segment_ids, can_answer = \
                batch.input_ids, batch.input_mask, batch.input_ids_q, batch.input_mask_q,\
                batch.answer_ids, batch.answer_mask, batch.segment_ids,batch.can_answer
            answer_inputs = answer_ids[:,:-1]    # 去掉EOS
            answer_targets = answer_ids[:,1:]    # 去掉BOS
            answer_len = answer_mask.sum(1) - 1
            if len(device_ids) > 1:
                input_ids, input_mask, input_ids_q, input_mask_q, answer_inputs, answer_len, answer_targets,\
                segment_ids, can_answer = \
                    input_ids.cuda(device=device_ids[0]), input_mask.cuda(device=device_ids[0]), \
                    input_ids_q.cuda(device=device_ids[0]), input_mask_q.cuda(device=device_ids[0]), \
                    answer_inputs.cuda(device=device_ids[0]), answer_len.cuda(device=device_ids[0]), \
                    answer_targets.cuda(device=device_ids[0]), \
                    segment_ids.cuda(device=device_ids[0]), can_answer.cuda(device=device_ids[0])
            elif len(device_ids) == 1:
                input_ids, input_mask, input_ids_q, input_mask_q, answer_inputs, answer_len, answer_targets, \
                segment_ids, can_answer = \
                    input_ids.to(device), input_mask.to(device), input_ids_q.to(device), input_mask_q.to(device), \
                    answer_inputs.to(device), answer_len.to(device), answer_targets.to(device), \
                    segment_ids.to(device), can_answer.to(device)
            
            loss,coverage_loss = model(input_ids,input_ids_q,token_type_ids = segment_ids,attention_mask = input_mask,
                                       attention_mask_q = input_mask_q,dec_inputs = answer_inputs,dec_inputs_len = answer_len,
                                       dec_targets = answer_targets,can_answer = can_answer)
            if epoch >= args.num_train_epochs - 2:
                loss = loss + coverage_loss
            train_loss_total += float(loss.item())
            n_words_total += torch.sum(answer_len)
            n_sents_total += answer_len.size(0)
            if step % args.display_freq == 0 and step:
                loss_int = (train_loss_total - train_loss)
                n_words_int = (n_words_total - n_words)
                loss_per_words = loss_int / n_words_int.item()
                avg_loss = loss_per_words
                print('Epoch {0:<3}'.format(epoch),'Step {0:<10}'.format(step),'Avg_loss {0:<10.2f}'.format(avg_loss))
                train_loss,n_words,n_sents = (train_loss_total,n_words_total.item(),n_sents_total)
            elif step == 0:
                loss_int = (train_loss_total - train_loss)
                n_words_int = (n_words_total - n_words)
                loss_per_words = loss_int / n_words_int.item()
                avg_loss = loss_per_words

                print('Epoch {0:<3}'.format(epoch),
                      'Step {0:<10}'.format(step),
                      'Avg_loss {0:<10.2f}'.format(avg_loss))
                train_loss, n_words, n_sents = (train_loss_total, n_words_total.item(), n_sents_total)
            
            if len(device_ids) > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            
            loss.backward()
            
            
            # 更新梯度
            if (step+1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
            
            
            # 验证
            if step % args.log_step == 4:
                eval_loss = evaluate.evaluate(model,dev_dataloader,device_ids)
                if eval_loss < best_loss:
                    best_loss = eval_loss
                    if len(device_ids) > 1:
                        torch.save(model.module.state_dict(),'./model_dir/debug')
                    if len(device_ids) == 1:
                        torch.save(model.state_dict(), './model_dir/' + "debug")
                model.train()


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值