前段时间在实现BeautyGAN时,生成器中公共的残差块部分使用两个共享参数的残差块分支实现,跟论文有所出入,应该是一个分支,将两个分支的输入feature map合并为1个张量即可,对张量的合并和拆分操作简单做一下笔记。
Tensor的合并
tf.concat(values, axis, name="concat")
values:输入的张量
axis:待合并的维度
官方示例:
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0) # [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 1) # [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
Tensor的拆分
tf.split(value, num_or_size_splits, axis=0, num=None, name="split")
value:待拆分的张量
num_or_size_splits:拆分的批量大小
1)若为整数n,则将待拆分的维度均匀拆分为n个张量,待拆分维度的shape要是n的整数倍;
2)若为一维数组[n1, n2, ...],则其size表示拆分后的张量个数,第i个数表示第i个张量对应维度的值。
axis:待拆分的维度
官方示例:
# 'value' is a tensor with shape [5, 30]
# Split 'value' into 3 tensors with sizes [4, 15, 11] along dimension 1
split0, split1, split2 = tf.split(value, [4, 15, 11], 1)
tf.shape(split0) # [5, 4]
tf.shape(split1) # [5, 15]
tf.shape(split2) # [5, 11]
# Split 'value' into 3 tensors along dimension 1
split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1)
tf.shape(split0) # [5, 10]