#!/usr/bin/python # -*- coding: GBK -*- import torch from torch import nn import numpy as np import math import random class PositionEncoding(nn.Module): def __init__(self, d_model, dropout, max_len=5000): super(PositionEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer("pe", pe) def forward(self, x): x = x + self.pe[:, :x.size(1)].requires_grad_(False) return self.dropout(x) class CopyTaskModel(nn.Module): def __init__(self, device, d_model=128): super(CopyTaskModel, self).__init__() self.english_char_count = 100 self.chinese_char_count = 100 self.e_embedding = nn.Embedding(num_embeddings=self.english_char_count, embedding_dim=128) self.c_embedding = nn.Embedding(num_embeddings=self.chinese_char_count, embedding_dim=128) self.transformer = nn.Transformer(d_model=128, num_encoder_layers=3, num_decoder_layers=2, dim_feedforward=512, batch_first=True) self.position_encoding = PositionEncoding(d_model, dropout=0) self.predictor = nn.Linear(128, self.english_char_count) self.device = device def forward(self, src, tgt): tgt_mask = (nn.Transformer.generate_square_subsequent_mask(tgt.size()[-1]) == -torch.inf).to(self.device) # tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size()[-1]).to(self.device) # print("tgt_mask:", tgt_mask) src_key_padding_mask = self.get_key_padding_mask(src) tgt_key_padding_mask = self.get_key_padding_mask(tgt) src = self.c_embedding(src) tgt = self.e_embedding(tgt) src = self.position_encoding(src) tgt = self.position_encoding(tgt) out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask) return out # @staticmethod def get_key_padding_mask(self, tokens): key_padding_mask = torch.zeros(tokens.size()) key_padding_mask[tokens == 2] = -torch.inf key_padding_mask = key_padding_mask.to(self.device) return key_padding_mask class SentenceDataSet: def __init__(self): self.BOS = 0 self.EOS = 1 self.PAD = 2 pass def load_vec_sentence_sample(self, batch_size=16): sample_batch = [] while True: eng_len_max = 20 chinese_len_max = 20 sample_s = [random.randint(3, 99) for _ in range(random.randint(10, eng_len_max-2))] sample_d = [0, ] + sorted(sample_s) + [1,] sample_s = [0, ] + sample_s + [1,] sample_s += [2,]*(eng_len_max-len(sample_s)) sample_d += [2,]*(chinese_len_max-len(sample_d)) sample_batch.append([sample_s, sample_d]) #print("sample_s:", sample_s) #print("sample_d:", sample_d) if len(sample_batch) == batch_size: yield sample_batch, eng_len_max, chinese_len_max eng_len_max = 0 chinese_len_max = 0 sample_batch = [] def cvt_batch(self, batch): c_batch = [] e_batch = [] for c, e in batch: c_batch.append(c) e_batch.append(e) return c_batch, e_batch def evaluate(model, src, device): with torch.no_grad(): print("src:", src) dst = torch.LongTensor([0,]*21).to(device) for i in range(20): input_dst = dst[0:i + 1] out = model(src.reshape(1,*src.shape), input_dst.reshape(1, *input_dst.shape)) out = model.predictor(out) front = torch.argmax(out[0], dim=1) dst[1:2+i]=front[0:1+i] print("dst:", dst) def train(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu") load_from_file = False if load_from_file: print("load model...") net = torch.load("model/model0.pth") else: net = CopyTaskModel(device=device) for m in net.modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.xavier_normal_(m.weight) if torch.cuda.is_available(): net.cuda(0) criteria = nn.CrossEntropyLoss() lr = 2e-3 optim = torch.optim.Adam(net.parameters(), lr=lr, eps=1e-4) dataset = SentenceDataSet() for b in range(1): result_list = [] index = 0 batch_size = 16 for sample_batch, e_max, c_max in dataset.load_vec_sentence_sample(batch_size=batch_size): c_batch, e_batch = dataset.cvt_batch(sample_batch) c_tensor = torch.LongTensor(c_batch).to(device) e_tensor = torch.LongTensor(e_batch).to(device) src, tgt, tgt_y, n_tokens = c_tensor, e_tensor[:, 0:-1], e_tensor[:, 1:], torch.sum(e_tensor != 2) out = net(src, tgt) out = net.predictor(out) loss = criteria(out.contiguous().view(-1, out.size(-1)), tgt_y.contiguous().view(-1)) / n_tokens optim.zero_grad() loss.backward() optim.step() result_list.append(float(loss)) if len(result_list) > 5: result_list = result_list[-5:] avg_loss = np.sum(np.array(result_list)) / float(len(result_list)) index += 1 if index % 40 == 0: print("src[0]", src[0]) #print("dst[0]", out[0].shape) print("dst[0]", torch.argmax(out[0], dim=1)) print("tgt_y ", tgt_y[0]) print("batch %02d batch_idx %4d, loss %f" % (b, index, avg_loss)) if index>40000: break torch.save(net, 'model/model%d.pth' % (b)) for sample_batch, e_max, c_max in dataset.load_vec_sentence_sample(batch_size=batch_size): c_batch, e_batch = dataset.cvt_batch(sample_batch) c_tensor = torch.LongTensor(c_batch).to(device) e_tensor = torch.LongTensor(e_batch).to(device) src, tgt, tgt_y = c_tensor, e_tensor[:, 0:-1], e_tensor[:, 1:] print("****", tgt[0]) evaluate(net, src[0], device) train()
transformer排序
最新推荐文章于 2024-09-15 23:05:39 发布