Attention is all you need Transformer和Attention实现和注释

这篇博客详细解析了《Attention is All You Need》论文中的Transformer模型,并对相关代码进行了注释,特别针对机器翻译任务。内容涵盖torchtext的使用和NMT模型的组成部分,帮助读者理解视觉领域的细节。
摘要由CSDN通过智能技术生成

参考:跟着论文《 Attention is All You Need》一步一步实现Attention和Transformer

 

png

对上面博客中提供的代码的一些细节进行注释。

由于是以机器翻译作为例子。对于没有接触过这方面的,特别是做视觉的会有很多细节不能理解,我花了一些时间,看了torchtext的使用以及机器翻译的过程,给代码做了写注释。

torchtext的使用:参考1参考2torchtext文档等等

代码分成两部分,一部分是NMT的部分,另一部分是模型

import numpy as np
import torch
import torch.nn as nn
import time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
#%matplotlib inline
from torchtext import data, datasets
from model import *



#用于mask数据,产生source mask和target mask
class Batch:
    """ 在训练期间使用mask处理数据 """

    def __init__(self, src, trg=None, pad=0):
        #src.size = batch_size, q_len
        self.src = src
        #src_mask.size = batch_size, 1, q_len
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:, :-1]
            self.trg_y = trg[:, 1:]
            self.trg_mask = self.make_std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum()

    @staticmethod
    def make_std_mask(tgt, pad):
        """ 创造一个mask来屏蔽补全词和字典外的词进行屏蔽"""
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
        return tgt_mask

#将优化器再包一层,更方便
class NoamOpt:
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0

    def step(self):
        """ 更新参数和学习率 """
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step=None):
        """ lrate 实现"""
        if step is None:
            step = self._step
        return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))


#没用到,就是返回一个优化器,里面是一些设置
def get_std_up(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,
                   torch.optim.Adam(model.param_groups(),
                                    lr=0, betas=(0.9, 0.98), eps=1e-9))

'''size 是目标类别数目 smoothing这里使用,0.1'''
#平滑标签,将非真实目标的类别也给一个小的值
class LabelSmoothing(nn.Module):
    """ 标签平滑实现 """
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        #改成这样,不会有warning
        self.criterion = nn.KLDivLoss(reduction='none')
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None

    #x是generator的输出[n, vocab_size],也就是模型预测,target是真实目标,大小 n
    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        #为什么减去2???? 要减去padding_idx和正确的label本身
        #size(x) = batch_size,
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0) #dim,index,val

        self.true_dist = true_dist
        return self.criterion(x, Variable(true_dist, requires_grad=False))

class MultiGPULossCompute:
    "A multi-gpu loss compute and train function."

    def __init__(self, generator, criterion, devices, opt=None, chunk_size=5):
        # Send out to different gpus.
        self.generator = generator
        self.criterion = nn.parallel.replicate(criterion,
                                               devices=devices)
        self.opt = opt
        self.devices = devices
        self.chunk_size = chunk_size

    #size(out) = batch_size, max_len, d_model
    def __call__(self, out, targets, normalize):
        total = 0.0
        generator = nn.parallel.replicate(self.generator,
                                          devices=self.devices)
        out_scatter = nn.parallel.scatter(out,
                                
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值