虽然已经接触deep learning很长一段时间了,也看了很久rnn相关的代码,但是突然想用tensorflow实现一些功能的时候,突然发现丝毫没有头绪,找了一些资料,学习了一波,记录一下。
一、tensorflow实现RNN cell
tensorflow由于不同的版本改动较大,在1.0版本之后,可以使用如下语句来创建一个cell:
from tensorflow.contrib import rnn
cell_fun = rnn.GRUCell(rnn_hidden_size)
在tensorflow中,上述GRUCell的实现如下(可以在GitHub上看到源码):
class GRUCell(RNNCell):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
def __init__(self, num_units, input_size=None, activation=tanh):
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
self._num_units = num_units
self._activation = activation
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
def __call__