Pytorch和Tensorflow在实现RNN上的区别

TF:

单个RNN单元可以调用:

rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)  # num_units是隐藏状态的特征维数,如果直接将h当作输出,则输出特征的维数是128
(也可以调用tf.nn.rnn_cell.BasicRNNCell(num_units=128)或tf.nn.rnn_cell.GRUCell(num_units=128)等rnn_cell_impl.py里的类)

RNN单元纵向堆叠(多层RNN网络):

multi_cell = tf.nn.rnn_cell.MultiRNNCell(
                [rnn, rnn, rnn])

 RNN单元横向扩展(时间维度上):

lstm_outputs, final_state = tf.nn.dynamic_rnn(multi_cell, lstm_inputs, initial_state=initial_state)
# inputs是输入x: shape=(batch_size, 序列长度/时间步/句子长度, embedding_size(可能没有embedding))
# initial_state是初始隐藏状态h0: shape=(batch_size, multi_cell.state_size)
# lstm_outputs是 序列长度/时间步/句子长度 所有步的输出: shape=(batch_size, 序列长度/时间步/句子长度, cell.output_size),如果对hidden_state没有做特殊输出处理,那么output_size=hidden_size

Pytorch:

单个RNN单元可以调用torch.nn.RNNCell(), LSTMCell(), GRUCell()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值