因为同一个模型输入在经过两次dropout层会得到不一样的结果,这样的话我们就可以引入KL散度进行约束,尽量让两次模型的输出分布一致。因此loss就是交叉熵和KL散度的加权叠加。
计算KL散度
因为KL散度是不对称的,因此需要两次计算,然后取平均。
def compute_kl_loss(p, q, pad_mask=None):
p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none') #(样本数,num_label)
q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')
# pad_mask is for seq-level tasks
if pad_mask is not None:
p_loss.masked_fill_(pad_mask, 0.)
q_loss.masked_fill_(pad_mask, 0.)
# You can choose whether to use function "sum" and "mean" depending on your task
p_loss = p_loss.sum()
q_loss = q_loss.sum()
loss = (p_loss + q_loss) / 2
return loss
带有r-dropout的训练
def train_fn_r_drop(train_loader, model, optimizer, epoch, scheduler, device):
model.train()
losses = AverageMeter()
start = end = time.time()
global_step = 0
for step, batch in enumerate(train_loader):
label = batch[2].to(device)
mask = batch[1].to(device)
input_ids = batch[0].to(device)
batch_size = label.size(0)
#第一次交叉熵和logits
output_0 = model(input_ids, mask, labels=label)
loss_0 = output_0.loss
logits_0 = output_0.logits #(batch,num_labels)
output_1 = model(input_ids, mask, labels=label)
#第二次交叉熵和logits
loss_1 = output_1.loss
logits_1 = output_1.logits #(batch,num_labels)
ce_loss = 0.5 * (loss_0 + loss_1)
kl_loss = compute_kl_loss(logits_0, logits_1) #带入两次的logits计算出KL散度
loss = ce_loss + 0.5 * kl_loss
losses.update(loss.item(), batch_size)
optimizer.zero_grad()
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 500)
optimizer.step()
global_step += 1
scheduler.step()
end = time.time()
if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
print('Epoch: [{0}][{1}/{2}] '
'Elapsed {remain:s} '
'Loss: {loss.val:.4f}({loss.avg:.4f}) '
'Grad: {grad_norm:.4f} '
'LR: {lr:.8f} '
.format(epoch + 1, step, len(train_loader),
remain=timeSince(start, float(step + 1) / len(train_loader)),
loss=losses,
grad_norm=grad_norm,
lr=scheduler.get_lr()[0]))
return losses.avg