语言模型
lstm
import torchtext
from torchtext.vocab import Vectors
import torch
import numpy as np
import random
BATCH_SIZE = 32
EMBEDDING_SIZE = 100
MAX_VOCAB_SIZE = 5000
import torch.nn as nn
class RNNModel(nn.Module):
def __init__(self, rnn_type,vocab_size, embed_size,hidden_size)
super (RNNModel,self).__init__()
self.embed = nn.Embedding(vocab_size,embed_size)
self.lstm = nn.LSTM(embed_size,hidden_size)
self.decoder = nn.Linear(hidden_size, vocab_size)
def forward(self,text,hidden):
#text: seq_length * batch_size
emb = self.embed(text)
output,hidden = self.lstm(emb,hidden)
#output: seq_len * batch_size * hidden_size
#hidden: 1*batch_szie*hidden_size , 1*betch_Size*hidden*size
output = output.view(-1,out.shape[2]) #把前两个维度拼到一起
decoded = self.decoder(output.view(-1,output.shape[2]))
return out_vocab,hid