模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL

1 GAN (回顾)

GAIL 的设计基于生成判别网络 ( GAN)。这里简单地回顾一下GAN,详细的可见 NTU 课程笔记 7454 GAN_UQI-LIUWJ的博客-CSDN博客
GAN由生成器 (Generator) 和判别器 (Discriminator)组成,它们 各是一个神经网络。
——>生成器负责生成假的样本
——>判别器负责判定一个样本是真是假。
我们的目标是希望生成器生成的内容可以“以假乱真”

1.1 生成器

        生成器 记作 a = G ( s ; θ ) ,其中 θ 是参数。它的输入是向量 s ,向量的每一个元素从均匀分布U(-1,1)或标准正态分布 N (0 , 1) 中抽取。生成器的输出是数据(比如图片)x

 

1.2 判别器

判别器 记作\hat{p}=D(x;\phi),其中 ϕ 是参数。
它的输入是图片 x;输出 \hat{p} 是介于 0 1 之间的概率值,0 表示“假的”, 1 表示“真的”。
判别器的功能是二分类器。

1.3 训练生成器

        将生成器与判别器相连,固定住判别器的参数,只更新生成器的参数 θ,使得生成的图片 x = G(s; θ) 在判别器的眼里更像真的。

        对于任意一个随机生成的向量 s,应该改变 θ,使得判别器的输出\hat{p}=D(x;\phi)尽量接近 1

        可以用如下函数作为loss function:

 

         我们希望此时D(x;Φ)越大越好,也就是E(s;θ)越小越好

        所以我们用梯度下降来更新生成器的θ

1.4 训练判别器

  •  判别器的本质是个二分类器,它的输出值 \hat{p}=D(x;\phi)表示对图片真伪的预测;
    • \hat{p} 接近 1 表示“真”,
    • \hat{p}接近 0 表示“假”。
判别器的训练如下图所示。
  • 从真实数据集中抽取一个样本,记作x^{real}
  • 再随机生成一个向量 s,用生成器生成 x^{fake}=G(s;\theta)
  • 训练判别器的目标是改进参数 ϕ,让 D(x^{real};\phi) 更接近 1(真),让D(x^{fake};\phi)更接近 0 (假)。
  • ——>也就是说让判别器的分类结果更准确,更好区分真实图片和生成的假图片。

此时的损失函数如下所示

 不难发现,判别器越准确,损失函数F越小

所以我们也用梯度下降更新判别器的θ

 

 1.5 整体训练流程

 

 2 生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL

2.1 训练数据

GAIL 的训练数据是被模仿的对象(人类专家)操作智能体得到的轨迹

 

数据集中有 k 条轨迹,把数据集记作:

 

 2.2 生成器

GAIL 的生成器是策略网络 π ( a | s ; θ )
策略网络的输入是状态 s,输出是一个向量:

 

输出向量 f 的维度是动作空间的大小 A ,它的每个元素对应一个动作,表示执行该动作
的概率。
给定初始状态 s 1 ,并让智能体与环境交互,可以得到一条轨迹:

 

 其中动作是根据策略网络抽样得到的, a_t \sim \pi(\cdot|s_t;\theta), \forall t=1,\cdots, n

 下一时刻的状态是环境根据状态转移函数计算出来的

 

 2.3 判别器

GAIL 的判别器记作 D ( s, a ; ϕ )

判别器的输入是状态 s,输出是一个向量:

 

输出向量 \hat{p} 的维度是动作空间的大小 A ,它的每个元素对应一个动作 a ,把一个元素记作:

 

\hat{p_a}接近 1 表示 ( s, a ) 为“真”,即动作 a 是人类专家做的。
\hat{p_a}接近 0 表示 ( s, a ) 为“假”,即动作 a 是策略网络生成的。

 2.4 GAIL的训练

2.4.1 训练生成器

\theta_{now}是当前策略网络的参数。用策略网络\pi(a|s;\theta_{now})控制智能体与环境交互,得到一条轨迹:
用判别器评价 (s_t,a_t)的真实情况, D(s_t,a_t;\phi)越大,说明 (s_t,a_t)在判别器的眼里越真实。
我们记第t步的回报为:

 

于是我们的轨迹可以变成

 

 有不同的方法来更新策略网络的参数θ

