基于 lstm的英文翻译系统

基于 lstm的英文翻译系统

代码原理是根据论文: Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation

链接: https://arxiv.org/pdf/1406.1078.pdf?source=post_page---------------------------

在这里插入图片描述

import torch
if torch.cuda.is_available():
    # Tell PyTorch to use the GPU.
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")
There are 1 GPU(s) available.
We will use the GPU: GeForce GTX 1070
导入包
import os
import sys
import math
import nltk
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from collections import Counter
导入数据
def load_data(in_file):
    cn = []
    en = []
    num_examples = 0
    with open(in_file, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip().split("\t")
            en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"])
            # split chinese sentence into characters
            cn.append(["BOS"] + [c for c in line[1]] + ["EOS"])
    return en, cn
分别建立英语和中文词典
train_file = "./data/train.txt"
dev_file = "./data/dev.txt"
train_en, train_cn = load_data(train_file)
dev_en, dev_cn = load_data(dev_file)

UNK_IDX = 0
PAD_IDX = 1
def build_dict(sentences, max_words=50000):
    word_count = Counter()
    for sentence in sentences:
        for s in sentence:
            word_count[s] += 1
    ls = word_count.most_common(max_words)
    total_words = len(ls) + 2
    word_dict = {w[0]: index+2 for index, w in enumerate(ls)}
    word_dict["UNK"] = UNK_IDX
    word_dict["PAD"] = PAD_IDX
    return word_dict, total_words

en_dict, en_total_words = build_dict(train_en)
cn_dict, cn_total_words = build_dict(train_cn)
inv_en_dict = {v: k for k, v in en_dict.items()}
inv_cn_dict = {v: k for k, v in cn_dict.items()}

编码句子
def encode(en_sentences, cn_sentences, en_dict, cn_dict, sort_by_len=True):
    '''
        Encode the sequences.
    '''
    length = len(en_sentences)
    out_en_sentences = [[en_dict.get(w, 0) for w in sent] for sent in en_sentences]
    out_cn_sentences = [[cn_dict.get(w, 0) for w in sent] for sent in cn_sentences]

    # sort sentences by english lengths
    def len_argsort(seq):
        return sorted(range(len(seq)), key=lambda x: len(seq[x]))

    # 把中文和英文按照同样的顺序排序
    if sort_by_len:
        sorted_index = len_argsort(out_en_sentences)
        out_en_sentences = [out_en_sentences[i] for i in sorted_index]
        out_cn_sentences = [out_cn_sentences[i] for i in sorted_index]

    return out_en_sentences, out_cn_sentences

train_en, train_cn = encode(train_en, train_cn, en_dict, cn_dict)
dev_en, dev_cn = encode(dev_en, dev_cn, en_dict, cn_dict)

简单查看一下平行语料
k = 10000
print(" ".join([inv_cn_dict[i] for i in train_cn[k]]))
print(" ".join([inv_en_dict[i] for i in train_en[k]]))
BOS 打 牌 很 有 意 思 。 EOS
BOS playing cards is very interesting . EOS
为模型训练处理数据
def get_minibatches(n, minibatch_size, shuffle=True):
    idx_list = np.arange(0, n, minibatch_size) # [0, 1, ..., n-1]
    if shuffle:
        np.random.shuffle(idx_list)
    minibatches = []
    for idx in idx_list:
        minibatches.append(np.arange(idx, min(idx + minibatch_size, n)))
    return minibatches

def prepare_data(seqs):
    lengths = [len(seq) for seq in seqs]
    n_samples = len(seqs)
    max_len = np.max(lengths)

    x = np.zeros((n_samples, max_len)).astype('int32')
    x_lengths = np.array(lengths).astype("int32")
    for idx, seq in enumerate(seqs):
        x[idx, :lengths[idx]] = seq
    return x, x_lengths #x_mask

def gen_examples(en_sentences, cn_sentences, batch_size):
    minibatches = get_minibatches(len(en_sentences), batch_size)
    all_ex = []
    for minibatch in minibatches:
        mb_en_sentences = [en_sentences[t] for t in minibatch]
        mb_cn_sentences = [cn_sentences[t] for t in minibatch]
        mb_x, mb_x_len = prepare_data(mb_en_sentences)
        mb_y, mb_y_len = prepare_data(mb_cn_sentences)
        all_ex.append((mb_x, mb_x_len, mb_y, mb_y_len))
    return all_ex

batch_size = 64
train_data = gen_examples(train_en, train_cn, batch_size)
random.shuffle(train_data)
dev_data = gen_examples(dev_en, dev_cn, batch_size)

sequence to sequence 中的encoder 部分

在这里插入图片描述

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
        super(Encoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, enc_hidden_size, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(enc_hidden_size * 2, dec_hidden_size)

    def forward(self, x, lengths):
        sorted_len, sorted_idx = lengths.sort(0, descending=True)
        x_sorted = x[sorted_idx.long()]
        embedded = self.dropout(self.embed(x_sorted))

        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(),
                                                            batch_first=True)
        packed_out, hid = self.rnn(packed_embedded)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        _, original_idx = sorted_idx.sort(0, descending=False)
        out = out[original_idx.long()].contiguous()
        hid = hid[:, original_idx.long()].contiguous()

        hid = torch.cat([hid[-2], hid[-1]], dim=1)
        hid = torch.tanh(self.fc(hid)).unsqueeze(0)

        return out, hid

为decoder 加上attention 机制

在这里插入图片描述

class Attention(nn.Module):
    def __init__(self, enc_hidden_size, dec_hidden_size):
        super(Attention, self).__init__()

        self.enc_hidden_size = enc_hidden_size
        self.dec_hidden_size = dec_hidden_size

        self.linear_in = nn.Linear(enc_hidden_size * 2, dec_hidden_size, bias=False)
        self.linear_out = nn.Linear(enc_hidden_size * 2 + dec_hidden_size, dec_hidden_size)

    def forward(self, output, context, mask):
        # output: batch_size, output_len, dec_hidden_size
        # context: batch_size, context_len, 2*enc_hidden_size

        batch_size = output.size(0)
        output_len = output.size(1)
        input_len = context.size(1)

        context_in = self.linear_in(context.view(batch_size * input_len, -1)).view(
            batch_size, input_len, -1)  # batch_size, context_len, dec_hidden_size

        # context_in.transpose(1,2): batch_size, dec_hidden_size, context_len
        # output: batch_size, output_len, dec_hidden_size
        attn = torch.bmm(output, context_in.transpose(1, 2))
        # batch_size, output_len, context_len

        attn.data.masked_fill(mask.bool(), -1e6)
       # attn.data.masked_fill(mask, -1e6)

        attn = F.softmax(attn, dim=2)
        # batch_size, output_len, context_len

        context = torch.bmm(attn, context)
        # batch_size, output_len, enc_hidden_size

        output = torch.cat((context, output), dim=2)  # batch_size, output_len, hidden_size*2

        output = output.view(batch_size * output_len, -1)
        output = torch.tanh(self.linear_out(output))
        output = output.view(batch_size, output_len, -1)
        return output, attn

sequence to sequence 中的Decoder 部分

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-B0rJC8hI-1596632193579)(attachment:image.png)]

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(enc_hidden_size, dec_hidden_size)
        self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True)
        self.out = nn.Linear(dec_hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def create_mask(self, x_len, y_len):
        # a mask of shape x_len * y_len
        device = x_len.device
        max_x_len = x_len.max()
        max_y_len = y_len.max()
        x_mask = torch.arange(max_x_len, device=x_len.device)[None, :] < x_len[:, None]
        y_mask = torch.arange(max_y_len, device=x_len.device)[None, :] < y_len[:, None]
        # mask = (1 - x_mask[:, :, None] * y_mask[:, None, :]).byte()
        mask = (~x_mask[:, :, None] * y_mask[:, None, :]).byte()
        return mask

    def forward(self, ctx, ctx_lengths, y, y_lengths, hid):
        sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
        y_sorted = y[sorted_idx.long()]
        hid = hid[:, sorted_idx.long()]

        y_sorted = self.dropout(self.embed(y_sorted))  # batch_size, output_length, embed_size

        packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
        out, hid = self.rnn(packed_seq, hid)
        unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        _, original_idx = sorted_idx.sort(0, descending=False)
        output_seq = unpacked[original_idx.long()].contiguous()
        hid = hid[:, original_idx.long()].contiguous()

        mask = self.create_mask(y_lengths, ctx_lengths)

        output, attn = self.attention(output_seq, ctx, mask)
        output = F.log_softmax(self.out(output), -1)

        return output, hid, attn
定义语言模型
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x, x_lengths, y, y_lengths):
        encoder_out, hid = self.encoder(x, x_lengths)
        output, hid, attn = self.decoder(ctx=encoder_out,
                                         ctx_lengths=x_lengths,
                                         y=y,
                                         y_lengths=y_lengths,
                                         hid=hid)
        return output, attn

    def translate(self, x, x_lengths, y, max_length=100):
        encoder_out, hid = self.encoder(x, x_lengths)
        preds = []
        batch_size = x.shape[0]
        attns = []
        for i in range(max_length):
            output, hid, attn = self.decoder(ctx=encoder_out,
                                             ctx_lengths=x_lengths,
                                             y=y,
                                             y_lengths=torch.ones(batch_size).long().to(y.device),
                                             hid=hid)
            y = output.max(2)[1].view(batch_size, 1)
            preds.append(y)
            attns.append(attn)
        return torch.cat(preds, 1), torch.cat(attns, 1)

