transformer排序

#!/usr/bin/python
# -*- coding: GBK -*-
import torch
from torch import nn
import numpy as np
import math
import random

class PositionEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)].requires_grad_(False)
        return self.dropout(x)


class CopyTaskModel(nn.Module):
    def __init__(self, device, d_model=128):
        super(CopyTaskModel, self).__init__()

        self.english_char_count = 100
        self.chinese_char_count = 100

        self.e_embedding = nn.Embedding(num_embeddings=self.english_char_count, embedding_dim=128)
        self.c_embedding = nn.Embedding(num_embeddings=self.chinese_char_count, embedding_dim=128)
        self.transformer = nn.Transformer(d_model=128, num_encoder_layers=3, num_decoder_layers=2, dim_feedforward=512,
                                          batch_first=True)
        self.position_encoding = PositionEncoding(d_model, dropout=0)

        self.predictor = nn.Linear(128, self.english_char_count)
        self.device = device

    def forward(self, src, tgt):
        tgt_mask = (nn.Transformer.generate_square_subsequent_mask(tgt.size()[-1]) == -torch.inf).to(self.device)
        # tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size()[-1]).to(self.device)
        # print("tgt_mask:", tgt_mask)
        src_key_padding_mask = self.get_key_padding_mask(src)
        tgt_key_padding_mask = self.get_key_padding_mask(tgt)

        src = self.c_embedding(src)
        tgt = self.e_embedding(tgt)
        src = self.position_encoding(src)
        tgt = self.position_encoding(tgt)

        out = self.transformer(src, tgt, tgt_mask=tgt_mask,
                               src_key_padding_mask=src_key_padding_mask,
                               tgt_key_padding_mask=tgt_key_padding_mask)
        return out

    # @staticmethod
    def get_key_padding_mask(self, tokens):
        key_padding_mask = torch.zeros(tokens.size())
        key_padding_mask[tokens == 2] = -torch.inf
        key_padding_mask = key_padding_mask.to(self.device)
        return key_padding_mask


class SentenceDataSet:
    def __init__(self):
        self.BOS = 0
        self.EOS = 1
        self.PAD = 2
        pass

    def load_vec_sentence_sample(self, batch_size=16):

        sample_batch = []
        while True:
            eng_len_max = 20
            chinese_len_max = 20

            sample_s = [random.randint(3, 99) for _ in range(random.randint(10, eng_len_max-2))]
            sample_d = [0, ] + sorted(sample_s) + [1,]
            sample_s = [0, ] + sample_s + [1,]
            sample_s += [2,]*(eng_len_max-len(sample_s))
            sample_d += [2,]*(chinese_len_max-len(sample_d))

            sample_batch.append([sample_s, sample_d])
            #print("sample_s:", sample_s)
            #print("sample_d:", sample_d)
            if len(sample_batch) == batch_size:
                yield sample_batch, eng_len_max, chinese_len_max
                eng_len_max = 0
                chinese_len_max = 0
                sample_batch = []

    def cvt_batch(self, batch):
        c_batch = []
        e_batch = []
        for c, e in batch:
            c_batch.append(c)
            e_batch.append(e)
        return c_batch, e_batch


def evaluate(model, src, device):
    with torch.no_grad():
        print("src:", src)
        dst = torch.LongTensor([0,]*21).to(device)
        for i in range(20):
            input_dst = dst[0:i + 1]
            out = model(src.reshape(1,*src.shape), input_dst.reshape(1, *input_dst.shape))
            out = model.predictor(out)
            front = torch.argmax(out[0], dim=1)
            dst[1:2+i]=front[0:1+i]
        print("dst:", dst)

def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")
    load_from_file = False
    if load_from_file:
        print("load model...")
        net = torch.load("model/model0.pth")
    else:
        net = CopyTaskModel(device=device)
        for m in net.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_normal_(m.weight)
    if torch.cuda.is_available():
        net.cuda(0)
    criteria = nn.CrossEntropyLoss()
    lr = 2e-3
    optim = torch.optim.Adam(net.parameters(), lr=lr, eps=1e-4)

    dataset = SentenceDataSet()

    for b in range(1):
        result_list = []

        index = 0
        batch_size = 16
        for sample_batch, e_max, c_max in dataset.load_vec_sentence_sample(batch_size=batch_size):

            c_batch, e_batch = dataset.cvt_batch(sample_batch)
            c_tensor = torch.LongTensor(c_batch).to(device)
            e_tensor = torch.LongTensor(e_batch).to(device)

            src, tgt, tgt_y, n_tokens = c_tensor, e_tensor[:, 0:-1], e_tensor[:, 1:], torch.sum(e_tensor != 2)

            out = net(src, tgt)
            out = net.predictor(out)
            loss = criteria(out.contiguous().view(-1, out.size(-1)), tgt_y.contiguous().view(-1)) / n_tokens

            optim.zero_grad()
            loss.backward()
            optim.step()
            result_list.append(float(loss))

            if len(result_list) > 5:
                result_list = result_list[-5:]

            avg_loss = np.sum(np.array(result_list)) / float(len(result_list))

            index += 1
            if index % 40 == 0:
                print("src[0]", src[0])
                #print("dst[0]", out[0].shape)
                print("dst[0]", torch.argmax(out[0], dim=1))
                print("tgt_y ", tgt_y[0])
                print("batch %02d batch_idx %4d, loss %f" % (b, index, avg_loss))
                if index>40000:
                    break

        torch.save(net, 'model/model%d.pth' % (b))

    for sample_batch, e_max, c_max in dataset.load_vec_sentence_sample(batch_size=batch_size):
        c_batch, e_batch = dataset.cvt_batch(sample_batch)
        c_tensor = torch.LongTensor(c_batch).to(device)
        e_tensor = torch.LongTensor(e_batch).to(device)

        src, tgt, tgt_y = c_tensor, e_tensor[:, 0:-1], e_tensor[:, 1:]

        print("****", tgt[0])
        evaluate(net, src[0], device)

train()

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值