目录
生成对抗网络(Generative Adversarial Networks, GANs)是一种深度学习架构,由Ian Goodfellow等人于2014年首次提出,主要用于生成逼真的随机样本数据。这里我们特别关注基于高斯分布的GAN变体——尽管GAN本身并不直接针对高斯分布设计,但在某些情况下,其生成的数据分布可能近似高斯分布。
GAN包含两个主要组成部分:生成器(Generator)G和判别器(Discriminator)D。它们共同构成一个动态博弈系统,如下图所示:
1.生成器(Generator)G
生成器G:它接受来自潜在空间(通常是高维正态分布或其他简单分布如均匀分布)的噪声变量 z,并试图将其转换为看起来像真实数据分布的数据样本x。用数学公式表述就是:
其中,pz(z) 表示潜在空间中的先验分布,通常选择为标准正态分布N(0,I)。
2.判别器(Discriminator)D
判别器D:它的目的是区分真实数据样本x 和生成器产生的假样本G(z)。对于任意输入 x,判别器输出一个介于0和1之间的实数值,代表该样本为真实数据的概率估计。判别器的决策边界可以用以下函数表示:
GAN的优化过程是通过最小极大博弈实现的,即生成器试图最小化欺骗判别器的成功率,而判别器则努力最大化正确分类样本的能力。
当博弈达到纳什均衡时,生成器生成的样本应该无法被最优的判别器区分,此时生成器所学习到的概率分布pg 尽可能接近真实数据分布pdata。
综上所述,G和D构成了一个动态对抗(或博弈过程),随着训练(对抗)的进行,G生成的数据越来越接近真实数据,D鉴别数据的水平越来越高。在理想的状态下,G可以生成足以“以假乱真”的数据;而对于D来说,它难以判定生成器生成的数据究竟是不是真实的,因此D(G(z)) = 0.5。训练完成后,我们得到了一个生成模型G,它可以用来生成以假乱真的数据。
3.GAN训练
第一阶段:固定「判别器D」,训练「生成器G」。使用一个性能不错的判别器,G不断生成“假数据”,然后给这个D去判断。开始时候,G还很弱,所以很容易被判别出来。但随着训练不断进行,G技能不断提升,最终骗过了D。这个时候,D基本属于“瞎猜”的状态,判断是否为假数据的概率为50%。
第二阶段:固定「生成器G」,训练「判别器D」。当通过了第一阶段,继续训练G就没有意义了。这时候我们固定G,然后开始训练D。通过不断训练,D提高了自己的鉴别能力,最终他可以准确判断出假数据。
重复第一阶段、第二阶段。通过不断的循环,「生成器G」和「判别器D」的能力都越来越强。最终我们得到了一个效果非常好的「生成器G」,就可以用它来生成数据。
在训练过程中,D会接收真数据和G产生的假数据,它的任务是判断图片是属于真数据的还是假数据的。对于最后输出的结果,可以同时对两方的参数进行调优。如果D判断正确,那就需要调整G的参数从而使得生成的假数据更为逼真;如果D判断错误,则需调节D的参数,避免下次类似判断出错。训练会一直持续到两者进入到一个均衡和谐的状态。
训练后的产物是一个质量较高的自动生成器和一个判断能力较强的分类器。前者可以用于机器创作,而后者则可以用来机器分类。
4.MATLAB程序
..............................................................................
% 开始迭代训练
for i = 1:Niter
% 数据准备
% 生成器输入数据:随机生成并归一化
Gdat = rand(10,bthsize);
Gdat = mapminmax(Gdat', 0, 1)';
GT = zeros([1, bthsize]); % 假样本标签设为0
% 真实数据:生成均值为5、方差为2的高斯分布,归一化并设置真样本标签为1
Rdat = 10*rand([100,bthsize]);
Rdat = sort(Rdat);
Rdaty = exp(-(Rdat-5).^2/4);
real_data = mapminmax(Rdaty', 0, 1)';
real_label = ones([1, bthsize]);
% G网络前向传播
netG = func_FW(netG, Gdat); % 使用G_data进行前向传播
netG_out = netG.o_o; % 获取G网络输出
% 实时观察G网络生成数据
if mod(i, 50) == 0 % 每隔50次迭代显示一次
figure(1);
plot(Rdat(:,1)/10, real_data(:,1)); % 绘制真实数据
hold on
plot([0.01:0.01:1], netG_out(:,1),'r.'); % 绘制G网络生成数据
hold off
pause(0.1); % 暂停0.1秒
end
% 准备D网络输入数据:将生成数据与真实数据拼接并打乱顺序
data_temp = [netG_out, real_data]; % 数据拼接
netD_label = [GT, real_label]; % 标签拼接
rand_idx = randperm(2*bthsize); % 打乱索引
D_data = data_temp(:,rand_idx); % 按打乱顺序选取数据
D_label = netD_label(rand_idx); % 对应标签也按打乱顺序选取
% D网络前向传播
netD = func_FW(netD, D_data); % 使用D_data进行前向传播
netD_out = netD.o_o; % 获取D网络输出
% 计算D网络和G网络损失
netD_loss = (netD_out - D_label); % D网络误差计算
netG_loss = (netD_out .* (D_label == 0) - (D_label==0)); % G网络误差计算
% 存储网络误差
D_L = [D_L; sum(1/2*(netD_loss).^2)/length(D_label)]; % D网络误差平均值
G_L = [G_L; sum(1/2*(netG_loss).^2)/length(D_label)*2]; % G网络误差平均值
% D网络反向传播
netD_D = func_bk(netD, netD_loss); % 使用D网络真实误差反向传播
netD_G = func_bk(netD, netG_loss); % 使用G网络误差反向传播
% 提取G网络需要的误差
netG_o_loss_temp = netD_G.w' * netD_G.d_hi; % G网络误差(包含真实数据误差)
temp_data = [rand_idx', netG_o_loss_temp']; % 将误差与原顺序对应
temp_data = sortrows(temp_data, 1); % 重新按照原顺序排列
netG_o_loss = temp_data(1:bthsize, 2:end)'; % 提取G网络真实误差
% G网络反向传播
netG = func_bk(netG, netG_o_loss); % 使用G网络真实误差反向传播
% 更新网络权值
netD = func_updata(netD_D); % 更新D网络权值
netG = func_updata(netG); % 更新G网络权值
end
% 绘制训练损失曲线
figure;
plot(D_L); % 绘制D网络损失曲线
hold on
plot(G_L); % 绘制G网络损失曲线
up4069