先丢个demo就跑
model
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
class LSTM(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size=128, dropout_rate=0.2, layer_num=2, max_seq_len=128):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.layer_num = layer_num
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers=layer_num, batch_first=True, dropout=dropout_rate,
bidirectional=False)
self.fc = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(dropout_rate)
self.init_weights()
def LSTM_leyer(self, x, hidden, lens):
x = pack_padded_sequence(x, lens, batch_first=True)
x, _ = self.lstm(x, hidden)
x, _ = pad_packed_sequence(x)
return torch.tensor(x)
def init_weights(self):
for p in self.lstm.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
else:
p.data.zero_()
def init_hidden(self, batch_size):
weight = next(self.parameters())
return (weight.new_zeros(self.layer_num, batch_size, self.hidden_size),
weight.new_zeros(self.layer_num, batch_size, self.hidden_size))
def forward(self, x, lens, hidden):
x = self.embed(x)
x = self.LSTM_leyer(x, hidden, lens)
x = self.dropout(x)
output = self.fc(x)
return output
utils
from torchtext.legacy import data
from torchtext.legacy.data import BucketIterator
import os
def read_data(path, max_length):
with open(path, 'r', encoding="utf8") as f:
poetries_list = []
poetry = []
for line in f:
line = line.strip()
if not line:
if len(poetry) + len(line) <= max_length:
if poetry:
poetries_list.append(poetry)
poetry = []
else:
poetry.append(line)
if poetry:
poetries_list.append(poetry)
return poetries_list
class PoetryDataset(data.Dataset):
def __init__(self, text_field, path, max_length, **kwargs):
fields = [("text", text_field)]
raw_data = read_data(path, max_length)
examples = []
for text in raw_data:
examples.append(data.Example.fromlist([text], fields))
super(PoetryDataset, self).__init__(examples, fields, **kwargs)
def data_loader(eos_token="[EOS]", batch_size=32, device="cpu", data_path='data', max_length=128):
TEXT = data.Field(eos_token=eos_token, batch_first=True, include_lengths=True)
data_set = PoetryDataset(TEXT, os.path.join(data_path, "poetryFromTang.txt"), max_length)
train_data, dev_data, test_data = data_set.split([0.8, 0.1, 0.1])
TEXT.build_vocab(train_data)
train_iter, dev_iter, test_iter = BucketIterator.splits(
(train_data, dev_data, test_data),
batch_sizes=(batch_size, batch_size, batch_size),
device=device,
sort_key=lambda x: len(x.text),
sort_within_batch=True,
repeat=False,
shuffle=True
)
return train_iter, dev_iter, test_iter, TEXT
main
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm, trange
from tensorboardX import SummaryWriter
from models import LSTM
from utils import data_loader
import math
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
torch.manual_seed(1)
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
batch_size = 128
hidden_size = 512
num_layers = 2
epochs = 500
drop_out = 0.2
learning_rate = 0.01
MOMENTUM = 0.9
CLIP = 5
decay_rate = 0.05 # learning rate decay rate
EOS_TOKEN = "[EOS]"
path = 'data'
embedding_size = 300
temperature = 0.8 # Higher temperature means more diversity.
max_length = 128
def train(train_iter, dev_iter, loss_func, optimizer, epochs, clip):
perplexity = []
for epoch in trange(epochs):
model.train()
total_loss = 0
total_words = 0
for i, batch in enumerate(train_iter):
text, lens = batch.text
inputs = text[:, :-1]
targets = text[:, 1:]
init_hidden = model.init_hidden(inputs.size(0))
logits = model(inputs, lens - 1, init_hidden) # [EOS] is included in length.
loss = loss_func(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
model.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
total_loss += loss.item()
total_words += lens.sum().item()
if epoch % 10 == 9:
tqdm.write("Epoch: %d, Train perplexity: %d" % (epoch + 1, math.exp(total_loss / total_words)))
writer.add_scalar('Train_Loss', total_loss, epoch)
eval(dev_iter, True, epoch)
perplexity.append(math.exp(total_loss / total_words))
lr = learning_rate / (1 + decay_rate * (epoch + 1))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
plt.plot(np.arange(len(perplexity)), np.array(perplexity))
plt.xlabel('Iterations')
plt.ylabel('Training Perplexity')
plt.title('LSTM language model')
plt.show()
def eval(data_iter, is_dev=False, epoch=None):
model.eval()
perplexity = []
with torch.no_grad():
total_words = 0
total_loss = 0
for i, batch in enumerate(data_iter):
text, lens = batch.text
inputs = text[:, :-1]
targets = text[:, 1:]
model.zero_grad()
init_hidden = model.init_hidden(inputs.size(0))
logits = model(inputs, lens - 1, init_hidden) # [EOS] is included in length.
loss = loss_func(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
total_loss += loss.item()
total_words += lens.sum().item()
if epoch is not None:
tqdm.write(
"Epoch: %d, %s perplexity %.3f" % (
epoch + 1, "Dev" if is_dev else "Test", math.exp(total_loss / total_words)))
writer.add_scalar('Dev_Loss', total_loss, epoch)
perplexity.append(math.exp(total_loss / total_words))
else:
tqdm.write(
"%s perplexity %.3f" % ("Dev" if is_dev else "Test", math.exp(total_loss / total_words)))
perplexity.append(math.exp(total_loss / total_words))
def generate(eos_idx, word, temperature=0.8):
model.eval()
with torch.no_grad():
if word in TEXT.vocab.stoi:
idx = TEXT.vocab.stoi[word]
inputs = torch.tensor([idx])
else:
print("%s is not in vocabulary, choose by random." % word)
prob = torch.ones(len(TEXT.vocab.stoi))
inputs = torch.multinomial(prob, 1)
idx = inputs[0].item()
inputs = inputs.unsqueeze(1).to(device)
lens = torch.tensor([1]).to(device)
hidden = tuple([h.to(device) for h in model.lstm.init_hidden(1)])
poetry = [TEXT.vocab.itos[idx]]
while idx != eos_idx:
logits, hidden = model(inputs, lens, hidden)
word_weights = logits.squeeze().div(temperature).exp().cpu()
idx = torch.multinomial(word_weights, 1)[0].item()
inputs.fill_(idx)
poetry.append(TEXT.vocab.itos[idx])
print("".join(poetry[:-1]))
if __name__ == "__main__":
train_iter, dev_iter, test_iter, TEXT = data_loader(EOS_TOKEN, batch_size, device, path, max_length)
pad_idx = TEXT.vocab.stoi[TEXT.pad_token]
eos_idx = TEXT.vocab.stoi[EOS_TOKEN]
model = LSTM(len(TEXT.vocab), embed_size=embedding_size, hidden_size=hidden_size,
dropout_rate=drop_out, layer_num=num_layers).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_func = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, reduction="sum")
writer = SummaryWriter("logs")
train(train_iter, dev_iter, loss_func, optimizer, epochs, CLIP)
#eval(test_iter, is_dev=False)
try:
while True:
word = input("Input the first word or press Ctrl-C to exit: ")
generate(eos_idx, word.strip(), temperature)
except:
pass
训练结果:
corpus太小,overfitting is all you need!