小黑公式与代码力量积蓄:WordLSTMCell

1.WordCell原理图

在这里插入图片描述

2.WordCell门控单元原始公式:

i b , e w = σ ( W i , e x b , e w + W i , b h b , c + b i ) i_{b,e}^{w} = \sigma(W_{i,e}x_{b,e}^{w} + W_{i,b}h_{b,c} + b_{i}) ib,ew=σ(Wi,exb,ew+Wi,bhb,c+bi)
f b , e w = σ ( W f , e x b , e w + W f , b h b , c + b f ) f_{b,e}^{w} = \sigma(W_{f,e}x_{b,e}^{w} + W_{f,b}h_{b,c} + b_{f}) fb,ew=σ(Wf,exb,ew+Wf,bhb,c+bf)
c ~ b , e w = t a n h ( W f , e x b , e w + W f , b h b , c + b f ) \widetilde{c}_{b,e}^{w} = tanh(W_{f,e}x_{b,e}^{w} + W_{f,b}h_{b,c} + b_{f}) c b,ew=tanh(Wf,exb,ew+Wf,bhb,c+bf)

维度分析:
W x x W_{xx} Wxx.shape:[hidden_size,hidden_size]
x b , e w x_{b,e}^{w} xb,ew.shape:[hidden_size,1](已经经过了变化,从input_size变成hidden_size)

3.WordCell门控单元简化公式:

在这里插入图片描述
维度分析:
W W T W^{W^{T}} WWT与原始公式 [ W i , e , W i , b ; W f , e , W f , b ; W f , e , W f , b ] [W_{i,e}, W_{i,b};W_{f,e},W_{f,b};W_{f,e},W_{f,b}] [Wi,e,Wi,b;Wf,e,Wf,b;Wf,e,Wf,b]
x b , e w x_{b,e}^{w} xb,ew.shape:[hidden_size,1]
h b c h_{b}^{c} hbc.shape:[hidden_size,1]
W W T W^{W^{T}} WWT.shape:[3 x hidden_size,2 x hidden_size]

4.代码实现

import torch
from torch import nn
import torch.autograd as autograd
from torch.autograd import Variable
from torch.nn import functional,init
import numpy as np

class WordLSTMCell(nn.Module):
    
    def __init__(self,input_size,hidden_size,use_bias = True):
        super(WordLSTMCell,self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.use_bias = use_bias
        self.weight_ih = nn.Parameter(
            torch.FloatTensor(input_size,3 * hidden_size)
        )
        self.weight_hh = nn.Parameter(
            torch.FloatTensor(hidden_size,3 * hidden_size)
        )
        if use_bias:
            self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
        else:
            self.register_parameter('bias',None)
        self.reset_parameters()
    def reset_parameters(self):
        # 正交化参数
        init.orthogonal_(self.weight_ih.data)
        weight_hh_data = torch.eye(self.hidden_size)    # [hidden_size,hidden_size]
        weight_hh_data = weight_hh_data.repeat(1,3)   # [hidden_size,3 * hidden_size]
        with torch.no_grad():
            self.weight_hh.set_(weight_hh_data)
        if self.use_bias:
            init.constant_(self.bias.data,val = 0)
    
    def forward(self,input_,hx):
        # input_:[num_words,word_emb_dim]
        # h_0,c_0:[1,hidden_size]
        h_0,c_0 = hx
        batch_size = h_0.size(0)
        # bias_batch:[1,3 * hidden_size]
        bias_batch = self.bias.unsqueeze(0).expand(batch_size,*self.bias.size())
        # wh_b:[1,3 * hidden_size]
        wh_b = torch.addmm(bias_batch,h_0,self.weight_hh)
        # wi:[num_words,3 * hidden_size]
        wi = torch.mm(input_,self.weight_ih)
        # f,i,g:[num_words,hidden_size]
        f,i,g = torch.split(wh_b + wi,split_size_or_sections = self.hidden_size,dim = 1)
        # c_l:[num_words,hidden_size]
        c_l = torch.sigmoid(f) * c_0 + torch.sigmoid(i) * torch.tanh(g)
        return c_l

num_words = 5
word_emb_dim = 100
hidden_size = 100
input_ = torch.randn([num_words,word_emb_dim])
h_0 = torch.randn([1,hidden_size])
c_0 = torch.randn([1,hidden_size])
hx = (h_0,c_0)
word_lstm_cell = WordLSTMCell(word_emb_dim,hidden_size)
print(word_lstm_cell(input_,hx).shape)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值