无监督式GAN(infoGAN) matlab实战

    一、infoGAN原理简介

        普通的GAN存在无约束、不可控、噪声信号z很难解释等问题。InfoGAN 主要特点是对GAN进行了一些改动,成功地让网络学到了可解释的特征,网络训练完成之后,我们可以通过设定输入生成器的隐含编码来控制生成数据的特征。InfoGAN的基本结构为:

        其中,真实数据Real_data只是用来跟生成的Fake_data混合在一起进行真假判断,并根据判断的结果更新生成器和判别器,从而使生成的数据与真实数据接近。生成数据既要参与真假判断,还需要和隐变量C_vector求互信息,并根据互信息更新生成器和判别器,从而使得生成图像中保留了更多隐变量C_vector的信息。

二、matlab代码实战

clear all; close all; clc;
%% Info Generative Adversarial Network
%% Load Data
load('mnistAll.mat')
trainX = preprocess(mnist.train_images); 
trainY = mnist.train_labels;
testX = preprocess(mnist.test_images); 
testY = mnist.test_labels;
%% Settings
args.maxepochs = 50; args.c_weight = 0.5; args.z_dim = 62; 
args.batch_size = 16; args.image_size = [28,28,1]; 
args.lrD = 0.0002; args.lrG = 0.001; args.beta1 = 0.5;
args.beta2 = 0.999; args.cc_dim = 1; args.dc_dim = 10; 
args.sample_size = 100;
%% Weights, Biases, Offsets and Scales
% Generator
paramsGen.FCW1 = dlarray(...
    initializeGaussian([1024,args.z_dim+args.cc_dim+args.dc_dim]));
paramsGen.FCb1 = dlarray(zeros(1024,1,'single'));
paramsGen.BNo1 = dlarray(zeros(1024,1,'single'));
paramsGen.BNs1 = dlarray(ones(1024,1,'single'));

paramsGen.FCW2 = dlarray(initializeGaussian([128*7*7,1024]));
paramsGen.FCb2 = dlarray(zeros(128*7*7,1,'single'));
paramsGen.BNo2 = dlarray(zeros(128*7*7,1,'single'));
paramsGen.BNs2 = dlarray(ones(128*7*7,1,'single'));

paramsGen.TCW1 = dlarray(initializeGaussian([4,4,64,128]));
paramsGen.TCb1 = dlarray(zeros(64,1,'single'));
paramsGen.BNo3 = dlarray(zeros(64,1,'single'));
paramsGen.BNs3 = dlarray(ones(64,1,'single'));
paramsGen.TCW2 = dlarray(initializeGaussian([4,4,1,64]));
paramsGen.TCb2 = dlarray(zeros(1,1,'single'));

%% Progress Plot
function progressplot(args,paramsGen,stGen)
fixednoise = zeros(args.z_dim,args.sample_size);
tmp = zeros(args.cc_dim,args.sample_size);
for i = 1:10
    tmp(1,(i-1)*10+1:i*10) = linspace(-2,2,10);
end
cc = tmp;

tmp = zeros(args.dc_dim,args.sample_size);
for i = 1:10
    tmp(i,(i-1)*10+1:i*10) = 1;
end
dc = tmp;

fake_data = gpuArray(dlarray(cat(1,fixednoise,cc,dc),'CB'));
fake_images = extractdata(Generator(fake_data,paramsGen,stGen));

fig = gcf;
if ~isempty(fig.Children)
    delete(fig.Children)
end

I = imtile(fake_images);
I = rescale(I);
imagesc(I)
title("Generated Images")

drawnow;

end
%% Report Progress
function [d_loss,g_loss] = reportprogress(x,z,paramsDis,...
    paramsGen,args,stDis,stGen)
fake_images = Generator(z,paramsGen,stGen);
d_output_real = Discriminator(x,paramsDis,args,stDis);
d_output_fake = Discriminator(fake_images,paramsDis,args,stDis);

% Loss due to true or not
d_loss_a = -mean(log(d_output_real(1,:))+log(1-d_output_fake(1,:)));
g_loss_a = -mean(log(d_output_fake(1,:)));

% cc loss
output_cc = d_output_fake(2,:);
d_loss_cc = mean((output_cc/0.5).^2);

% softmax classification loss
output_dc = d_output_fake(3:end,:);
d_loss_dc = -(mean(sum(z(args.z_dim+args.cc_dim+1:end,:).*output_dc,1))+...
    mean(sum(z(args.z_dim+args.cc_dim+1:end,:).*z(args.z_dim+args.cc_dim+1:end,:),1)));

% Discriminator Loss
d_loss = d_loss_a+args.c_weight*d_loss_cc+d_loss_dc;
% Generator Loss
g_loss = g_loss_a+args.c_weight*d_loss_cc+d_loss_dc;
end
%% Model Gradients
function [GradDis,GradGen,stDis,stGen] = modelGradients(x,z,paramsDis,...
    paramsGen,args,stDis,stGen)
[fake_images,stGen] = Generator(z,paramsGen,stGen);
d_output_real = Discriminator(x,paramsDis,args,stDis);
[d_output_fake,stDis] = Discriminator(fake_images,paramsDis,args,stDis);

% Loss due to true or not
d_loss_a = -mean(log(d_output_real(1,:))+log(1-d_output_fake(1,:)));
g_loss_a = -mean(log(d_output_fake(1,:)));

% cc loss
output_cc = d_output_fake(2,:);
d_loss_cc = mean((output_cc/0.5).^2);

% softmax classification loss
output_dc = d_output_fake(3:end,:);
d_loss_dc = -(mean(sum(z(args.z_dim+args.cc_dim+1:end,:).*output_dc,1))+...
    mean(sum(z(args.z_dim+args.cc_dim+1:end,:).*z(args.z_dim+args.cc_dim+1:end,:),1)));

% Discriminator Loss
d_loss = d_loss_a+args.c_weight*d_loss_cc+d_loss_dc;
% Generator Loss
g_loss = g_loss_a+args.c_weight*d_loss_cc+d_loss_dc;

% For each network, calculate the gradients with respect to the loss.
GradGen = dlgradient(g_loss,paramsGen,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);

end

结果展示

 

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值