混合密度网络Mixture Density Networks(MDN)

简介

平方和或交叉熵误差函数的最小化导致网络输出近似目标数据的条件平均值,以输入向量为条件。对于分类问题,只要选择合适的目标编码方案,这些平均值表示类隶属度的后验概率,因此可以认为是最优的。然而,对于涉及连续变量预测的问题,条件平均只能对目标变量的性质提供非常有限的描述。对于要学习的映射是多值的问题尤其如此,就像反问题的解中经常出现的那样,因为几个正确目标值的平均值本身不一定是正确的值。为了获得数据的完整描述,为了预测与新输入向量对应的输出,我们必须对目标数据的条件概率分布进行建模,同样以输入向量为条件。本文介绍了将传统神经网络与混合密度模型相结合而得到的一类新的网络模型。完整的系统被称为混合密度网络,原则上可以像传统神经网络表示任意函数一样表示任意条件概率分布。我们用一个玩具问题和一个涉及机器人逆运动学的问题来证明混合密度网络的有效性。

作者:Bishop, Christopher M. (1994).  混合密度网络的提出者;
论文:Mixture density networks. 
出版:Technical Report. Aston University, Birmingham.
论文地址:https://publications.aston.ac.uk/id/eprint/373

论文地址

关注微信公众号,获取更多资讯内容:
在这里插入图片描述

1 介绍

混合密度网络通常作为神经网络的最后处理部分。将某种分布(通常是高斯分布)按照一定的权重进行叠加,从而拟合最终的分布。

如果选择高斯分布的MDN,那么它和GMM(高斯混合模型 Gaussian Mixture Model)有着相同的效果。但是他们有着很明显的区别

在这里插入图片描述
1 该部分案例参考该博客

  • MDN的均值、方差、每个模型的权重是通过神经网络产生的,利用最大似然估计作为Loss函数进行反向传播从而确定网络的权重(也就是确定一个较好的高斯分布参数)
  • GMM的均值、方差、每个模型的权重是通过估计出来的,通常使用EM算法来通过不断迭代确定。

MDN实现简单,而且可以直接模块化的连接到神经网络的后端。他的结果可以得到一个概率范围,相对有deterministic类只输出一个结果,往往有更好的健壮性。
[1]中有相关代码实现。

2 实现

假设我们要拟合如下一个带噪声的函数:
y = 7.0 s i n ( 0.75 x ) + 0.5 x + ϵ y=7.0sin(0.75x)+0.5x+ϵ y=7.0sin(0.75x)+0.5x+ϵ
原始图像为:
在这里插入图片描述
使用神经网络拟合得到:
在这里插入图片描述
对调x和y,再用神经网络拟合得到:
在这里插入图片描述
使用MDN:对于单一输入x,预测y的概率分布。DN的输出为服从混合高斯分布(Mixture Gaussian distributions),具体的输出值被建模为多个高斯随机值的和:
在这里插入图片描述

class MDN(nn.Module):
    def __init__(self, n_hidden, n_gaussians):
        super(MDN, self).__init__()
        self.z_h = nn.Sequential(
            nn.Linear(1, n_hidden),
            nn.Tanh()
        )
        self.z_pi = nn.Linear(n_hidden, n_gaussians)
        self.z_mu = nn.Linear(n_hidden, n_gaussians)
        self.z_sigma = nn.Linear(n_hidden, n_gaussians)
    def forward(self, x):
        z_h = self.z_h(x)
        pi = F.softmax(self.z_pi(z_h), -1)
        mu = self.z_mu(z_h)
        sigma = torch.exp(self.z_sigma(z_h))
        return pi, mu, sigma

由于输出本质上是概率分布,因此不能采用诸如L1损失、L2损失的硬损失函数。这里我们采用了对数似然损失(和交叉熵类似):
在这里插入图片描述
使用MDN得到的如下结果:
在这里插入图片描述
具体过程,请参考:
Github库
YoungTimes博客
xiongxyowo的CSDN博客

3 几个MDN的应用:

在这里插入图片描述
3 参考自博客

参考

[1] A Hitchhiker’s Guide to Mixture Density Networks

  • 13
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值