关键是 CD-k(contrastive_divergence)算法的实现。
// the CD-k algorithm
void RBM::contrastive_divergence (
int *input, // the input visiable sample
double lr, // the learning rate
int k // the k value in CD-k
)
{
// allocate the memory for <v,h>|data, <v,h>|reconstructed
// the probability expectiion (mean) of the hidden nodes, p(h0|v)
double *ph_mean = new double[n_hidden];
// the 0-1 state of the hidden nodes
int *ph_sample = new int[n_hidden];
// the probability expectiion (mean) of the visiable nodes, p(v|h)
double *nv_means = new double[n_visible];
// the pointer to the 0-1 state of the visiable nodes
int *nv_samples = new int[n_visible];
// the probability expectiion (mean) of the hidden nodes, p(h1|v)
double *nh_means = new double[n_hidden];
// the 0-1 state of the hidden nodes
int *nh_samples = new int[n_hidden];
/* CD-k */
// step 1: get the probability and 0-1 state in the hidden nodes according to
// the input from the visiable nodes
sample_h_given_v(input, ph_mean, ph_sample);
// step 2: gibbs sample
// data in visiable node --> p(h0|v) in hidden node --> reconstruct the visiable node p(v|h) --> p(h1|v) in hidden node ...
for(int step=0; step<k; step++)
{
if(step == 0)
{
gibbs_hvh (ph_sample, // one input sample from hidden node, h0 -- input
nv_means, // the output probability of visiable nodes -- output
nv_samples, // the calculated 0-1 state of visiable node -- output
nh_means, // the output probability of reconstructed hidden node h1 -- output
nh_samples // the calculated 0-1 state of reconstructed hidden node h1 -- output
);
}
else
{
gibbs_hvh (nh_samples, // one input sample from hidden node, h0 -- input
nv_means, // the output probability of visiable nodes -- output
nv_samples, // the calculated 0-1 state of visiable node -- output
nh_means, // the output probability of reconstructed hidden node h1 -- output
nh_samples // the calculated 0-1 state of reconstructed hidden node h1 -- output
);
}
}
// update the value of W
for(int i=0; i<n_hidden; i++)
{
for(int j=0; j<n_visible; j++)
{
// w[hidden][visible] += learningRate * DeltaW, while DeltaW = <v,h>|data - <v,h>|reconstruct
// <v,h>|data is the expectation of v_i and h_j given the input data
// <v,h>|data = ( 0-1 of h0 ) * ( 0-1 in input sample) / the total number of sample
// <v,h>|reconstruct is the expectation v_i and h_j by the reconstructed visible nodes
// <v,h>|data = ( 0-1 of h1 ) * ( 0-1 in reconstructed node) / the total number of sample
W[i][j] += lr * (ph_sample[i] * input[j] - nh_means[i] * nv_samples[j]) / N;
}
// hiddenBias += learningRate * ( <h>|data - <h>|model )
// = learningRate * ( h0 - h1 ) / N
hbias[i] += lr * (ph_sample[i] - nh_means[i]) / N;
}
for(int i=0; i<n_visible; i++)
{
// visibleBias += learningRate * ( <v>|data - <v>|model )
// = learningRate * ( input - reconstructedNode ) / N
vbias[i] += lr * (input[i] - nv_samples[i]) / N;
}
// release the memory for <v,h>|data, <v,h>|reconstructed
delete[] ph_mean;
delete[] ph_sample;
delete[] nv_means;
delete[] nv_samples;
delete[] nh_means;
delete[] nh_samples;
} // contrastive_divergence
void RBM::sample_h_given_v (
int *v0_sample, // one input sample from visiable nodes -- input
double *mean, // the output probability of hidden nodes -- output
int *sample // the calculated 0-1 state of hidden node -- output
)
{
// iterate all the hidden node
for(int i=0; i<n_hidden; i++)
{
// calculate the probablity of hidden node given the input sample and
// the RBM weight from bottem to top
mean[i] = propup(v0_sample, W[i], hbias[i]);
// binomial test to decide the 0-1 state of each hidden node
sample[i] = binomial(1, mean[i]);
}
}
void RBM::sample_v_given_h (
int *h0_sample, // one input sample from hidden nodes -- input
double *mean, // the output probability of visiable nodes -- output
int *sample // the calculated 0-1 state of visiable node -- output
)
{
// iterate all the visiable node
for(int i=0; i<n_visible; i++)
{
// calculate the probablity of visible node given the hidden node sample and
// the RBM weight from top to bottem
mean[i] = propdown(h0_sample, i, vbias[i]);
// binomial test to decide the 0-1 state of each visiable node
// this step reconstruct the visiable nodes
sample[i] = binomial(1, mean[i]);
}
}
// the returned probability is : p (hi|v) = sigmod ( sum_j(vj * wij) + bi)
double RBM::propup (
int *v, // one input sample from visiable node -- input
double *w, // the weight W connecting one hidden node to all visible node -- input
double b // the bias for this hidden node -- input
)
{
// calculated sum_j(vj * wij) + bi )
double pre_sigmoid_activation = 0.0;
for(int j=0; j<n_visible; j++)
{
pre_sigmoid_activation += w[j] * v[j];
}
pre_sigmoid_activation += b;
// sigmod (pre_sigmoid_activation)
return sigmoid(pre_sigmoid_activation);
}
// the returned probability is : p (vi|h) = sigmod ( sum_j(hj * wij) + ci)
double RBM::propdown(
int *h, // one input sample from hidden node -- input
int i, // the index of visiable node in the W matrix -- input
double b // the bias for this visible node -- input
)
{
// calcualte sum_j(hj * wij) + ci
double pre_sigmoid_activation = 0.0;
for(int j=0; j<n_hidden; j++)
{
pre_sigmoid_activation += W[j][i] * h[j];
}
pre_sigmoid_activation += b;
// sigmod (pre_sigmoid_activation)
return sigmoid(pre_sigmoid_activation);
}
void RBM::gibbs_hvh (
int *h0_sample, // one input sample from hidden node, h0 -- input
double *nv_means, // the output probability of visiable nodes -- output
int *nv_samples, // the calculated 0-1 state of visiable node -- output
double *nh_means, // the output probability of reconstructed hidden node h1 -- output
int *nh_samples // the calculated 0-1 state of reconstructed hidden node h1 -- output
)
{
// calculate p(v|h)
sample_v_given_h(h0_sample, nv_means, nv_samples);
// calculate p(h|v)
sample_h_given_v(nv_samples, nh_means, nh_samples);
}
void RBM::reconstruct(int *v, double *reconstructed_v)
{
// h[i] = p(h_i|v)
double *h = new double[n_hidden];
double pre_sigmoid_activation;
// calculate p(h|v) given v
for(int i=0; i<n_hidden; i++)
{
h[i] = propup(v, W[i], hbias[i]);
}
// calculate p(v_reconstruct|h)
for(int i=0; i<n_visible; i++)
{
pre_sigmoid_activation = 0.0;
for(int j=0; j<n_hidden; j++)
{
pre_sigmoid_activation += W[j][i] * h[j];
}
pre_sigmoid_activation += vbias[i];
reconstructed_v[i] = sigmoid(pre_sigmoid_activation);
}
delete[] h;
}