#discriminator 1 and Q network 1 models
input_shape = (feature1_dim,)
inputs = keras.layers.Input(shape=input_shape,name=‘discriminator1_input’)
dis1 = build_disciminator(inputs,z_dim=z_dim)
损失函数: 1) feature1是真实的概率 (adversarial1 loss)
2) MSE z1 重建损失 (Q1 network loss or entropy1 loss)
loss = [‘binary_crossentropy’,‘mse’]
loss_weights = [1.0,1.0]
dis1.compile(loss=loss,loss_weights=loss_weights,
optimizer=optimizer,
metrics=[‘acc’])
dis1.summary()
#generator models
feature1 = keras.layers.Input(shape=feature1_shape,name=‘featue1_input’)
labels = keras.layers.Input(shape=label_shape,name=‘labels’)
z1 = keras.layers.Input(shape=z_shape,name=‘z1_input’)
z0 = keras.layers.Input(shape=z_shape,name=‘z0_input’)
latent_codes = (labels,z0,z1,feature1)
gen0,gen1 = build_generator(latent_codes,image_size)
gen0.summary()
gen1.summary()
#encoder models
input_shape = (image_size,image_size,1)
inputs = keras.layers.Input(shape=input_shape,name=‘encoder_input’)
enc0,enc1 = build_encoder((inputs,feature1),num_labels)
enc0.summary()
enc1.summary()
encoder = keras.Model(inputs,enc1(enc0(inputs)))
encoder.summary()
data = (x_train,y_train),(x_test,y_test)
#训练对抗网路前,需要已经训练完成的编码器网络
train_encoder(encoder,data,model_name=model_name)
#adversarial0 model = generator0 + discrimnator0 + encoder0
optimizer = keras.optimizers.RMSprop(lr=lr0.5,decay=decay0.5)
enc0.trainable = False
dis0.trainable = False
gen0_inputs = [feature1,z0]
gen0_outputs = gen0(gen0_inputs)
adv0_outputs = dis0(gen0_outputs) + [enc0(gen0_outputs)]
adv0 = keras.Model(gen0_inputs,adv0_outputs,name=‘adv0’)
损失函数:1)feature1是真实的概率
2)Q network 0 损失
3)condition0 损失
loss = [‘binary_crossentropy’,‘mse’,‘mse’]
loss_weights = [1.0,10.0,1.0]
adv0.compile(loss=loss,
loss_weights=loss_weights,
optimizer=optimizer,
metrics=[‘acc’])
adv0.summary()
#adversarial1 model = generator1 + discrimnator1 + encoder1
enc1.trainable = False
dis1.trainable &