导入相应包
使用的是text8数据集,分为train,val, test三个数据集
import torchtext
from torchtext.vocab import Vectors
import torch
import torch.nn as nn
import numpy as np
import random
batch_size = 32
embedding_size = 50
max_vocab_size = 50000
hidden_size = 100
learn_rate = 0.001
TEXT = torchtext.data.Field(lower=True)
# 专门用来处理语言模型数据集
train,val,test=torchtext.datasets.LanguageModelingDataset.splits(path='./text8',train='text8.train.txt',
validation='text8.dev.txt',test = 'text8.test.txt',text_field=TEXT)
TEXT.build_vocab(train,max_size=max_vocab_size) # 从train数据集中建立词表,按词频建立
# print(TEXT.vocab.itos[:100])
# print(TEXT.vocab.stoi['<unk>'])
train_iter,val_iter,test_iter = torchtext.data.BPTTIterator.splits((train,val,test),
batch_size=batch_size,device='cpu',bptt_len=50,repeat=False,shuffle=True)
# it=iter(trainiter)
# batch = next(it)
定义模型,LSTM
class RNNmodel(nn.Module):
def __init__(self,vocab_size,embed_size,hidden_size):
super(RNNmodel,self).__init__()
self.embed = nn.Embedding(vocab_size,embed_size)
self.lstm = nn.LSTM(embed_size,hidden_size) # 通过设置batch_first = True ,输入可以是 batch_size*seq_length其他不用变
self.decoder = nn.Linear(hidden_size,vocab_size)
self.hidden_size = hidden_size
def forward(self,text,hidden):
emb = self.embed(text) # text: seq_length*batch_size
output,hidden = self.lstm(emb,hidden) # output: seq_length*batch_size*hidden_size hidden: (1*batch_size*hidden_size,1*batch_size*hidden_size )
decoded = self.decoder(output.view(-1,output.size(-1)))
decoded = decoded.view(output.size(0),output.size(1),decoded.size(1)) # decoded: seq_length * batch_size * vocab_size
return decoded, hidden
def init_hidden(self,bsize,requires_grad=True):
weight = next(self.parameters())
return (weight.new_zeros((1,bsize,self.hidden_size),requires_grad=requires_grad),
weight.new_zeros((1,bsize,self.hidden_size),requires_grad=requires_grad))
新建一个模型,并训练
model = RNNmodel(len(TEXT.vocab),embedding_size,hidden_size)
optimizer=torch.optim.Adam(model.parameters(),lr=learn_rate)
loss_fn = nn.CrossEntropyLoss()
def repackage_hidden(h):
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
def evaluate(model,data):
model.eval()
total_loss = 0.
total_count = 0.
it = iter(data)
with torch.no_grad():
hidden = model.init_hidden(batch_size,requires_grad=False)
for i, batch in enumerate(it):
data, target = batch.text, batch.target
hidden = repackage_hidden(hidden)
output, hidden = model(data, hidden)
loss = loss_fn(output.view(-1, len(TEXT.vocab)), target.view(-1))
total_loss = loss.item()*np.multiply(*data.size())
total_count = np.multiply(*data.size())
loss = total_loss/total_count
model.train()
return loss
val_losses = []
GRAD_CLIP =5.0
for epoch in range(2):
model.train()
it = iter(train_iter)
hidden = model.init_hidden(batch_size)
for i, batch in enumerate(it):
data,target = batch.text,batch.target
hidden = repackage_hidden(hidden)
output,hidden = model(data,hidden)
loss = loss_fn(output.view(-1, len(TEXT.vocab)),target.view(-1)) # batch_size * target_class_dim batch_size
optimizer.zero_grad()
torch.nn.utils.clip_grad_norm_(model.parameters(),GRAD_CLIP) # 优化方法
loss.backward()
optimizer.step()
schedule = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.5) # 调整学习率来优化,0.5即降一半
if i % 100 == 0:
print('loss',loss.item())
if i % 1000 == 0:
val_loss = evaluate(model,val_iter)
if len(val_losses)==0 or val_loss < min(val_losses):
torch.save(model.state_dict(),'lm.pth')
print('best model saved to lm.pth')
else:
schedule.step()
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)
val_losses.append(val_loss)
使用训练好的模型生成文本
# #生成句子
# best_model = RNNmodel(50002, embed_size=embedding_size, hidden_size=hidden_size)
# hidden = best_model.load_state_dict(torch.load('lm.pth'))
# hidden =best_model.init_hidden(1)
# input = torch.randint(max_vocab_size+2, (1,1),dtype=torch.Long).to('cpu')
# words = []
# for i in range(100):
# output,hidden = best_model(input,hidden)
# word_weights = output.squeeze().exp().cpu()
# wordidx = torch.multinomial(word_weights,1)[0]
# input.fill_(wordidx)
# word = TEXT.vocab.itos[wordidx]
# words.append(word)
# print(''.join(words))