"""Dong2019.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1YhkaQQ8xjocvH8vGHgQ0k6V0KTzngbO5
"""
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
H = 576
W = 768
N = 32
d = int(W/3)
def ResNet(h, w, n):
ResNet_input = keras.Input(shape=(int(h/2), int(w/2), n))
ResNet_x = ResNet_input
for i in range(8):
block_pre = ResNet_x
ResNet_x = layers.Conv2D(n, 3, activation='relu', padding='same')(ResNet_x)
ResNet_x = layers.Conv2D(n, 3, activation='relu', padding='same')(ResNet_x)
ResNet_x = layers.add([ResNet_x, block_pre])
model = keras.Model(ResNet_input, ResNet_x, name='ResNet')
return model
"""ResNet1模块"""
input = keras.Input(shape=(H,W,1))
layer1 = layers.Conv2D(N, 5, activation='relu', strides=2, padding='same')(input)
ResNet1 = ResNet(H, W, N)
layer_17 = ResNet1(layer1)
layer_18 = layers.Conv2D(N, 3, padding='same')(layer_17)
ResNet1 = keras.Model(input, layer_18, name='ResNet1')
class CustomConcatenate(layers.Layer):
def __init__(self):
super(CustomConcatenate, self).__init__()
def call(self, input):
input1 = input[0]
input2 = input[1]
input1_ex = tf.expand_dims(input1, 3)
input2_ex = tf.expand_dims(input2, 3)
out = input1_ex
d2 = int(d/2)
print(d2)
for i in range(d2-1):
t1 = tf.concat([input2_ex[:,:,i:d2,:,:], tf.zeros_like(input2_ex)[:,:,:i,:,:]], 2)
t2 = tf.concat([input2_ex[:,:,d2+i:2*d2,:,:], tf.zeros_like(input2_ex)[:,:,:i,:,:]], 2)
t3 = tf.concat([input2_ex[:,:,2*d2+i:3*d2,:,:], tf.zeros_like(input2_ex)[:,:,:i,:,:]], 2)
t4 = tf.concat([t1,t2,t3], 2)
out = tf.concat([out,t4], 3)
return out
class AttentionMultiply(layers.Layer):
def __init__(self):
super(AttentionMultiply, self).__init__()
def call(self, input):
input1 = input[0]
input2 = input[1]
temp = input1
for i in range(int(N/2)-1):
temp = tf.pad(temp, [[0,0],[0,0],[0,0],[0,0],[0,1]], 'SYMMETRIC')
attention_ex = tf.concat([tf.ones_like(input2)[:,:,:,:,:int(N/2)], temp], 4)
out = tf.multiply(attention_ex, input2)
return out
class SqueezeLayerDim(layers.Layer):
def __init__(self):
super(SqueezeLayerDim, self).__init__()
def call(self, input):
out = tf.squeeze(input[0])
return out
"""网络模型搭建"""
ResNet1 = ResNet(H, W, N)
input_1 = keras.Input(shape=(H,W,1))
layer1_1 = layers.Conv2D(N, 5, activation='relu', strides=2, padding='same')(input_1)
layer_17_1 = ResNet1(layer1_1)
layer_18_1 = layers.Conv2D(N, 3, padding='same')(layer_17_1)
input_2 = keras.Input(shape=(H,W,1))
layer1_2 = layers.Conv2D(N, 5, activation='relu', strides=2, padding='same')(input_2)
layer_17_2 = ResNet1(layer1_2)
layer_18_2 = layers.Conv2D(N, 3, padding='same')(layer_17_2)
print(layer_18_1)
print(layer_18_2)
VF = CustomConcatenate()([layer_18_1, layer_18_2])
VF
layer_19 = layers.Conv3D(N, 1, activation='sigmoid')(VF)
A = layers.Conv3D(1, 1, activation='sigmoid')(layer_19)
VA = AttentionMultiply()([A, VF])
layer_21 = layers.Conv3D(N, 3, padding='same')(VA)
layer_22 = layers.Conv3D(N, 3, padding='same')(layer_21)
layer_23 = layers.Conv3D(2*N, 3, strides=2, padding='same')(layer_22)
layer_24 = layers.Conv3D(2*N, 3, padding='same')(layer_23)
layer_25 = layers.Conv3D(2*N, 3, padding='same')(layer_24)
layer_26 = layers.Conv3D(2*N, 3, strides=2, padding='same')(layer_25)
layer_27 = layers.Conv3D(2*N, 3, padding='same')(layer_26)
layer_28 = layers.Conv3D(2*N, 3, padding='same')(layer_27)
layer_29 = layers.Conv3D(2*N, 3, strides=2, padding='same')(layer_28)
layer_30 = layers.Conv3D(2*N, 3, padding='same')(layer_29)
layer_31 = layers.Conv3D(2*N, 3, padding='same')(layer_30)
layer_32 = layers.Conv3D(2*N, 3, strides=2, padding='same')(layer_31)
layer_33 = layers.Conv3D(2*N, 3, padding='same')(layer_32)
layer_34 = layers.Conv3D(2*N, 3, padding='same')(layer_33)
layer_35 = layers.Conv3DTranspose(2*N, 3, strides=2, padding='same')(layer_34)
layer_35_2 = layers.add([layer_31, layer_35])
layer_36 = layers.Conv3DTranspose(2*N, 3, strides=2, padding='same')(layer_35_2)
layer_36_2 = layers.add([layer_28, layer_36])
layer_37 = layers.Conv3DTranspose(2*N, 3, strides=2, padding='same')(layer_36_2)
layer_37_2 = layers.add([layer_25, layer_37])
layer_38 = layers.Conv3DTranspose(N, 3, strides=2, padding='same')(layer_37_2)
layer_38_2 = layers.add([layer_22, layer_38])
layer_39 = layers.Conv3DTranspose(1, 3, strides=2, padding='same')(layer_38_2)
W = SqueezeLayerDim()(layer_39)
Dong2019 = keras.Model([input_1,input_2], W, name='Dong2019')
Dong2019.summary()
class GetRoughResult(layers.Layer):
def __init__(self):
super(GetRoughResult, self).__init__()
def call(self, input):
input1 = input[0]
input2 = input[1]
out = tf.zeros_like(input1)
for i in range(input1.shape[0]):
for j in range(input1.shape[1]):
return out
a = tf.zeros([2,2])
a = tf.Variable(a)
a[0][0] = 1
a
input_3 = keras.Input(shape=(H,W))
GetRoughResult()([input_3, layer_39_2])
t = tf.constant([[1], [4]])
paddings = tf.constant([[0,0], [1,0]])
print(t)
t2 = tf.pad(t, paddings, "SYMMETRIC")
t2
t = tf.constant([[1, 2, 3], [4, 5, 6]])
print(t)
t2 = tf.pad(t, [[0,0], [2,1]], 'REFLECT')
print(t2)
"""拼接FY和FYR"""
Y = keras.Input(shape=(H,W,1))
YR = keras.Input(shape=(H,W,1))
FY = ResNet1(Y)
FYR = ResNet1(YR)
VF = CustomConcatenate()(FY, FYR)
Concatenate = keras.Model(inputs=[Y,YR], outputs=VF, name='Concatenate')
keras.utils.plot_model(Concatenate, 'Concatenate.png', show_shapes=True)
"""Attention模块"""
input = keras.Input(shape=(int(H/2), int(W/2), int(d/2), N))
x = layers.Conv3D(N, 1, activation='sigmoid')(input)
x = layers.Conv3D(1, 1, activation='sigmoid')(x)
Attention = keras.Model(input, x, name='Attention')
keras.utils.plot_model(Attention, 'Attention.png', show_shapes=True)
a = np.array(range(24)).reshape(2,3,4)
b = np.array(range(24)).reshape(2,3,4)
x = np.array(range(48)).reshape(2,3,8)
a = tf.expand_dims(a, 2)
b = tf.expand_dims(b, 2)
x = tf.expand_dims(x, 2)
print(a.shape)
print(b.shape)
x[0,0,0,:] = np.concatenate([a[0,0,0,:], b[0,0,0,:]])
print(x.shape)