我的tf=1.9 的环境,原代码是1.12,结果报错了。
from tensorflow.contrib.rnn.python.ops import rnn_cell
linear = rnn_cell._linear # pylint: disable=protected-access
这个暂时是无解。不过由于这个函数实现的简单的线性求和,因此可以手动在程序中进行修改。
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.util import nest
from tensorflow.python.ops import variable_scope as vs
def linear(args, output_size, bias, bias_start=0.0, scope=None):
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
Args:
args: a 2D Tensor or a list of 2D, batch x n, Tensors.
output_size: int, second dimension of W[i].
bias: boolean, whether to add a bias term or not.
bias_start: starting value to initialize the bias; 0 by default.
scope: (optional) Variable scope to create parameters in.
Returns: