import torch
import random
from train import indexesFromSentence
from load import MAX_LENGTH, loadPrepareData, Voc
from model import *
from config import *
class Sentence:
def __init__(self, decoder_hidden, last_idx=SOS_token, sentence_idxes=[], sentence_scores=[]):
if (len(sentence_idxes) != len(sentence_scores)):
raise ValueError("length of indexes and scores should be the same")
self.decoder_hidden = decoder_hidden
self.last_idx = last_idx
self.sentence_idxes = sentence_idxes
self.sentence_scores = sentence_scores
def avgScore(self):
if len(self.sentence_scores) == 0:
raise ValueError("Calculate average score of sentence, but got no word")
# return mean of sentence_score
return sum(self.sentence_scores) / len(self.sentence_scores)
def addTopk(self, topi, topv, decoder_hidden, beam_size, voc):
topv = torch.log(topv)
terminates, sentences = [], []
for i in range(beam_size):
if topi[0][i] == EOS_token:
terminates.append(([voc.index2word[idx.item()] for idx in self.sentence_idxes] + ['<EOS>'],
self.avgScore())) # tuple(word_list, score_float
continue
idxes = self.sentence_idxes[:] # pass by value
scores = self.sentence_scores[:] # pass by value
idxes.append(topi[0][i])
scores.append(topv[0][i])
sentences.append(Sentence(decoder_hidden, topi[0][i], idxes, scores))
return terminates, sentences
def toWordScore(self, voc):
words = []
for i in range(len(self.sentence_idxes)):
if self.sentence_idxes[i] == EOS_token:
words.append('<EOS>')
else:
words.append(voc.index2word[self.sentence_idxes[i].item()])
if self.sentence_idxes[-1] != EOS_token:
words.append('<EOS>')
return words, self.avgScore()
def beam_decode(decoder, decoder_hidden, encoder_outputs, voc, beam_size, max_length=MAX_LENGTH):
terminal_sentences, prev_top_sentences, next_top_sentences = [], [], []
prev_top_sentences.append(Sentence(decoder_hidden))
for i in range(max_length):
for sentence in prev_top_sentences:
decoder_input = torch.LongTensor([[sentence.last_idx]])
decoder_input = decoder_input.to(device)
decoder_hidden = sentence.decoder_hidden
decoder_output, decoder_hidden, _ = decoder(
decoder_input, decoder_hidden, encoder_outputs
)
topv, topi = decoder_output.topk(beam_size)
term, top = sentence.addTopk(topi, topv, decoder_hidden, beam_size, voc)
terminal_sentences.extend(term)
next_top_sentences.extend(top)
next_top_sentences.sort(key=lambda s: s.avgScore(), reverse=True)
prev_top_sentences = next_top_sentences[:beam_size]
next_top_sentences = []
terminal_sentences += [sentenfce.toWordScore(voc) for sentence in prev_top_sentences]
terminal_sentences.sort(key=lambda x: x[1], reverse=True)
n = min(len(terminal_sentences), 15)
return terminal_sentences[:n]
def decode(decoder, decoder_hidden, encoder_outputs, voc, max_length=MAX_LENGTH):
decoder_input = torch.LongTensor([[SOS_token]])
decoder_input = decoder_input.to(device)
decoded_words = []
decoder_attentions = torch.zeros(max_length, max_length) # TODO: or (MAX_LEN+1, MAX_LEN+1)
for di in range(max_length):
# decoder_output, decoder_hidden, decoder_attn = decoder(decoder_input, decoder_hidden, encoder_outputs)
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
_, topi = decoder_output.topk(3) # 目的是什么?
ni = topi[0][0]
if ni == EOS_token:
decoded_words.append('<EOS>')
break
else:
decoded_words.append(voc.index2word[ni.item()])
decoder_input = torch.LongTensor([[ni]])
decoder_input = decoder_input.to(device)
return decoded_words, decoder_attentions[:di + 1]
def evaluate(encoder, decoder, voc, sentence, beam_size, max_length=MAX_LENGTH):
indexes_batch = [indexesFromSentence(voc, sentence)] # [1, seq_len]
lengths = [len(indexes) for indexes in indexes_batch]
input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
input_batch = input_batch.to(device)
encoder_outputs, encoder_hidden = encoder(input_batch, lengths, None)
# decoder_hidden = encoder_hidden[:decoder.n_layers]
decoder_hidden = encoder_hidden[:decoder.n_layers] + encoder_hidden[decoder.n_layers:]
if beam_size == 1:
return decode(decoder, decoder_hidden, encoder_outputs, voc)
else:
return beam_decode(decoder, decoder_hidden, encoder_outputs, voc, beam_size)
def evaluateInput(encoder, decoder, beam_size, voc):
while True:
try:
pair = input('> ')
if pair == 'q': break
if beam_size == 1:
output_words, _ = evaluate(encoder, decoder, voc, pair, beam_size)
output_sentence = ' '.join(output_words)
print('<', output_sentence)
else:
output_words_list = evaluate(encoder, decoder, voc, pair, beam_size)
for output_words, score in output_words_list:
output_sentence = ' '.join(output_words)
print("{:.3f} < {}".format(score, output_sentence))
except KeyError:
print("Incorrect spelling.")
def evaluateRandomly(encoder, decoder, voc, pairs, reverse, beam_size, n=10):
for _ in range(n):
pair = random.choice(pairs)
print("=============================================================")
if reverse:
print('>', " ".join(reversed(pair[0].split())))
else:
print('>', pair[0])
if beam_size == 1:
output_words, _ = evaluate(encoder, decoder, voc, pair[0], beam_size)
output_sentence = ' '.join(output_words)
print('<', output_sentence)
else:
output_words_list = evaluate(encoder, decoder, voc, pair[0], beam_size)
for output_words, score in output_words_list:
output_sentence = ' '.join(output_words)
print("{:.3f} < {}".format(score, output_sentence))
def runTest(n_layers, hidden_size, reverse, modelFile, corpus, beam_size=1, inp=True):
torch.set_grad_enabled(False)
voc, pairs = loadPrepareData(corpus)
embedding = nn.Embedding(voc.num_words, hidden_size)
encoder = EncoderRNN(voc.num_words, hidden_size, embedding, n_layers)
attn_model = 'dot'
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, n_layers)
checkpoint = torch.load(modelFile, map_location='cpu')
encoder.load_state_dict(checkpoint['en'])
decoder.load_state_dict(checkpoint['de'])
# train mode set to false, effect only on dropout, batchNorm
encoder.eval()
decoder.eval()
encoder = encoder.to(device)
decoder = decoder.to(device)
if inp:
evaluateInput(encoder, decoder, beam_size, voc)
else:
evaluateRandomly(encoder, decoder, voc, pairs, reverse, beam_size, 20)
def parseFilename(filename):
filename = filename.split('/')
dataType = filename[-1][:-4] # remove '.tar'
parse = dataType.split('_')
reverse = 'reverse' in parse
layers, hidden = filename[-2].split('_')
n_layers = int(layers.split('-')[0])
hidden_size = int(hidden)
return n_layers, hidden_size, reverse
if __name__ == '__main__':
modelFile = './save/model/formatted_movie_lines/2-2_500/10000_checkpoint.tar'
n_layers, hidden_size, reverse = parseFilename(modelFile)
runTest(n_layers, hidden_size, reverse, modelFile, corpus)
# encoder.eval()
# decoder.eval()
# searcher = GreedySearchDecoder(encoder, decoder)
# evaluateInput(encoder, decoder, searcher, voc)
语言训练模型
最新推荐文章于 2024-07-16 19:26:54 发布