RNN-RBM网络介绍及代码分析

本文介绍了RNN-RBM网络,详细讲解了RBM的参数、训练过程和损失函数,并探讨了如何将RNN与RBM结合。通过Theano库提供了RNN-RBM的构建代码,包括训练和序列生成的函数。同时,还提供了一个简单的RnnRbm类,用于训练和生成MIDI文件的序列。
摘要由CSDN通过智能技术生成

官方学习教程的地址:http://deeplearning.net/tutorial/rnnrbm.html#rnnrbm

关于RBM的介绍及代码:http://deeplearning.net/tutorial/rbm.html#rbm

但只从介绍很难明白RNN与RBM究竟怎么在训练过程中互相结合的。

————————————————————————————————————

首先RBM网络:

其参数主要有隐层与显层的关系矩阵W,隐层的偏置bh,显层的偏置bv。

对于输入数据V矩阵,V的每一行代表一个sample,这里针对每个sample,训练的都是同样的一组bh与bv,即从始至终bh与bv都只是一个一维向量。

训练过程中的loss函数如何得到:有样本V,根据Gibbs采样生成的样本V‘,计算V对应的能量函数值,计算V’对应的能量函数值,二者相减即为loss函数。

————————————————————————————————————

LSTM网络学习:http://deeplearning.net/tutorial/lstm.html

————————————————————————————————————

而如何将RNN与RBM相结合呢?

在原有的RBM中,一组bh与bv针对是所有时序的采样,而这里为了体现出时序的特点,采样每个时序对应一组bh与bv。

而具体bh与bv则由RNN的循环递归生成(V也参与其中)。

一个更容易理解的图:


代码:

# Author: Nicolas Boulanger-Lewandowski
# University of Montreal (2012)
# RNN-RBM deep learning tutorial
# More information at http://deeplearning.net/tutorial/rnnrbm.html


from __future__ import print_function


import glob
import os
import sys


import numpy
try:
    import pylab
except ImportError:
    print ("pylab isn't available. If you use its functionality, it will crash.")
    print("It can be installed with 'pip install -q Pillow'")


from midi.utils import midiread, midiwrite
import theano
import theano.tensor as T
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams


#Don't use a python long as this don't work on 32 bits computers.
numpy.random.seed(0xbeef)
rng = RandomStreams(seed=numpy.random.randint(1 << 30))
theano.config.warn.subtensor_merge_bug = False




def build_rbm(v, W, bv, bh, k):
    '''Construct a k-step Gibbs chain starting at v for an RBM.


    v : Theano vector or matrix
        If a matrix, multiple chains will be run in parallel (batch).
    W : Theano matrix
        Weight matrix of the RBM.
    bv : Theano vector
        Visible bias vector of the RBM.
    bh : Theano vector
        Hidden bias vector of the RBM.
    k : scalar or Theano scalar
        Length of the Gibbs chain.


    Return a (v_sample, cost, monitor, updates) tuple:


    v_sample : Theano vector or matrix with the same shape as `v`
        Corresponds to the generated sample(s).
    cost : Theano scalar
        Expression whose gradient with respect to W, bv, bh is the CD-k
        approximation to the log-likelihood of `v` (training

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值