源码地址:
https://github.com/carpedm20/DCGAN-tensorflow
今天看了源码中,再debug调试时看到了conv_cond_concat这个函数,第一反应就是应该与tf.concat有关系,看了源码确实是。
以下是函数定义:
def conv_cond_concat(x, y):
"""Concatenate conditioning vector on feature map axis."""
x_shapes = x.get_shape()
y_shapes = y.get_shape()
return concat([
x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)
这里面有
y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])]
得注意一下,这个是tensorflow中的乘法表示,tf.ones是全1的矩阵。
这个单独测试一下,看效果比较直观:
#!/usr/bin/env python
#coding:utf8
import os,sys
import numpy as np
import tensorflow as tf
y=tf.ones([1,1,3])
y=np.array([1,2,3])
z=y*tf.ones([4,4,3])
with tf.Session():
print(z.eval())
结果为:4个4*3的矩阵,矩阵每一行都是一个y。
[[[ 1. 2. 3.]
[ 1. 2. 3.]
[ 1. 2. 3.]
[ 1. 2. 3.]]
[[ 1. 2. 3.]
[ 1. 2. 3.]
[ 1. 2. 3.]
[ 1. 2. 3.]]
[[ 1. 2. 3.]
[ 1. 2. 3.]
[ 1. 2. 3.]
[ 1. 2. 3.]]
[[ 1. 2. 3.]
[ 1. 2. 3.]
[ 1. 2. 3.]
[ 1. 2. 3.]]]
Process finished with exit code 0
所以,不难看出,这就是一个相当于一个复制粘贴改变大小的操作。
最后concat很简单就是将两个矩阵(只有要连接起来的那一维不一样,其他维度都一样的矩阵)连接起来。这里是将最后一维连接起来。