在写模型的时候,我们希望一个模型的参数是灵活的,例如矩阵乘的时候可以根据输入最后一维的大小来定义一个W。
获得矩阵的维度
tf.shape(input)
input为所求矩阵,返回该矩阵的维度,但是是一个Tensor。经常取出来的维度值并不能直接用,因为会出现类似这样的报错:
TypeError: int() argument must be a string or a number, not ‘Tensor’input.get_shape()
这样得到的是Dimension类型的对象。
解决办法:
使用as_list()函数将Dimention
k=input.get_shape().as_list()[-1]
例子:
u=tf.reshape(np.arange(0,6),[3,2])
k=u.get_shape().as_list()[-1]
w=tf.Variable(tf.random_uniform([k,4]))
prod=tf.matmul(tf.cast(u,tf.float32),w)
with tf.Session()as s:
s.run(tf.initialize_all_variables())
print s.run(prod)
output:
[[ 0.61596847 0.58131492 0.46035814 0.58667159]
[ 3.24906611 2.39435339 1.6604023 3.1891706 ]
[ 5.882164 4.20739174 2.86044645 5.79166985]]