- 对抗生成网络的思想:关注损失从何而来
- 生成器、判别器
- nn.sequential容器:适合于按顺序运算的情况,简化前向传播写法
- leakyReLU介绍:避免relu的神经元失活现象
ps;如果你学有余力,对于gan的损失函数的理解,建议去找找视频看看,如果只是用,没必要学
作业:对于心脏病数据集,对于病人这个不平衡的样本用GAN来学习并生成病人样本,观察不用GAN和用GAN的F1分数差异。
一、 GAN对抗生成网络思想
先说一下gan的原理,我们之前就说过,无论是多么复杂的架构设计,你最需要把握的核心点就是:1. 损失从何而来 2. 损失如何定义。
假设现在有1个造假币的a,和一个警察b,他们都是初始化的神经网络。
为了让警察能分辨真币,我需要给他看真币的样子,也就是先训练警察知道什么是真币,但是只给他一个真币他也没法训练,我中间得掺杂一些无意义的噪声图片,无意义的肯定不可能是真币,所以他大概可以训练好。此时他就是一个基础的分类模型,能分别真币和假币
然后我随机初始化这个造假币的a,我每次给他随机的输入,他会生成不同的假币,每一张假币都会让警察来判断,警察能够分别出真假,如果说你是假的,那么造假币的就要更新参数,如果是真的,那么造假币的参数就不用更新,警察要更新参数。所以后续二者就在不断博弈中进步,知道最后假币专家造假以假乱真,只要更新的轮数多,即使骗不过专家,但是也很棒了。
我们把这个造假币的叫做生成器,把这个警察叫做判别器。
这个过程有点类似于二者互相对抗,所以叫做对抗生成网络,目的就是在对抗中找到一个可以模仿原有数据的模型。生成器基于随机噪声生成样本,判别器对样本(真实数据 + 生成数据)进行分类,双方根据判别结果更新参数(生成器尝试让判别器误判,判别器尝试提高准确率)。
他的损失由2部分构成,生成器的损失和判别器的损失。判别器的损失定义在先,生成器的损失定义基于判别器的反馈。
1. 判别器的损失:就是分类问题的损失,二分类是二元交叉熵损失BCE(本质分类输出的是概率,并不是非黑即白)
2. 生成器的损失:依靠判别器的损失,如果判别器说是假的,生成器的损失就大,反之亦然。
判别器损失同时包含 “真实数据判真” 和 “生成数据判假” 两部分,而生成器损失仅针对 “生成数据被判别为真” 的目标。实际训练中,两者的优化是交替进行的(先训判别器,再训生成器)。两者的损失共同推动 GAN 逼近 “以假乱真” 的平衡状态。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataL

最低0.47元/天 解锁文章
2012

被折叠的 条评论
为什么被折叠?



