pytorch task04动手学pytorch-机器翻译

本文是PyTorch任务04,讲解了机器翻译的基本概念、数据集预处理、Encoder-Decoder结构、Sequence to Sequence模型,重点探讨了注意力机制的原理和实验,包括Softmax屏蔽操作。
摘要由CSDN通过智能技术生成

pytorch task04动手学pytorch-机器翻译

1. 机器翻译与数据集

机器翻译(MT):将一段文本从一种语言自动翻译为另一种语言,用神经网络解决这个问题通常称为神经机器翻译(NMT)。
主要特征:输出是单词序列而不是单个单词。 输出序列的长度可能与源序列的长度不同。
数据集采用 http://www.manythings.org/anki/ 的fra-eng数据集

1.1数据集预处理
#数据字典 char to index and index to char
class Vocab(object):
    def __init__(self, tokens, min_freq=0, use_special_tokens=False):
        counter = collections.Counter(tokens)
        self.token_freqs = list(counter.items())
        self.idx_to_token = []
        if use_special_tokens:
            # padding, begin of sentence, end of sentence, unknown
            self.pad, self.bos, self.eos, self.unk = (0, 1, 2, 3)
            self.idx_to_token += ['', '', '', '']
        else:
            self.unk = 0
            self.idx_to_token += ['']
        self.idx_to_token += [token for token, freq in self.token_freqs
                        if freq >= min_freq and token not in self.idx_to_token]
        self.token_to_idx = dict()
        for idx, token in enumerate(self.idx_to_token):
            self.token_to_idx[token] = idx

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

#数据清洗, tokenize, 建立数据字典
class TextPreprocessor():
    def __init__(self, text, num_lines):
        self.num_lines = num_lines
        text = self.clean_raw_text(text)
        self.src_tokens, self.tar_tokens = self.tokenize(text)
        self.src_vocab = self.build_vocab(self.src_tokens)
        self.tar_vocab = self.build_vocab(self.tar_tokens)
    
    def clean_raw_text(self, text):
        text = text.replace('\u202f', ' ').replace('\xa0', ' ')
        out = ''
        for i, char in enumerate(text.lower()):
            if char in (',', '!', '.') and i > 0 and text[i-1] != ' ':
                out += ' '
            out += char
        return out
        
    def tokenize(self, text):
        sources, targets = [], []
        for i, line in enumerate(text.split('\n')):
            if i > self.num_lines:
                break
            parts = line.split('\t')
            if len(parts) >= 2:
                sources.append(parts[0].split(' '))
                targets.append(parts[1].split(' '))
        return sources, targets
        
    def build_vocab(self, tokens):
        tokens = [token for line in tokens for token in line]
        return Vocab(tokens, min_freq=3, use_special_tokens=True)
1.2 创建dataloader
# pad, 构建数据dataset, 创建dataloader
class TextUtil():
    def __init__(self, tp, max_len):
        self.src_vocab, self.tar_vocab = tp.src_vocab, tp.tar_vocab
        src_arr, src_valid_len = self.build_array(tp.src_tokens, tp.src_vocab, max_len = max_len, padding_token = tp.src_vocab.pad, is_source=True)
        tar_arr, tar_valid_len = self.build_array(tp.tar_tokens, tp.tar_vocab, max_len = max_len, padding_token = tp.tar_vocab.pad, is_source=False)
        self.dataset = torch.utils.data.TensorDataset(src_arr, src_valid_len, tar_arr, tar_valid_len)
        
    def build_array(self,lines, vocab, max_len, padding_token, is_source):
        def _pad(line):
            if len(line) > max_len:
                return line[:max_len]
            else:
                return line + (max_len - len(line)) * [padding_token]
        lines = [vocab[line] for line in lines]
        if not is_source:
            lines = [[vocab.bos] + line + [vocab.eos] for line in lines]
        arr = torch.tensor([_pad(line) for line in lines])
        valid_len = (arr != vocab.pad).sum(1)
        return arr, valid_len
        
    def load_data_nmt(self, batch_size):
        train_loader = torch.utils.data.DataLoader(self.dataset, batch_size, shuffle = True)
        return self.src_vocab, self.tar_vocab, train_loader

2. Encoder Decoder

encoder:输入到隐藏状态
decoder:隐藏状态到输出

Image Name

3. Sequence to Sequence

3.1 结构

训练
Image Name
预测

Image Name

具体结构:
Image Name

3.2 代码实现
class Encoder(nn.Module):
    def __init__(self,**kwargs):
        super(Encoder, self).__init__(**kwargs)
    
    def forward(self, X, *args):
        raise NotImplementedError
    
class Decoder(nn.Module):
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)
    
    def init_state(self, encoded_state, *args):
        raise NotImplementedError
        
    def forward(self, X, state):
        raise NotImplementedError

class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值