出错代码:
mask_windows = window_partition(mask_array, self.window_size)
mask_windows = tf.reshape(mask_windows, shape=[-1, self.window_size * self.window_size])
attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims(mask_windows, axis=2)
attn_mask = tf.cast(attn_mask, tf.float32)
attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
报错:
ValueError: Shapes must be equal rank, but are 0 and 3 for 'generator_swin_unet/swin_transformer_block_1/Select' (op: 'Select') with input shapes: [], [], [196,16,16].
解决方案:
mask_windows = window_partition(mask_array, self.window_size)
mask_windows = tf.reshape(mask_windows, shape=[-1, self.window_size * self.window_size])
attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims(mask_windows, axis=2)
#修改shape到相同
non_zero = tf.fill(tf.shape(attn_mask), -100.0)
zero = tf.fill(tf.shape(attn_mask), -0.0)
attn_mask = tf.where(tf.equal(attn_mask, 0), zero, non_zero)
# attn_mask = tf.where(attn_mask != 0, non_zero, attn_mask)
# attn_mask = tf.where(attn_mask == 0, zero, attn_mask)