简单seq2seq代码 使用tensorflow的LSTMCell构造循环decoder

好多预测模型的论文都是用seq2seq实现的,具体是LSTM_encoder将输入序列编码为一个tensor(又叫output、H或Y),同时保留序列状态state(又叫w或c);
LSTM_decoder继承encoder的状态,将上层的output作为输入,得到的每个输出到embeding中找对应的词向量,然后再次调用LSTM_decoder刚才的输出作为这次的输入。一直循环,直到输出EOS为止。

tensorflow中并没有循环网络(可能有,我不知道)。因此决定用LSTMCell循环实现。
代码如下:

class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.encodeLSTM = layers.LSTM(1, return_state=True)
        self.decodeLSTM = layers.LSTMCell(1) #简单试验 units=1

    def call(self, inputs):
        x, memory_state, carry_state = self.encodeLSTM(inputs)
        pred = tf.constant([], shape=(16, 0), dtype=tf.float32)
        for i in range(8): #循环预测之后的8个序列
            x,[memory_state, carry_state] = self.decodeLSTM(x,[memory_state, carry_state])
            pred = tf.concat((pred, x),axis=1 )
        return pred

model = MyModel()
model.build((None, 10, 1)) #设time_seq=10
model.summary()

在这里插入图片描述
LSTMCell一般与RNN组合使用,例

cell=[layer.LSTMCell(10) , layers.LSTMCell(5)]
layers.RNN(cell) #构造双层LSTM第一层10个单元 第二层5个单元

单独使用时注意一下几点:
①LSTMCell帮助文档中没有关于状态的参数,需要从**kwargs传入。
②LSTMCell的状态不能保留,因此它每一次运算都会返回当前状态,以便下一次继续使用。
③LSTMCell由于不处理时间序列time_seq,它的输入格式为(batch_size,units)和输出格式相同。对比LSTM输入(batch_size,time_seq,units)输出(batch_size,units)

感谢官方文档和github教我的用法
https://tensorflow.google.cn/api_docs/python/tf/keras/layers/LSTMCell
https://github.com/search?q=tensorflow+layers.LSTMCell&type=Code

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值