# Save to the d2l package.classDotProductAttention(nn.Module):def__init__(self, dropout,**kwargs):super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)# query: (batch_size, #queries, d)# key: (batch_size, #kv_pairs, d)# value: (batch_size, #kv_pairs, dim_v)# valid_length: either (batch_size, ) or (batch_size, xx)defforward(self, query, key, value, valid_length=None):
d = query.shape[-1]# set transpose_b=True to swap the last two dimensions of key
scores = torch.bmm(query, key.transpose(1,2))/ math.sqrt(d)
attention_weights = self.dropout(masked_softmax(scores, valid_length))print("attention_weight\n",attention_weights)return torch.bmm(attention_weights, value)
import zipfile
import torch
import requests
from io import BytesIO
from torch.utils import data
import sys
import collections
classVocab(object):# This class is saved in d2l.def__init__(self, tokens, min_freq=0, use_special_tokens=False):# sort by frequency and token
counter = collections.Counter(tokens)
token_freqs =sorted(counter.items(), key=lambda x: x[0])
token_freqs.sort(key=lambda x: x[1], reverse=True)if use_special_tokens:# padding, begin of sentence, end of sentence, unknown
self.pad, self.bos, self.eos, self.unk =(0,1,2,3)
tokens =['','','','']else:
self.unk =0
tokens =['']
tokens +=[token for token, freq in token_freqs if freq >= min_freq]
self.idx_to_token =[]
self.token_to_idx =dict()for token in tokens:
self.idx_to_token.append(token)
self.token_to_idx[token]=len(self.idx_to_token)-1def__len__(self):returnlen(self.idx_to_token)def__getitem__(self, tokens):ifnotisinstance(tokens,(list,tuple)):return self.token_to_idx.get(tokens, self.unk)else:return[self.__getitem__(token)for token in tokens]defto_tokens(self, indices):ifnotisinstance(indices,(list,tuple)):return self.idx_to_token[indices]else:return[self.idx_to_token[index]for index in indices]defload_data_nmt(batch_size, max_len, num_examples=1000):"""Download an NMT dataset, return its vocabulary and data iterator."""# Download and preprocessdefpreprocess_raw(text):
text = text.replace('\u202f',' ').replace('\xa0',' ')
out =''for i, char inenumerate(text.lower()):if char in(',','!','.')and text[i-1]!=' ':
out +=' '
out += char
return out
withopen('/home/kesci/input/fraeng6506/fra.txt','r')as f:
raw_text = f.read()
text = preprocess_raw(raw_text)# Tokenize
source, target =[],[]for i, line inenumerate(text.split('\n')):if i >= num_examples:break
parts = line.split('\t')iflen(parts)>=2:
source.append(parts[0].split(' '))
target.append(parts[1].split(' '))# Build vocabdefbuild_vocab(tokens):
tokens =[token for line in tokens for token in line]return Vocab(tokens, min_freq=3, use_special_tokens=True)
src_vocab, tgt_vocab = build_vocab(source), build_vocab(target)# Convert to index arraysdefpad(line, max_len, padding_token):iflen(line)> max_len:return line[:max_len]return line +[padding_token]*(max_len -len(line))defbuild_array(lines, vocab, max_len, is_source):
lines =[vocab[line]for line in lines]ifnot is_source:
lines =[[vocab.bos]+ line +[vocab.eos]for line in lines]
array = torch.tensor([pad(line, max_len, vocab.pad)for line in lines])
valid_len =(array != vocab.pad).sum(1)return array, valid_len
src_vocab, tgt_vocab = build_vocab(source), build_vocab(target)
src_array, src_valid_len = build_array(source, src_vocab, max_len,True)
tgt_array, tgt_valid_len = build_array(target, tgt_vocab, max_len,False)
train_data = data.TensorDataset(src_array, src_valid_len, tgt_array, tgt_valid_len)
train_iter = data.DataLoader(train_data, batch_size, shuffle=True)return src_vocab, tgt_vocab, train_iter
for sentence in['Go .','Good Night !',"I'm OK .",'I won !']:print(sentence +' => '+ d2l.predict_s2s_ch9(
model, sentence, src_vocab, tgt_vocab, num_steps, ctx))