首先,该代码在新版本下会运行失败,根据 https://github.com/keras-team/keras/pull/13712/commits,需要把文件C:\ProgramData\Miniconda3\pkgs\keras-base-2.3.1-py37_0\Lib\site-packages\keras\backend\tensorflow_backend.py 中的函数 _get_available_gpus,由
def _get_available_gpus():
"""Get a list of available gpu devices (formatted as strings).
# Returns
A list of available GPU devices.
"""
global _LOCAL_DEVICES
if _LOCAL_DEVICES is None:
if _is_tf_1():
devices = get_session().list_devices()
_LOCAL_DEVICES = [x.name for x in devices]
else:
_LOCAL_DEVICES = tf.config.experimental_list_devices()
return [x for x in _LOCAL_DEVICES if 'device:gpu' in x.lower()]
,修改为:
def _get_available_gpus():
"""Get a list of available gpu devices (formatted as strings).
# Returns
A list of available GPU devices.
"""
global _LOCAL_DEVICES
if _LOCAL_DEVICES is None:
if _is_tf_1():
devices = get_session().list_devices()
_LOCAL_DEVICES = [x.name for x in devices]
elif int(tf.__version__.split('.')[1]) >= 1:
devices = tf.config.list_logical_devices()
_LOCAL_DEVICES = [x.name for x in devices]
else:
_LOCAL_DEVICES = tf.config.experimental_list_devices()
return [x for x in _LOCAL_DEVICES if 'device:gpu' in x.lower()]
神经网络结构如下:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 1, 32, 32) 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 8, 32, 32) 80 input_1[0][0]
__________________________________________________________________________________________________
elu_1 (ELU) (None, 8, 32, 32) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 8, 32, 32) 72 elu_1[0][0]
__________________________________________________________________________________________________
elu_2 (ELU) (None, 8, 32, 32) 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 8, 32, 32) 72 elu_2[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 8, 32, 32) 0 conv2d_1[0][0]
conv2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 8, 16, 16) 0 add_1[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 16, 16, 16) 1168 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
elu_3 (ELU) (None, 16, 16, 16) 0 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 16, 16, 16) 272 elu_3[0][0]
__________________________________________________________________________________________________
elu_4 (ELU) (None, 16, 16, 16) 0 conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 16, 16, 16) 272 elu_4[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 16, 16, 16) 0 conv2d_4[0][0]
conv2d_6[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 16, 8, 8) 0 add_2[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 32, 8, 8) 4640 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
elu_5 (ELU) (None, 32, 8, 8) 0 conv2d_7[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 32, 8, 8) 1056 elu_5[0][0]
__________________________________________________________________________________________________
elu_6 (ELU) (None, 32, 8, 8) 0 conv2d_8[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 32, 8, 8) 1056 elu_6[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 32, 8, 8) 0 conv2d_7[0][0]
conv2d_9[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 32, 4, 4) 0 add_3[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 64, 4, 4) 18496 max_pooling2d_3[0][0]
__________________________________________________________________________________________________
elu_7 (ELU) (None, 64, 4, 4) 0 conv2d_10[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 64, 4, 4) 4160 elu_7[0][0]
__________________________________________________________________________________________________
elu_8 (ELU) (None, 64, 4, 4) 0 conv2d_11[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 64, 4, 4) 4160 elu_8[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 64, 4, 4) 0 conv2d_10[0][0]
conv2d_12[0][0]
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D) (None, 64, 2, 2) 0 add_4[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 128, 2, 2) 73856 max_pooling2d_4[0][0]
__________________________________________________________________________________________________
elu_9 (ELU) (None, 128, 2, 2) 0 conv2d_13[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 128, 2, 2) 16512 elu_9[0][0]
__________________________________________________________________________________________________
elu_10 (ELU) (None, 128, 2, 2) 0 conv2d_14[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 128, 2, 2) 16512 elu_10[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 128, 2, 2) 0 conv2d_13[0][0]
conv2d_15[0][0]
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D) (None, 128, 1, 1) 0 add_5[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 128, 2, 2) 0 max_pooling2d_5[0][0]
__________________________________________________________________________________________________
lambda_5 (Lambda) (None, 128, 2, 2) 0 add_5[0][0]
max_pooling2d_5[0][0]
__________________________________________________________________________________________________
multiply_1 (Multiply) (None, 128, 2, 2) 0 up_sampling2d_1[0][0]
lambda_5[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 64, 2, 2) 73792 multiply_1[0][0]
__________________________________________________________________________________________________
elu_11 (ELU) (None, 64, 2, 2) 0 conv2d_16[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 64, 2, 2) 4160 elu_11[0][0]
__________________________________________________________________________________________________
elu_12 (ELU) (None, 64, 2, 2) 0 conv2d_17[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 64, 2, 2) 4160 elu_12[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 64, 2, 2) 0 conv2d_16[0][0]
conv2d_18[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, 64, 4, 4) 0 add_6[0][0]
__________________________________________________________________________________________________
lambda_4 (Lambda) (None, 64, 4, 4) 0 add_4[0][0]
max_pooling2d_4[0][0]
__________________________________________________________________________________________________
multiply_2 (Multiply) (None, 64, 4, 4) 0 up_sampling2d_2[0][0]
lambda_4[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, 32, 4, 4) 18464 multiply_2[0][0]
__________________________________________________________________________________________________
elu_13 (ELU) (None, 32, 4, 4) 0 conv2d_19[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D) (None, 32, 4, 4) 1056 elu_13[0][0]
__________________________________________________________________________________________________
elu_14 (ELU) (None, 32, 4, 4) 0 conv2d_20[0][0]
__________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, 32, 4, 4) 1056 elu_14[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 32, 4, 4) 0 conv2d_19[0][0]
conv2d_21[0][0]
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D) (None, 32, 8, 8) 0 add_7[0][0]
__________________________________________________________________________________________________
lambda_3 (Lambda) (None, 32, 8, 8) 0 add_3[0][0]
max_pooling2d_3[0][0]
__________________________________________________________________________________________________
multiply_3 (Multiply) (None, 32, 8, 8) 0 up_sampling2d_3[0][0]
lambda_3[0][0]
__________________________________________________________________________________________________
conv2d_22 (Conv2D) (None, 16, 8, 8) 4624 multiply_3[0][0]
__________________________________________________________________________________________________
elu_15 (ELU) (None, 16, 8, 8) 0 conv2d_22[0][0]
__________________________________________________________________________________________________
conv2d_23 (Conv2D) (None, 16, 8, 8) 272 elu_15[0][0]
__________________________________________________________________________________________________
elu_16 (ELU) (None, 16, 8, 8) 0 conv2d_23[0][0]
__________________________________________________________________________________________________
conv2d_24 (Conv2D) (None, 16, 8, 8) 272 elu_16[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 16, 8, 8) 0 conv2d_22[0][0]
conv2d_24[0][0]
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D) (None, 16, 16, 16) 0 add_8[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda) (None, 16, 16, 16) 0 add_2[0][0]
max_pooling2d_2[0][0]
__________________________________________________________________________________________________
multiply_4 (Multiply) (None, 16, 16, 16) 0 up_sampling2d_4[0][0]
lambda_2[0][0]
__________________________________________________________________________________________________
conv2d_25 (Conv2D) (None, 8, 16, 16) 1160 multiply_4[0][0]
__________________________________________________________________________________________________
elu_17 (ELU) (None, 8, 16, 16) 0 conv2d_25[0][0]
__________________________________________________________________________________________________
conv2d_26 (Conv2D) (None, 8, 16, 16) 72 elu_17[0][0]
__________________________________________________________________________________________________
elu_18 (ELU) (None, 8, 16, 16) 0 conv2d_26[0][0]
__________________________________________________________________________________________________
conv2d_27 (Conv2D) (None, 8, 16, 16) 72 elu_18[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 8, 16, 16) 0 conv2d_25[0][0]
conv2d_27[0][0]
__________________________________________________________________________________________________
up_sampling2d_5 (UpSampling2D) (None, 8, 32, 32) 0 add_9[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, 8, 32, 32) 0 add_1[0][0]
max_pooling2d_1[0][0]
__________________________________________________________________________________________________
multiply_5 (Multiply) (None, 8, 32, 32) 0 up_sampling2d_5[0][0]
lambda_1[0][0]
__________________________________________________________________________________________________
conv2d_28 (Conv2D) (None, 1, 32, 32) 73 multiply_5[0][0]
__________________________________________________________________________________________________
elu_19 (ELU) (None, 1, 32, 32) 0 conv2d_28[0][0]
__________________________________________________________________________________________________
conv2d_29 (Conv2D) (None, 1, 32, 32) 2 elu_19[0][0]
__________________________________________________________________________________________________
elu_20 (ELU) (None, 1, 32, 32) 0 conv2d_29[0][0]
__________________________________________________________________________________________________
conv2d_30 (Conv2D) (None, 1, 32, 32) 2 elu_20[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, 1, 32, 32) 0 conv2d_28[0][0]
conv2d_30[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 1, 32, 32) 0 add_10[0][0]
==================================================================================================
Total params: 251,621
Trainable params: 251,621
Non-trainable params: 0
__________________________________________________________________________________________________
该神经网络,我并没有完全搞明白其实际意义,我知道它是一个编解码器,其训练的输入和输出是一样的,比如都是x_train,根据某些地方的介绍说,这样的自动编解码器,解码效果会更清晰,在训练完成后可以看下效果;
之所以效果更好,是因为解码器中使用了编码器中的位置信息,也就是函数 getwhere 中对 MaxPooling2D 进行求导,求导结果就是相应的位置信息;
其他的,以后慢慢悟吧