keras自带的Reshape层不需要写batch的维度,但是tensorflow的reshape需要完整的维度。
keras中reshape需要知道具体维度,如non-locla 复现中,在reshape成原输入大小时具体的tensor维度需要明确。不能直接将dim1和dim2用作赋值即reshape(dim1,dim2,channels)(y)不可行。
ip_shape = K.int_shape(x)
bs, dim1, dim2, channels = ip_shape
# x = tf.keras.layers.Conv2D(channels, 1, strides=1, kernel_initializer=initializer, padding='same', use_bias=False)
x1 = Reshape((-1, channels))(x) # xi
x2 = Reshape((-1, channels))(x) # xj
f = dot([x1, x2], axes=2)
f = Activation('softmax')(f)
g = Reshape((-1, channels))(x)
y = dot([f, g], axes=[2, 1])
y = Reshape((pow(2, i), pow(2, i), channels))(y)
x = add([y, x])