"""
CAAE网络的encoder部分
"""
def encoder(self, image, reuse_variables=False):
if reuse_variables:
tf.get_variable_scope().reuse_variables()
num_layers = int(np.log2(self.size_image)) - int(self.size_kernel/2)
current = image
# conv layers with stride 2
for i in range(num_layers):
name='E_conv'+str(i)
current=conv2d(input_map=current,num_output_channels=self.num_encoder_channels*(2**i),size_kernel=self.size_kernel,name=name)
current=tf.nn.relu(current)
# fully connection layer
name='E_fc'
current=fc(input_vector=tf.reshape(current,[self.size_batch,-1]), num_output_length=self.num_z_channels,name=name)
# output
return tf.nn.tanh(current)
"""
ops.py的load_image部分
Args: arbitrary image
Returs:
一张64*64的RGB np.float32 image
"""
def load_image(
image_path, # path of a image
image_size=64, # expected size of the image
image_value_range=(-1, 1), # expected pixel value range of the image
is_gray=False, # gray scale or color image
):
if is_gray:
image = imread(image_path, mode='L').astype(np.float32)
else:
image = imread(image_path, mode='RGB').astype(np.float32)
image = imresize(image, [image_size, image_size])
image = image.astype(np.float32) * (image_value_range[-1] - image_value_range[0]) / 255.0 + image_value_range[0]
return image