在GAIL中,使用的是TRPO

 强化学习笔记:置信域策略优化 TRPO_UQI-LIUWJ的博客-CSDN博客

即目标函数为

通过解带约束的最大化问题,得到新的参数

 

 2.4.2 训练判别器

训练判别器的目的是让它能区分真的轨迹与生成的轨迹
我们从训练数据中抽样一条轨迹:

同时用策略网络控制智能体和环境交互,得到另一条轨迹,记作

 注意real和fake轨迹的长度可能不一样

同样地,我们希望D(s_t^{real},a_t^{real};\phi)尽量趋近于1,D(s_t^{fake},a_t^{fake};\phi)尽量趋近于0

于是我们定义损失函数

 我们希望损失函数尽量小,也就是说判别器能区分开真假轨迹。可以做梯度下降来更新判别器的参数Φ

 

2.4.3 整体训练流程

每一轮训练更新一个生成器,更新一次判别器。训练重复以下步骤,直 到收敛。

 

 

  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
### 回答1: ESRGAN是增强型超分辨率生成对抗网络的缩写,它是一种深度学习模型,用于将低分辨率图像转换为高分辨率图像。它使用生成对抗网络(GAN)的方法,通过训练生成器和判别器来提高图像的质量。ESRGAN是目前最先进的超分辨率算法之一,它可以生成更加真实、细节更加丰富的高分辨率图像。 ### 回答2: ESRGAN是一种增强超分辨率生成对抗网络(Enhanced Super-Resolution Generative Adversarial Networks)的算法,它采用了图像增强技术和深度学习的方法,可以将低分辨率(LR)的图像转化为高分辨率(HR)的图像。该算法主要的贡献在于,它可以生成更加逼真的HR图像,从而更好地应用于实际的图像处理领域。 该算法主要是由两个子模型组成的,一个是生成模型(Generator),另一个是判别模型(Discriminator)。生成模型通过学习HR图像和相应的LR图像之间的关系,生成更加逼真的HR图像。而判别模型则评估生成模型生成的HR图像是否真实,从而提高生成模型的准确度。 ESRGAN算法采用特殊的损失函数,即感知损失和自适应增强损失,来优化生成模型。感知损失通过VGG网络来计算生成模型和HR图像之间的差异,以此来调整生成模型的参数。自适应增强损失则用于动态调整生成模型的输出图像的细节层次,使生成模型产生更加真实的输出图像。 ESRGAN算法在图像增强领域取得了显著的成果,其生成的HR图像质量要比先前的SRGAN算法有了很大的提升。因此,ESRGAN算法在实际应用中具有广泛的前景,可以为图像处理领域提供更加高效、准确和可靠的方法。 ### 回答3: ESRGAN(Enhanced Super-Resolution Generative Adversarial Networks)是一种利用深度学习算法进行图像超分辨率的技术。其主要思路是基于GAN模型,通过训练一个生成器去从低分辨率图像生成高分辨率图像,同时以高分辨率的真实图片为样本来训练判别器模型,使其能够区分出生成生成的图像是否为真实高清图像。 ESRGAN相对于传统的超分辨率算法,具有以下几个优点: 1.超分辨率效果更好。传统的超分辨率算法往往是基于一些数学模型进行插值运算,因此往往会出现图像模糊、失真等问题。而ESRGAN能够通过深度学习算法学习到更加准确的纹理特征,从而可以生成更为真实的高清图像。 2.可扩展性更强。ESRGAN的GAN模型可以通过增加网络深度、增加训练数据等方式对模型进行优化,从而提高图像超分辨率效果。 3.针对性更强。ESRGAN可以针对不同种类的图像进行训练,从而能够对各种类型的图像进行超分辨率处理,具有广泛的适用性。 4.易于应用。ESRGAN训练出的模型可以很方便地应用到实际生产环境中,对于需要进行图像超分辨率处理的应用场景具有很大的帮助作用。 虽然ESRGAN在图像超分辨率方面具有较为突出的优势,但其也存在一些缺点和挑战。比如需要大量的高清图像数据用于训练,需要考虑到训练时间和计算资源的问题;还需要解决一些局部纹理复杂的图像超分辨率问题。总之,ESRGAN是一种非常有潜力的图像超分辨率算法,将有助于推动图像处理技术的进一步发展。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值