RNN模型详细实现(pytorch+jupyter)

RNN网络构建

从零构建一个RNN网络模型

1、初始化模型参数

RNN模型计算表达式为 o t = H t W h q + b q o_t=H_{t}W_{hq}+b_q ot=HtWhq+bq ,其中 o t o_t ot t t t时间步的输出,维度为 n × d n \times d n×d(批量大小 乘以 时间步数), H t H_{t} Ht为隐变量,维度为 d × h d\times h d×h(批量大小 乘以 隐变量大小)。隐变量的更新公式为 H t = ϕ ( X t W x h + H t − 1 W h h + b h ) H_t=\phi(X_tW_{xh}+H_{t-1}W_{hh}+b_h) Ht=ϕ(XtWxh+Ht1Whh+bh) ,其中 X t X_t Xt的维度为 n × d n\times d n×d W x h W_{xh} Wxh的维度为 d × h d\times h d×h

需要哪些参数:

  • 隐藏层参数: W x h ∈ R d × h W_{xh}\in R^{d\times h} WxhRd×h W h h ∈ R h × h W_{hh}\in R^{h\times h} WhhRh×h b h ∈ R h × 1 b_h\in R^{h\times 1} bhRh×1
  • 输出层参数: W h q ∈ R h × d W_{hq}\in R^{h\times d} WhqRh×d b q ∈ R d × 1 b_q\in R^{d\times 1} bqRd×1

定义初始化函数:

  • Input: num_inputs(即上式中的d),num_hiddens(隐藏层数量)
  • Output: 初始化后的参数列表
# 初始化模型参数
def get_params(num_inputs, num_hiddens, device):
    num_outputs = num_inputs
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01
    
    # 初始化隐藏层参数
    w_xh = normal(num_inputs, num_hiddens)
    w_hh = normal(num_hiddens, num_hiddens)
    b_h = torch.zeros(num_hiddens, device=device)

    # 初始化输出层参数
    w_hq = normal(num_hiddens, num_outputs)
    b_q = normal(num_outputs, device=device)

    # 初始化输出
    outpus = [w_xh, w_hh, b_h, w_hq, b_q]

    for output in outpus:
        output.requires_grad_(True)
        
    return outpus

2、循环神经网络模型

建立一个循环神经网络模型,首先需要初始化隐变量(即 H 0 H_0 H0),然后将数据输入网络中进行计算。

  1. 首先初始化隐变量状态,以及如何进行前向传播计算
# 初始化隐变量H_0,维度为d*h批量大小*隐变量大小)
def init_rnn(batch_size, num_hiddens, device):
    return (torch.zeros(batch_size, num_hiddens, device=device),) # 这里为什么把返回值设置为一个tuple呢?

def rnn(inputs, state, params):
    w_xh, w_hh, b_h, w_hq, b_q = params
    H, = state # 因为要传入的state也是一个元组
    outputs = []
    for X in inputs:
        H = torch.tanh(torch.mm(X, w_xh) + torch.mm(H, w_hh) + b_h)
        Y = torch.mm(H, w_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)
  1. 有了计算函数,就可以建立模型类,用于实例化模型网络了

要实例化一个网络模型,需要哪些参数?

  • 数据集维度、隐藏层个数、设备、参数、初始化隐状态、计算函数

代码如下所示:

# 建立模型

# 初始化隐变量H_0,维度为d*h批量大小*隐变量大小)
def init_rnn_state(batch_size, num_hiddens, device):
    return (torch.zeros(batch_size, num_hiddens, device=device),) # 这里为什么把返回值设置为一个tuple呢?

def rnn(inputs, state, params):
    w_xh, w_hh, b_h, w_hq, b_q = params
    H, = state # 因为要传入的state也是一个元组
    outputs = []
    for X in inputs:
        H = torch.tanh(torch.mm(X, w_xh) + torch.mm(H, w_hh) + b_h)
        Y = torch.mm(H, w_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)

# 封装成类
class RNNModelScratch:
    def __init__(self, vocab_size, num_hiddens, device, 
                 get_params, init_state, forward_in):
        self.vocab_size = vocab_size
        self.num_hiddens = num_hiddens
        self.params = get_params(vocab_size, num_hiddens, device)
        self.init_state = init_state
        self.forward_in = forward_in

    def __call__(self, X, state):
        X = F.one_hot(X.T, self.vocab_size).type(torch.float32)
        return self.forward_in(X, state, self.params)
    
    def begin_state(self, batch_size, device):
        return self.init_state(batch_size, self.num_hiddens, device)

3、使用网络模型进行预测

网络模型预测即使用训练出的模型进行预测,使用者给出初始化语句以及预测的步长,便可根据给出的语句实现续写,具体实现如下:

# 如何使用网络进行预测
def predict_char(prefix, num_preds, net, vocab, device):
    # prefix: 预热语句
    # num_preds:预测步长,单位:字母
    # net:网络模型
    # vocab:词表
    # device:设备
    # 预热部分,更新隐变量参数
    state = net.begin_state(batch_size=1, device=device)
    outputs = [vocab[prefix[0]]]
    for y in prefix[1:]:
        x = torch.tensor([outputs[-1]], device=device).reshape((1, 1))     
        _, state = net(x, state)
        outputs.append(vocab[y])
  
    for i in range(num_preds):
        final_char = torch.tensor(outputs[-1],device=device).reshape(1,1)
        final_char, state = net(final_char, state)
        outputs.append(int(final_char.argmax(dim=1).reshape(1)))
    return ''.join([vocab.idx_to_token[i] for i in outputs])

编写测试代码:

num_hiddens = 512
device = 'cuda:0'
net = RNNModelScratch(len(vocab), num_hiddens, device, 
                      get_params, init_rnn_state, rnn)
predict_char('hello', 10, net, vocab, device)

输出结果如下:

'hellonznqfkkkkk'

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值