代码:
def convolve(img, W):
# The W matrix is only 2D# But conv2d will need a tensor which is 4d:
# height x width x n_input x n_output
if len(W.get_shape()) == 2:
dims = W.get_shape().as_list() + [1, 1]
W = tf.reshape(W, dims)
if len(img.get_shape()) == 2:
# num x height x width x channels
dims = [1] + img.get_shape().as_list() + [1]
img = tf.reshape(img, dims)
elif len(img.get_shape()) == 3:
dims = [1] + img.get_shape().as_list()
img = tf.reshape(img, dims)
# if the image is 3 channels, then our convolution
# kernel needs to be repeated for each input channel
W = tf.concat(2, [W, W, W])
# Stride is how many values to skip for the dimensions of
# num, height, width, channels
convolved = tf.nn.conv2d(img, W,
strides=[1, 1, 1, 1], padding='SAME')
return convolved
x = tf.placeholder(tf.float32, shape=img.shape)
out = convolve(x, gabor())
报错:
TypeError: Expected int32, got list containing Tensors of type '_Message' instead.
解决办法:
tf.concat([W, W, W],2)