Python训练营打卡Day50-对抗生成网络

知识点回顾:
  1. 对抗生成网络的思想:关注损失从何而来
  2. 生成器、判别器
  3. nn.sequential容器:适合于按顺序运算的情况,简化前向传播写法
  4. leakyReLU介绍:避免relu的神经元失活现象

ps;如果你学有余力对于gan的损失函数的理解,建议去找找视频看看,如果只是用,没必要学

作业:对于心脏病数据集,对于病人这个不平衡的样本用GAN来学习并生成病人样本,观察不用GAN和用GAN的F1分数差异。

一、 GAN对抗生成网络思想

先说一下gan的原理,我们之前就说过,无论是多么复杂的架构设计,你最需要把握的核心点就是:1. 损失从何而来 2. 损失如何定义。

假设现在有1个造假币的a,和一个警察b,他们都是初始化的神经网络。

为了让警察能分辨真币,我需要给他看真币的样子,也就是先训练警察知道什么是真币,但是只给他一个真币他也没法训练,我中间得掺杂一些无意义的噪声图片,无意义的肯定不可能是真币,所以他大概可以训练好。此时他就是一个基础的分类模型,能分别真币和假币

然后我随机初始化这个造假币的a,我每次给他随机的输入,他会生成不同的假币,每一张假币都会让警察来判断,警察能够分别出真假,如果说你是假的,那么造假币的就要更新参数,如果是真的,那么造假币的参数就不用更新,警察要更新参数。所以后续二者就在不断博弈中进步,知道最后假币专家造假以假乱真,只要更新的轮数多,即使骗不过专家,但是也很棒了。

我们把这个造假币的叫做生成器,把这个警察叫做判别器。

这个过程有点类似于二者互相对抗,所以叫做对抗生成网络,目的就是在对抗中找到一个可以模仿原有数据的模型。生成器基于随机噪声生成样本,判别器对样本(真实数据 + 生成数据)进行分类,双方根据判别结果更新参数(生成器尝试让判别器误判,判别器尝试提高准确率)。

他的损失由2部分构成,生成器的损失和判别器的损失。判别器的损失定义在先,生成器的损失定义基于判别器的反馈。

1. 判别器的损失:就是分类问题的损失,二分类是二元交叉熵损失BCE(本质分类输出的是概率,并不是非黑即白)

2. 生成器的损失:依靠判别器的损失,如果判别器说是假的,生成器的损失就大,反之亦然。

判别器损失同时包含 “真实数据判真” 和 “生成数据判假” 两部分,而生成器损失仅针对 “生成数据被判别为真” 的目标。实际训练中,两者的优化是交替进行的(先训判别器,再训生成器)。两者的损失共同推动 GAN 逼近 “以假乱真” 的平衡状态。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataL
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值