pytorch LSTM从头开始训练一个语言模型代码及其注释

利用lstm 和gru 训练一个语言模型 

这个语言模型 就是输入一个词预测下一个词是什么 

  
        **********************************************************************************************************
        emb: torch.Size([32, 32, 650])
        hidden ([2,32,650],[2,32,650])
        这里的Hidden 是包括  hidden  和cell  (hidden,cell)
        
        output torch.Size([32, 32, 650]) 是 [seq_len,batch_Size, embed_size]  
        RNN 的输出 是前面的Hidden  和当前的Input 预测出来的 prey  shape  自认和输入的shape 一样  [seq_len,batch_size,embed_szie]
        抛开批维度 来看 输入就是  [seq_lenth,embed_Size]>>>>output 输出 [seq_len,embed_Size]  
        hidden  则是最后的隐藏状态 维度 [1,hidden_Size] 这里我们一般 hidden_size==embed_size
        因为每次一个序列输出之后  我们都只是拿到最后的隐藏状态    中间状态我们都没有拿  和cell状态 
        
        如果加上批处理维度  在加上2层  [layers,batch_size,hidden_size]  cell 的维度一样  GRu 没有这cell 只有一个状态 
        
        如果是双层的化  hidden 的是 size ==[layers*2,batch_Size,hidden_size] 
        
        一般我们会进行一个双向的合并   hidden[-1]+hidden[-2]  进行相加  
        
        模型的本质就是上一个隐藏状态 [1,1,hidden_Size]+[1,1,embed_size]>>>>output[1,1,embed_Size]
        
        根据 LsTM的推导公式是可以看出 我么计算当前  Hidden的时候 只用到上个Hidden 没有用到cell  cell是根据 ft  it  ct算出来的 
        ************************************************************************************************************
        

"""
https://github.com/pytorch/text

学习语言模型,以及如何训练一个语言模型
学习torchtext的基本使用方法
构建 vocabulary
word to inde 和 index to word
学习torch.nn的一些基本模型
Linear
RNN
LSTM
GRU
RNN的训练技巧
Gradient Clipping
如何保存和读取模型
我们会使用 torchtext 来创建vocabulary, 然后把数据读成batch的格式。请大家自行阅读README来学习torchtext。

"""


import torchtext
from torchtext.vocab import Vectors
import torch
import numpy as np
import random

USE_CUDA=torch.cuda.is_available()
device=torch.device('cuda' if USE_CUDA else 'cpu')

#为了保证实验结果可以复现  我们经常会吧各种random seed 固定在某一个值
random.seed(53113)
np.random.seed(53113)
torch.manual_seed(53113)

if USE_CUDA:
    torch.cuda.manual_seed(53113)
    
BATCH_SIZE=32
EMBEDDING_SIZE=650
MAX_VOCAB_SIZE=50000

"""
我们会继续使用上次的text8作为我们的训练,验证和测试数据
TorchText的一个重要概念是Field,它决定了你的数据会如何被处理。我们使用TEXT这个field来处理文本数据。
我们的TEXT field有lower=True这个参数,所以所有的单词都会被lowercase。
torchtext提供了LanguageModelingDataset这个class来帮助我们处理语言模型数据集。
build_vocab可以根据我们提供的训练数据集来创建最高频单词的单词表,max_size帮助我们限定单词总量。
BPTTIterator可以连续地得到连贯的句子,BPTT的全程是back propagation through time。
"""
TEXT=torchtext.data.Field(lower=True)
train,val,test=torchtext.datasets.LanguageModelingDataset.splits(path='.',
                                                                train='/root/torch/data/text8/text8.train.txt',
                                                                validation='/root/torch/data/text8/text8.dev.txt',
                                                                test='/root/torch/data/text8/text8.test.txt',text_field=TEXT)

TEXT.build_vocab(train,max_size=MAX_VOCAB_SIZE)
print('vocabulary size:{}'.format(len(TEXT.vocab)))
#vocabulary size:50002

VOCAB_SIZE=len(TE
  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值