设计损失函数(masked cross entropy loss)
class LanguageModelCriterion(nn.Module):
    def __init__(self):
        super(LanguageModelCriterion, self).__init__()

    def forward(self, input, target, mask):
        # input: (batch_size * seq_len) * vocab_size
        input = input.contiguous().view(-1, input.size(2))
        # target: batch_size * 1
        target = target.contiguous().view(-1, 1)
        mask = mask.contiguous().view(-1, 1)
        output = -input.gather(1, target) * mask
        output = torch.sum(output) / torch.sum(mask)

        return output
初始化模型参数
dropout = 0.2
embed_size = hidden_size = 100
encoder = Encoder(vocab_size=en_total_words,
                       embed_size=embed_size,
                      enc_hidden_size=hidden_size,
                       dec_hidden_size=hidden_size,
                      dropout=dropout)
decoder = Decoder(vocab_size=cn_total_words,
                      embed_size=embed_size,
                      enc_hidden_size=hidden_size,
                       dec_hidden_size=hidden_size,
                      dropout=dropout)

model = Seq2Seq(encoder, decoder)
model = model.to(device)
loss_fn = LanguageModelCriterion().to(device)
optimizer = torch.optim.Adam(model.parameters())
定义评估函数
def evaluate(model, data):
    model.eval()
    total_num_words = total_loss = 0.
    with torch.no_grad():
        for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
            mb_x = torch.from_numpy(mb_x).to(device).long()
            mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
            mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
            mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
            mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()
            mb_y_len[mb_y_len<=0] = 1

            mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)

            mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
            mb_out_mask = mb_out_mask.float()

            loss = loss_fn(mb_pred, mb_output, mb_out_mask)

            num_words = torch.sum(mb_y_len).item()
            total_loss += loss.item() * num_words
            total_num_words += num_words
    print("Evaluation loss", total_loss/total_num_words)
定义训练函数
def train(model, data, num_epochs=20):
    for epoch in range(num_epochs):
        model.train()
        total_num_words = total_loss = 0.
        for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
            mb_x = torch.from_numpy(mb_x).to(device).long()
            mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
            mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
            mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
            mb_y_len = torch.from_numpy(mb_y_len - 1).to(device).long()
            mb_y_len[mb_y_len <= 0] = 1

            mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)

            mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
            mb_out_mask = mb_out_mask.float()

            loss = loss_fn(mb_pred, mb_output, mb_out_mask)

            num_words = torch.sum(mb_y_len).item()
            total_loss += loss.item() * num_words
            total_num_words += num_words

            # 更新模型
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
            optimizer.step()

            if it % 500 == 0:
                print("Epoch", epoch, "iteration", it, "loss", loss.item())

        print("Epoch", epoch, "Training loss", total_loss / total_num_words)
        if epoch % 50 == 0:
            evaluate(model, dev_data)

训练模型
train(model, train_data, num_epochs=200)
Epoch 0 iteration 0 loss 7.928357124328613
Epoch 0 iteration 500 loss 3.90830397605896
Epoch 0 Training loss 4.815643868381794
Evaluation loss 4.091342107835811
Epoch 1 iteration 0 loss 3.849057912826538
Epoch 1 iteration 500 loss 3.227210283279419
Epoch 1 Training loss 3.766976896258373
Epoch 2 iteration 0 loss 3.2430152893066406
Epoch 2 iteration 500 loss 2.896683931350708
Epoch 2 Training loss 3.2947107675981218
Epoch 3 iteration 0 loss 2.8801681995391846
Epoch 3 iteration 500 loss 2.655324697494507
Epoch 3 Training loss 2.99818192814612
Epoch 4 iteration 0 loss 2.606170892715454
Epoch 4 iteration 500 loss 2.485543727874756
Epoch 4 Training loss 2.7858131342012973
Epoch 5 iteration 0 loss 2.4095005989074707
Epoch 5 iteration 500 loss 2.3372409343719482
Epoch 5 Training loss 2.61820455699905
Epoch 6 iteration 0 loss 2.2848551273345947
Epoch 6 iteration 500 loss 2.2205848693847656
Epoch 6 Training loss 2.4839356306734097
Epoch 7 iteration 0 loss 2.169241189956665
Epoch 7 iteration 500 loss 2.1018216609954834
Epoch 7 Training loss 2.370443155886292
Epoch 8 iteration 0 loss 2.06174635887146
Epoch 8 iteration 500 loss 2.020695447921753
Epoch 8 Training loss 2.2732634094208275
Epoch 9 iteration 0 loss 1.984137773513794
Epoch 9 iteration 500 loss 1.939084529876709
Epoch 9 Training loss 2.1883279392436283
Epoch 10 iteration 0 loss 1.9452708959579468
Epoch 10 iteration 500 loss 1.8818047046661377
Epoch 10 Training loss 2.112133346478979
Epoch 11 iteration 0 loss 1.8716343641281128
Epoch 11 iteration 500 loss 1.8141098022460938
Epoch 11 Training loss 2.0430460208444448
Epoch 12 iteration 0 loss 1.7709314823150635
Epoch 12 iteration 500 loss 1.7577903270721436
Epoch 12 Training loss 1.9812295165150986
Epoch 13 iteration 0 loss 1.7548656463623047
Epoch 13 iteration 500 loss 1.732089638710022
Epoch 13 Training loss 1.927129452975913
Epoch 14 iteration 0 loss 1.711237907409668
Epoch 14 iteration 500 loss 1.6911147832870483
Epoch 14 Training loss 1.8765310167380684
Epoch 15 iteration 0 loss 1.6728557348251343
Epoch 15 iteration 500 loss 1.6582821607589722
Epoch 15 Training loss 1.8305548231946718
Epoch 16 iteration 0 loss 1.6360520124435425
Epoch 16 iteration 500 loss 1.5719242095947266
Epoch 16 Training loss 1.7837277017431035
Epoch 17 iteration 0 loss 1.6439002752304077
Epoch 17 iteration 500 loss 1.524785041809082
Epoch 17 Training loss 1.7447902788038288
Epoch 18 iteration 0 loss 1.58423912525177
Epoch 18 iteration 500 loss 1.5078282356262207
Epoch 18 Training loss 1.7080127060983032
Epoch 19 iteration 0 loss 1.5390608310699463
Epoch 19 iteration 500 loss 1.5225342512130737
Epoch 19 Training loss 1.6723873219620053
Epoch 20 iteration 0 loss 1.494099497795105
Epoch 20 iteration 500 loss 1.5124914646148682
Epoch 20 Training loss 1.6367066513444974
Epoch 21 iteration 0 loss 1.4822521209716797
Epoch 21 iteration 500 loss 1.467223048210144
Epoch 21 Training loss 1.606666042107959
Epoch 22 iteration 0 loss 1.4411041736602783
Epoch 22 iteration 500 loss 1.4498487710952759
Epoch 22 Training loss 1.5750408884265574
Epoch 23 iteration 0 loss 1.3950568437576294
Epoch 23 iteration 500 loss 1.3538126945495605
Epoch 23 Training loss 1.551325583339711
Epoch 24 iteration 0 loss 1.3918793201446533
Epoch 24 iteration 500 loss 1.3291021585464478
Epoch 24 Training loss 1.5217847593300489
Epoch 25 iteration 0 loss 1.365149736404419
Epoch 25 iteration 500 loss 1.3761420249938965
Epoch 25 Training loss 1.4973933488412807
Epoch 26 iteration 0 loss 1.3466569185256958
Epoch 26 iteration 500 loss 1.371177077293396
Epoch 26 Training loss 1.4720650266576034
Epoch 27 iteration 0 loss 1.3281514644622803
Epoch 27 iteration 500 loss 1.331527590751648
Epoch 27 Training loss 1.4490924394659275
Epoch 28 iteration 0 loss 1.3383699655532837
Epoch 28 iteration 500 loss 1.306111454963684
Epoch 28 Training loss 1.4280537929261388
Epoch 29 iteration 0 loss 1.3682355880737305
Epoch 29 iteration 500 loss 1.2585397958755493
Epoch 29 Training loss 1.4075209773713206
Epoch 30 iteration 0 loss 1.297282338142395
Epoch 30 iteration 500 loss 1.2389955520629883
Epoch 30 Training loss 1.386655172488642
Epoch 31 iteration 0 loss 1.2280360460281372
Epoch 31 iteration 500 loss 1.241192102432251
Epoch 31 Training loss 1.3661337603685542
Epoch 32 iteration 0 loss 1.2098947763442993
Epoch 32 iteration 500 loss 1.1983294486999512
Epoch 32 Training loss 1.3482332145965288
Epoch 33 iteration 0 loss 1.2415146827697754
Epoch 33 iteration 500 loss 1.2350493669509888
Epoch 33 Training loss 1.3326907000024077
Epoch 34 iteration 0 loss 1.2163398265838623
Epoch 34 iteration 500 loss 1.1608344316482544
Epoch 34 Training loss 1.3135639723825614
Epoch 35 iteration 0 loss 1.228286623954773
Epoch 35 iteration 500 loss 1.1782830953598022
Epoch 35 Training loss 1.296640716271889
Epoch 36 iteration 0 loss 1.205100417137146
Epoch 36 iteration 500 loss 1.2059627771377563
Epoch 36 Training loss 1.2816641990295445
Epoch 37 iteration 0 loss 1.1419614553451538
Epoch 37 iteration 500 loss 1.0831267833709717
Epoch 37 Training loss 1.266050579365604
Epoch 38 iteration 0 loss 1.1397496461868286
Epoch 38 iteration 500 loss 1.0981181859970093
Epoch 38 Training loss 1.2505547716558623
Epoch 39 iteration 0 loss 1.1959482431411743
Epoch 39 iteration 500 loss 1.1356191635131836
Epoch 39 Training loss 1.2390485424490791
Epoch 40 iteration 0 loss 1.1450738906860352
Epoch 40 iteration 500 loss 1.0957298278808594
Epoch 40 Training loss 1.2206933174895096
Epoch 41 iteration 0 loss 1.1270146369934082
Epoch 41 iteration 500 loss 1.0484026670455933
Epoch 41 Training loss 1.209546530202252
Epoch 42 iteration 0 loss 1.1364833116531372
Epoch 42 iteration 500 loss 1.0879396200180054
Epoch 42 Training loss 1.1973414120195522
Epoch 43 iteration 0 loss 1.0954558849334717
Epoch 43 iteration 500 loss 1.0777392387390137
Epoch 43 Training loss 1.1830152717945142
Epoch 44 iteration 0 loss 1.1112639904022217
Epoch 44 iteration 500 loss 1.0734856128692627
Epoch 44 Training loss 1.1713028105248906
Epoch 45 iteration 0 loss 1.0966465473175049
Epoch 45 iteration 500 loss 1.0444735288619995
Epoch 45 Training loss 1.1590108651671145
Epoch 46 iteration 0 loss 1.0701093673706055
Epoch 46 iteration 500 loss 1.0505683422088623
Epoch 46 Training loss 1.1480596024264336
Epoch 47 iteration 0 loss 1.0556902885437012
Epoch 47 iteration 500 loss 0.998938262462616
Epoch 47 Training loss 1.136144786781804
Epoch 48 iteration 0 loss 1.1005561351776123
Epoch 48 iteration 500 loss 1.046020746231079
Epoch 48 Training loss 1.1245009635135628
Epoch 49 iteration 0 loss 1.0618535280227661
Epoch 49 iteration 500 loss 1.016508936882019
Epoch 49 Training loss 1.1140278308835816
Epoch 50 iteration 0 loss 1.0361311435699463
Epoch 50 iteration 500 loss 0.9854018092155457
Epoch 50 Training loss 1.103240595894343
Evaluation loss 1.1672870177730559
Epoch 51 iteration 0 loss 0.9864965081214905
Epoch 51 iteration 500 loss 0.9881776571273804
Epoch 51 Training loss 1.0957198824253787
Epoch 52 iteration 0 loss 1.0605281591415405
Epoch 52 iteration 500 loss 0.9425312876701355
Epoch 52 Training loss 1.0835173193434786
Epoch 53 iteration 0 loss 1.0059244632720947
Epoch 53 iteration 500 loss 1.012495756149292
Epoch 53 Training loss 1.0746916196923277
Epoch 54 iteration 0 loss 1.0311338901519775
Epoch 54 iteration 500 loss 0.9643288254737854
Epoch 54 Training loss 1.0639997538991715
Epoch 55 iteration 0 loss 0.9999511241912842
Epoch 55 iteration 500 loss 0.9228934645652771
Epoch 55 Training loss 1.0586248801784999
Epoch 56 iteration 0 loss 0.9765851497650146
Epoch 56 iteration 500 loss 0.9912199974060059
Epoch 56 Training loss 1.047446841975497
Epoch 57 iteration 0 loss 0.9578201770782471
Epoch 57 iteration 500 loss 0.9974180459976196
Epoch 57 Training loss 1.0401080069686683
Epoch 58 iteration 0 loss 0.9840922951698303
Epoch 58 iteration 500 loss 0.9113590121269226
Epoch 58 Training loss 1.0313254258577484
Epoch 59 iteration 0 loss 0.9581423997879028
Epoch 59 iteration 500 loss 0.9251217246055603
Epoch 59 Training loss 1.0204329001920651
Epoch 60 iteration 0 loss 0.9331505298614502
Epoch 60 iteration 500 loss 0.9187439680099487
Epoch 60 Training loss 1.016994911563405
Epoch 61 iteration 0 loss 0.981872022151947
Epoch 61 iteration 500 loss 0.9553562998771667
Epoch 61 Training loss 1.0054522911305408
Epoch 62 iteration 0 loss 0.946431040763855
Epoch 62 iteration 500 loss 0.95720374584198
Epoch 62 Training loss 1.0011483115534685
Epoch 63 iteration 0 loss 0.938338041305542
Epoch 63 iteration 500 loss 0.8557417392730713
Epoch 63 Training loss 0.9887521455122721
Epoch 64 iteration 0 loss 0.9388967752456665
Epoch 64 iteration 500 loss 0.8844293355941772
Epoch 64 Training loss 0.9827392641244528
Epoch 65 iteration 0 loss 0.9198037385940552
Epoch 65 iteration 500 loss 0.9046379923820496
Epoch 65 Training loss 0.9780227733258282
Epoch 66 iteration 0 loss 0.925284206867218
Epoch 66 iteration 500 loss 0.8914257884025574
Epoch 66 Training loss 0.9681966190858181
Epoch 67 iteration 0 loss 0.9229926466941833
Epoch 67 iteration 500 loss 0.874291718006134
Epoch 67 Training loss 0.9661976679190136
Epoch 68 iteration 0 loss 0.9343907833099365
Epoch 68 iteration 500 loss 0.8881480693817139
Epoch 68 Training loss 0.9585099018026118
Epoch 69 iteration 0 loss 0.9092890620231628
Epoch 69 iteration 500 loss 0.8629350066184998
Epoch 69 Training loss 0.9508208637017396
Epoch 70 iteration 0 loss 0.8868396282196045
Epoch 70 iteration 500 loss 0.8765168190002441
Epoch 70 Training loss 0.9432473274468882
Epoch 71 iteration 0 loss 0.8532478213310242
Epoch 71 iteration 500 loss 0.826164186000824
Epoch 71 Training loss 0.9381867783906953
Epoch 72 iteration 0 loss 0.878237783908844
Epoch 72 iteration 500 loss 0.8488359451293945
Epoch 72 Training loss 0.9334934804358592
Epoch 73 iteration 0 loss 0.8713428378105164
Epoch 73 iteration 500 loss 0.8377509713172913
Epoch 73 Training loss 0.9248339132723351
Epoch 74 iteration 0 loss 0.873065710067749
Epoch 74 iteration 500 loss 0.8362513184547424
Epoch 74 Training loss 0.9193484013685149
Epoch 75 iteration 0 loss 0.7983145713806152
Epoch 75 iteration 500 loss 0.8456605076789856
Epoch 75 Training loss 0.9143656854413621
Epoch 76 iteration 0 loss 0.8218435049057007
Epoch 76 iteration 500 loss 0.8419947624206543
Epoch 76 Training loss 0.9069888797938304
Epoch 77 iteration 0 loss 0.8373231291770935
Epoch 77 iteration 500 loss 0.8114780783653259
Epoch 77 Training loss 0.9042000555541497
Epoch 78 iteration 0 loss 0.8416014909744263
Epoch 78 iteration 500 loss 0.8539872169494629
Epoch 78 Training loss 0.8970924556067422
Epoch 79 iteration 0 loss 0.822470486164093
Epoch 79 iteration 500 loss 0.7818557024002075
Epoch 79 Training loss 0.8954583991136714
Epoch 80 iteration 0 loss 0.8531654477119446
Epoch 80 iteration 500 loss 0.8344308733940125
Epoch 80 Training loss 0.8861387732564512
Epoch 81 iteration 0 loss 0.860893189907074
Epoch 81 iteration 500 loss 0.7931697368621826
Epoch 81 Training loss 0.8816995690071902
Epoch 82 iteration 0 loss 0.8687481880187988
Epoch 82 iteration 500 loss 0.7674760222434998
Epoch 82 Training loss 0.8798118944121449
Epoch 83 iteration 0 loss 0.8461892008781433
Epoch 83 iteration 500 loss 0.790038526058197
Epoch 83 Training loss 0.8694967462520443
Epoch 84 iteration 0 loss 0.8096769452095032
Epoch 84 iteration 500 loss 0.7982391715049744
Epoch 84 Training loss 0.8649632915039455
Epoch 85 iteration 0 loss 0.8365834355354309
Epoch 85 iteration 500 loss 0.7786751985549927
Epoch 85 Training loss 0.8611366375337043
Epoch 86 iteration 0 loss 0.8260359168052673
Epoch 86 iteration 500 loss 0.7398275136947632
Epoch 86 Training loss 0.8586594959231909
Epoch 87 iteration 0 loss 0.8266860246658325
Epoch 87 iteration 500 loss 0.7295836806297302
Epoch 87 Training loss 0.8527324879214611
Epoch 88 iteration 0 loss 0.7657225728034973
Epoch 88 iteration 500 loss 0.7590709328651428
Epoch 88 Training loss 0.8472464314101893
Epoch 89 iteration 0 loss 0.7865862846374512
Epoch 89 iteration 500 loss 0.804772138595581
Epoch 89 Training loss 0.8433423174076622
Epoch 90 iteration 0 loss 0.7977380752563477
Epoch 90 iteration 500 loss 0.7664549350738525
Epoch 90 Training loss 0.8391216668559893
Epoch 91 iteration 0 loss 0.8242974281311035
Epoch 91 iteration 500 loss 0.7586526274681091
Epoch 91 Training loss 0.8350955967898613
Epoch 92 iteration 0 loss 0.7432764172554016
Epoch 92 iteration 500 loss 0.8062140345573425
Epoch 92 Training loss 0.8311378375884593
Epoch 93 iteration 0 loss 0.7997678518295288
Epoch 93 iteration 500 loss 0.7199326753616333
Epoch 93 Training loss 0.8252114816967426
Epoch 94 iteration 0 loss 0.7635942101478577
Epoch 94 iteration 500 loss 0.7079055905342102
Epoch 94 Training loss 0.8235959888185962
Epoch 95 iteration 0 loss 0.7742218971252441
Epoch 95 iteration 500 loss 0.738625168800354
Epoch 95 Training loss 0.8169184437844808
Epoch 96 iteration 0 loss 0.7915105223655701
Epoch 96 iteration 500 loss 0.7232112884521484
Epoch 96 Training loss 0.811307575136044
Epoch 97 iteration 0 loss 0.7911934852600098
Epoch 97 iteration 500 loss 0.745033323764801
Epoch 97 Training loss 0.8127360836219556
Epoch 98 iteration 0 loss 0.7677562832832336
Epoch 98 iteration 500 loss 0.7415921092033386
Epoch 98 Training loss 0.8080242738178477
Epoch 99 iteration 0 loss 0.7812052369117737
Epoch 99 iteration 500 loss 0.7319413423538208
Epoch 99 Training loss 0.8043054426229539
Epoch 100 iteration 0 loss 0.7394666075706482
Epoch 100 iteration 500 loss 0.7425164580345154
Epoch 100 Training loss 0.8013571414808753
Evaluation loss 0.8130458785905558
Epoch 101 iteration 0 loss 0.7540783882141113
Epoch 101 iteration 500 loss 0.7367751598358154
Epoch 101 Training loss 0.7957811313370509
Epoch 102 iteration 0 loss 0.7196013927459717
Epoch 102 iteration 500 loss 0.6945835947990417
Epoch 102 Training loss 0.7906426663184966
Epoch 103 iteration 0 loss 0.7630685567855835
Epoch 103 iteration 500 loss 0.6944538354873657
Epoch 103 Training loss 0.7882730398154664
Epoch 104 iteration 0 loss 0.7134157419204712
Epoch 104 iteration 500 loss 0.74537593126297
Epoch 104 Training loss 0.7860671504111497
Epoch 105 iteration 0 loss 0.7597986459732056
Epoch 105 iteration 500 loss 0.7277376651763916
Epoch 105 Training loss 0.7810746626260536
Epoch 106 iteration 0 loss 0.7150493860244751
Epoch 106 iteration 500 loss 0.7405678629875183
Epoch 106 Training loss 0.7778148535252268
Epoch 107 iteration 0 loss 0.7096103429794312
Epoch 107 iteration 500 loss 0.7220456004142761
Epoch 107 Training loss 0.7765383308579782
Epoch 108 iteration 0 loss 0.7453286647796631
Epoch 108 iteration 500 loss 0.6827071309089661
Epoch 108 Training loss 0.771828171381296
Epoch 109 iteration 0 loss 0.6829768419265747
Epoch 109 iteration 500 loss 0.7154251933097839
Epoch 109 Training loss 0.7695830236897697
Epoch 110 iteration 0 loss 0.722434401512146
Epoch 110 iteration 500 loss 0.7154301404953003
Epoch 110 Training loss 0.7648380817991757
Epoch 111 iteration 0 loss 0.7442435026168823
Epoch 111 iteration 500 loss 0.6423325538635254
Epoch 111 Training loss 0.7616600301560927
Epoch 112 iteration 0 loss 0.7071460485458374
Epoch 112 iteration 500 loss 0.7302472591400146
Epoch 112 Training loss 0.7588099653474677
Epoch 113 iteration 0 loss 0.731406033039093
Epoch 113 iteration 500 loss 0.6625105738639832
Epoch 113 Training loss 0.7561673885898474
Epoch 114 iteration 0 loss 0.724568247795105
Epoch 114 iteration 500 loss 0.6820351481437683
Epoch 114 Training loss 0.7539073121070888
Epoch 115 iteration 0 loss 0.6992641687393188
Epoch 115 iteration 500 loss 0.6801736354827881
Epoch 115 Training loss 0.7516719803910432
Epoch 116 iteration 0 loss 0.6960390210151672
Epoch 116 iteration 500 loss 0.6813892126083374
Epoch 116 Training loss 0.7479871914292215
Epoch 117 iteration 0 loss 0.6633367538452148
Epoch 117 iteration 500 loss 0.6134429574012756
Epoch 117 Training loss 0.7444758936016801
Epoch 118 iteration 0 loss 0.6863762736320496
Epoch 118 iteration 500 loss 0.6920523047447205
Epoch 118 Training loss 0.7400842625768795
Epoch 119 iteration 0 loss 0.689449667930603
Epoch 119 iteration 500 loss 0.7047742605209351
Epoch 119 Training loss 0.7379401736404623
Epoch 120 iteration 0 loss 0.7363123297691345
Epoch 120 iteration 500 loss 0.6468561887741089
Epoch 120 Training loss 0.7363262054784875
Epoch 121 iteration 0 loss 0.7053201198577881
Epoch 121 iteration 500 loss 0.6334596276283264
Epoch 121 Training loss 0.7331407053567305
Epoch 122 iteration 0 loss 0.7270336151123047
Epoch 122 iteration 500 loss 0.6559562683105469
Epoch 122 Training loss 0.729244610551721
Epoch 123 iteration 0 loss 0.6622068285942078
Epoch 123 iteration 500 loss 0.6599937081336975
Epoch 123 Training loss 0.7288847545296716
Epoch 124 iteration 0 loss 0.7071743607521057
Epoch 124 iteration 500 loss 0.6815555691719055
Epoch 124 Training loss 0.7274926877685234
Epoch 125 iteration 0 loss 0.7439572811126709
Epoch 125 iteration 500 loss 0.673474133014679
Epoch 125 Training loss 0.7217855175066641
Epoch 126 iteration 0 loss 0.6675475239753723
Epoch 126 iteration 500 loss 0.6556931734085083
Epoch 126 Training loss 0.7189844807094804
Epoch 127 iteration 0 loss 0.6661986708641052
Epoch 127 iteration 500 loss 0.6441670656204224
Epoch 127 Training loss 0.7177228928729062
Epoch 128 iteration 0 loss 0.6266466379165649
Epoch 128 iteration 500 loss 0.6757206320762634
Epoch 128 Training loss 0.712971736295992
Epoch 129 iteration 0 loss 0.6534844040870667
Epoch 129 iteration 500 loss 0.6352728009223938
Epoch 129 Training loss 0.7137009887221181
Epoch 130 iteration 0 loss 0.6692954301834106
Epoch 130 iteration 500 loss 0.6136940717697144
Epoch 130 Training loss 0.707448539007387
Epoch 131 iteration 0 loss 0.6911685466766357
Epoch 131 iteration 500 loss 0.6910461187362671
Epoch 131 Training loss 0.7072084314005552
Epoch 132 iteration 0 loss 0.6625894904136658
Epoch 132 iteration 500 loss 0.676133930683136
Epoch 132 Training loss 0.7046438151039027
Epoch 133 iteration 0 loss 0.6767389178276062
Epoch 133 iteration 500 loss 0.6022269129753113
Epoch 133 Training loss 0.7016673204264923
Epoch 134 iteration 0 loss 0.7002029418945312
Epoch 134 iteration 500 loss 0.6056179404258728
Epoch 134 Training loss 0.6982308283758855
Epoch 135 iteration 0 loss 0.6660704016685486
Epoch 135 iteration 500 loss 0.6187354326248169
Epoch 135 Training loss 0.6955358367493972
Epoch 136 iteration 0 loss 0.7052252888679504
Epoch 136 iteration 500 loss 0.6196271181106567
Epoch 136 Training loss 0.6983756656033084
Epoch 137 iteration 0 loss 0.6464589834213257
Epoch 137 iteration 500 loss 0.5709037780761719
Epoch 137 Training loss 0.6929460661092829
Epoch 138 iteration 0 loss 0.6531716585159302
Epoch 138 iteration 500 loss 0.6473918557167053
Epoch 138 Training loss 0.6909717761274798
Epoch 139 iteration 0 loss 0.661399781703949
Epoch 139 iteration 500 loss 0.5822514295578003
Epoch 139 Training loss 0.6877872386196606
Epoch 140 iteration 0 loss 0.6847108006477356
Epoch 140 iteration 500 loss 0.6045734882354736
Epoch 140 Training loss 0.6863274994355828
Epoch 141 iteration 0 loss 0.6634388566017151
Epoch 141 iteration 500 loss 0.6096868515014648
Epoch 141 Training loss 0.6831929687371145
Epoch 142 iteration 0 loss 0.6742739677429199
Epoch 142 iteration 500 loss 0.6183885335922241
Epoch 142 Training loss 0.6812224518252841
Epoch 143 iteration 0 loss 0.6753737926483154
Epoch 143 iteration 500 loss 0.5612662434577942
Epoch 143 Training loss 0.6791165840698635
Epoch 144 iteration 0 loss 0.6192144155502319
Epoch 144 iteration 500 loss 0.6639806032180786
Epoch 144 Training loss 0.6779386985280625
Epoch 145 iteration 0 loss 0.6651844382286072
Epoch 145 iteration 500 loss 0.5843113660812378
Epoch 145 Training loss 0.6790450120016149
Epoch 146 iteration 0 loss 0.7002686262130737
Epoch 146 iteration 500 loss 0.585415780544281
Epoch 146 Training loss 0.6749682782266198
Epoch 147 iteration 0 loss 0.5939975380897522
Epoch 147 iteration 500 loss 0.6044628024101257
Epoch 147 Training loss 0.6721406124393539
Epoch 148 iteration 0 loss 0.6048795580863953
Epoch 148 iteration 500 loss 0.6104022264480591
Epoch 148 Training loss 0.6710011421024045
Epoch 149 iteration 0 loss 0.6707007884979248
Epoch 149 iteration 500 loss 0.6155428290367126
Epoch 149 Training loss 0.668863014520299
Epoch 150 iteration 0 loss 0.6170852184295654
Epoch 150 iteration 500 loss 0.6271148324012756
Epoch 150 Training loss 0.6666583654311777
Evaluation loss 0.6567141524802815
Epoch 151 iteration 0 loss 0.6064229607582092
Epoch 151 iteration 500 loss 0.6067769527435303
Epoch 151 Training loss 0.6658764364017523
Epoch 152 iteration 0 loss 0.6033229827880859
Epoch 152 iteration 500 loss 0.6046653389930725
Epoch 152 Training loss 0.6627043828999308
Epoch 153 iteration 0 loss 0.6276624798774719
Epoch 153 iteration 500 loss 0.5686469674110413
Epoch 153 Training loss 0.6610188931476784
Epoch 154 iteration 0 loss 0.6161495447158813
Epoch 154 iteration 500 loss 0.5787109732627869
Epoch 154 Training loss 0.6608237031564234
Epoch 155 iteration 0 loss 0.6156296730041504
Epoch 155 iteration 500 loss 0.6136241555213928
Epoch 155 Training loss 0.6601708996267871
Epoch 156 iteration 0 loss 0.6367676258087158
Epoch 156 iteration 500 loss 0.6125684976577759
Epoch 156 Training loss 0.6572488561114832
Epoch 157 iteration 0 loss 0.5943743586540222
Epoch 157 iteration 500 loss 0.6088892817497253
Epoch 157 Training loss 0.6530513102432873
Epoch 158 iteration 0 loss 0.6312132477760315
Epoch 158 iteration 500 loss 0.5719904899597168
Epoch 158 Training loss 0.6542307047037696
Epoch 159 iteration 0 loss 0.6441388726234436
Epoch 159 iteration 500 loss 0.5410758256912231
Epoch 159 Training loss 0.6517251953436262
Epoch 160 iteration 0 loss 0.6008170247077942
Epoch 160 iteration 500 loss 0.6775693297386169
Epoch 160 Training loss 0.6494704568876846
Epoch 161 iteration 0 loss 0.6351537108421326
Epoch 161 iteration 500 loss 0.6094696521759033
Epoch 161 Training loss 0.6451660223800983
Epoch 162 iteration 0 loss 0.6162123084068298
Epoch 162 iteration 500 loss 0.6127580404281616
Epoch 162 Training loss 0.6455097573614572
Epoch 163 iteration 0 loss 0.63380366563797
Epoch 163 iteration 500 loss 0.5385963320732117
Epoch 163 Training loss 0.6419719247193757
Epoch 164 iteration 0 loss 0.6315809488296509
Epoch 164 iteration 500 loss 0.6105737686157227
Epoch 164 Training loss 0.6414774833428765
Epoch 165 iteration 0 loss 0.6115357279777527
Epoch 165 iteration 500 loss 0.5768589377403259
Epoch 165 Training loss 0.6398385553996117
Epoch 166 iteration 0 loss 0.6201842427253723
Epoch 166 iteration 500 loss 0.5616539716720581
Epoch 166 Training loss 0.6364566991862036
Epoch 167 iteration 0 loss 0.6097760796546936
Epoch 167 iteration 500 loss 0.5191922187805176
Epoch 167 Training loss 0.6377152620337613
Epoch 168 iteration 0 loss 0.6415517926216125
Epoch 168 iteration 500 loss 0.5646741986274719
Epoch 168 Training loss 0.6349046524174005
Epoch 169 iteration 0 loss 0.5960779190063477
Epoch 169 iteration 500 loss 0.5587354302406311
Epoch 169 Training loss 0.6327087651115453
Epoch 170 iteration 0 loss 0.573525071144104
Epoch 170 iteration 500 loss 0.5668031573295593
Epoch 170 Training loss 0.6311483931817605
Epoch 171 iteration 0 loss 0.6054481863975525
Epoch 171 iteration 500 loss 0.6193026900291443
Epoch 171 Training loss 0.6297210989871891
Epoch 172 iteration 0 loss 0.5725279450416565
Epoch 172 iteration 500 loss 0.5870577096939087
Epoch 172 Training loss 0.630531075839191
Epoch 173 iteration 0 loss 0.618507981300354
Epoch 173 iteration 500 loss 0.5655175447463989
Epoch 173 Training loss 0.6253485212726325
Epoch 174 iteration 0 loss 0.6010618805885315
Epoch 174 iteration 500 loss 0.5972918272018433
Epoch 174 Training loss 0.625110051075673
Epoch 175 iteration 0 loss 0.6050177812576294
Epoch 175 iteration 500 loss 0.5727249979972839
Epoch 175 Training loss 0.6245754901028149
Epoch 176 iteration 0 loss 0.609971821308136
Epoch 176 iteration 500 loss 0.5718697309494019
Epoch 176 Training loss 0.6247522305702191
Epoch 177 iteration 0 loss 0.6230632066726685
Epoch 177 iteration 500 loss 0.5268271565437317
Epoch 177 Training loss 0.6201899949967344
Epoch 178 iteration 0 loss 0.5248923897743225
Epoch 178 iteration 500 loss 0.5943036079406738
Epoch 178 Training loss 0.6193476752842509
Epoch 179 iteration 0 loss 0.6658679842948914
Epoch 179 iteration 500 loss 0.5524617433547974
Epoch 179 Training loss 0.6182718938479
Epoch 180 iteration 0 loss 0.5757883787155151
Epoch 180 iteration 500 loss 0.5122137665748596
Epoch 180 Training loss 0.6159584765731609
Epoch 181 iteration 0 loss 0.5369557738304138
Epoch 181 iteration 500 loss 0.6061877608299255
Epoch 181 Training loss 0.6155598271809033
Epoch 182 iteration 0 loss 0.5931233763694763
Epoch 182 iteration 500 loss 0.5716503262519836
Epoch 182 Training loss 0.6142580398110759
Epoch 183 iteration 0 loss 0.6230226755142212
Epoch 183 iteration 500 loss 0.5350344777107239
Epoch 183 Training loss 0.6115880344794106
Epoch 184 iteration 0 loss 0.5834236741065979
Epoch 184 iteration 500 loss 0.5867440104484558
Epoch 184 Training loss 0.6094208295693264
Epoch 185 iteration 0 loss 0.6023694276809692
Epoch 185 iteration 500 loss 0.5648134350776672
Epoch 185 Training loss 0.6096906789707265
Epoch 186 iteration 0 loss 0.5760602355003357
Epoch 186 iteration 500 loss 0.6042081713676453
Epoch 186 Training loss 0.6101944479163676
Epoch 187 iteration 0 loss 0.6048933267593384
Epoch 187 iteration 500 loss 0.5578716397285461
Epoch 187 Training loss 0.6078079864582462
Epoch 188 iteration 0 loss 0.5865512490272522
Epoch 188 iteration 500 loss 0.5678480863571167
Epoch 188 Training loss 0.6075730606190578
Epoch 189 iteration 0 loss 0.6308228373527527
Epoch 189 iteration 500 loss 0.5610036849975586
Epoch 189 Training loss 0.6028249191833381
Epoch 190 iteration 0 loss 0.6329002976417542
Epoch 190 iteration 500 loss 0.5683515667915344
Epoch 190 Training loss 0.6040577354080152
Epoch 191 iteration 0 loss 0.5902508497238159
Epoch 191 iteration 500 loss 0.5105716586112976
Epoch 191 Training loss 0.5994857476867721
Epoch 192 iteration 0 loss 0.5673339366912842
Epoch 192 iteration 500 loss 0.5264188647270203
Epoch 192 Training loss 0.6037306832121818
Epoch 193 iteration 0 loss 0.5514609813690186
Epoch 193 iteration 500 loss 0.5742241740226746
Epoch 193 Training loss 0.5991605134914191
Epoch 194 iteration 0 loss 0.5095831155776978
Epoch 194 iteration 500 loss 0.5610089302062988
Epoch 194 Training loss 0.5967281184137256
Epoch 195 iteration 0 loss 0.5875757932662964
Epoch 195 iteration 500 loss 0.5508366823196411
Epoch 195 Training loss 0.5941828790819974
Epoch 196 iteration 0 loss 0.6367263793945312
Epoch 196 iteration 500 loss 0.5381757616996765
Epoch 196 Training loss 0.5941237560591565
Epoch 197 iteration 0 loss 0.6024200916290283
Epoch 197 iteration 500 loss 0.5678775310516357
Epoch 197 Training loss 0.5922638683993326
Epoch 198 iteration 0 loss 0.564724326133728
Epoch 198 iteration 500 loss 0.5765736699104309
Epoch 198 Training loss 0.5923394315071917
Epoch 199 iteration 0 loss 0.5818908214569092
Epoch 199 iteration 500 loss 0.5664358139038086
Epoch 199 Training loss 0.5932606708257029
测试模型结果
def model_result(i):
    en_sent = " ".join([inv_en_dict[w] for w in dev_en[i]])
    print(en_sent)
    cn_sent = " ".join([inv_cn_dict[w] for w in dev_cn[i]])
    print("".join(cn_sent))

    mb_x = torch.from_numpy(np.array(dev_en[i]).reshape(1, -1)).long().to(device)
    mb_x_len = torch.from_numpy(np.array([len(dev_en[i])])).long().to(device)
    bos = torch.Tensor([[cn_dict["BOS"]]]).long().to(device)

    translation, attn = model.translate(mb_x, mb_x_len, bos)
    translation = [inv_cn_dict[i] for i in translation.data.cpu().numpy().reshape(-1)]
    trans = []
    for word in translation:
        if word != "EOS":
            trans.append(word)
        else:
            break
    print("".join(trans))

for i in range(20,25):
    model_result(i)
    print()
BOS anything else ? EOS
BOS 还 有 别 的 吗 ? EOS
还好别吃吗?

BOS i 'm sleepy . EOS
BOS 我 困 了 。 EOS
我累了。

BOS i ate caviar . EOS
BOS 我 吃 了 鱼 子 酱 。 EOS
我吃了鱼式。

BOS i like sports . EOS
BOS 我 喜 欢 运 动 。 EOS
我喜欢运动。

BOS she may come . EOS
BOS 她 可 以 来 。 EOS
她可以来。
总结: 总体来讲,我们的翻译系统还是表现得不错,基本上可以完成翻译任务,如果训练集更大,训练时间更久,模型表现一定会更好
  • 2
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值