贝叶斯网络参数学习(基于FullBNT-1.0.4的MATLAB实现)

题目:贝叶斯网络参数学习(基于FullBNT-1.0.4的MATLAB实现)

        贝叶斯网络学习分为结构学习和参数学习,前面用三篇分别介绍了两个工具箱共三个贝叶斯网络结构学习方法:

        贝叶斯网络结构学习之K2算法(基于FullBNT-1.0.4的MATLAB实现)

        贝叶斯网络结构学习之MCMC算法(基于FullBNT-1.0.4的MATLAB实现)

        贝叶斯网络结构学习(基于BDAGL工具箱的MATLAB实现)

        使用这些方法可以根据数据集学习得到贝叶斯网络结构DAG,对于部分应用来说这就足够了,因为得到了贝叶斯网络结构也就知道了各节点的相互依赖关系。但对某些应用来说不仅仅想知道相互依赖关系,更想知道相互依赖关系定量的描述,这时就要得到贝叶斯网络参数。什么是贝叶斯网络参数呢?其实就是每个节点的条件概率表(Conditional Probability Table, CPT)。看本篇内容最好简单浏览一下前面的《贝叶斯网络结构学习若干问题解释》,会对一些基本概念有一个初步的了解。

        工具箱FullBNT-1.0.4中包括两个参数学习函数:learn_params和bayes_update_params。

        本篇不再深究这两个参数学习函数的内部原理,目标仅是学会如何使用。

        两个函数的使用例子参见\FullBNT-1.0.4\BNT\examples\static\learn1.m

        例子learn1.m源代码如下:

% Lawn sprinker example from Russell and Norvig p454
% See www.cs.berkeley.edu/~murphyk/Bayes/usage.html for details.

N = 4; 
dag = zeros(N,N);
C = 1; S = 2; R = 3; W = 4;
dag(C,[R S]) = 1;
dag(R,W) = 1;
dag(S,W)=1;

false = 1; true = 2;
ns = 2*ones(1,N); % binary nodes

bnet = mk_bnet(dag, ns);
bnet.CPD{C} = tabular_CPD(bnet, C, [0.5 0.5]);
bnet.CPD{R} = tabular_CPD(bnet, R, [0.8 0.2 0.2 0.8]);
bnet.CPD{S} = tabular_CPD(bnet, S, [0.5 0.9 0.5 0.1]);
bnet.CPD{W} = tabular_CPD(bnet, W, [1 0.1 0.1 0.01 0 0.9 0.9 0.99]);

CPT = cell(1,N);
for i=1:N
  s=struct(bnet.CPD{i});  % violate object privacy
  CPT{i}=s.CPT;
end

% Generate training data
nsamples = 50;
samples = cell(N, nsamples);
for i=1:nsamples
  samples(:,i) = sample_bnet(bnet);
end
data = cell2num(samples);

% Make a tabula rasa
bnet2 = mk_bnet(dag, ns);
seed = 0;
rand('state', seed);
bnet2.CPD{C} = tabular_CPD(bnet2, C, 'clamped', 1, 'CPT', [0.5 0.5], ...
			   'prior_type', 'dirichlet', 'dirichlet_weight', 0);
bnet2.CPD{R} = tabular_CPD(bnet2, R, 'prior_type', 'dirichlet', 'dirichlet_weight', 0);
bnet2.CPD{S} = tabular_CPD(bnet2, S, 'prior_type', 'dirichlet', 'dirichlet_weight', 0);
bnet2.CPD{W} = tabular_CPD(bnet2, W, 'prior_type', 'dirichlet', 'dirichlet_weight', 0);


% Find MLEs from fully observed data
bnet4 = learn_params(bnet2, samples);

% Bayesian updating with 0 prior is equivalent to ML estimation
bnet5 = bayes_update_params(bnet2, samples);

CPT4 = cell(1,N);
for i=1:N
  s=struct(bnet4.CPD{i});  % violate object privacy
  CPT4{i}=s.CPT;
end

CPT5 = cell(1,N);
for i=1:N
  s=struct(bnet5.CPD{i});  % violate object privacy
  CPT5{i}=s.CPT;
  assert(approxeq(CPT5{i}, CPT4{i}))
end


if 0
% Find MLEs from partially observed data

% hide 50% of the nodes
samplesH = samples;
hide = rand(N, nsamples) > 0.5;
[I,J]=find(hide);
for k=1:length(I)
  samplesH{I(k), J(k)} = [];
end

engine = jtree_inf_engine(bnet2);
max_iter = 5;
[bnet6, LL] = learn_params_em(engine, samplesH, max_iter);

CPT6 = cell(1,N);
for i=1:N
  s=struct(bnet6.CPD{i});  % violate object privacy
  CPT6{i}=s.CPT;
end

end

        前18行代码是生成一个贝叶斯网络,其中第4到9行是定义贝叶斯网络结构DAG,第12行给出每个节点离散值的个数ns向量,第14行使用工具箱的函数mk_bnet根据DAG和ns生成一个贝叶斯网络bnet(理解为工具箱自己定义的一种贝叶斯网络存储方式吧),第15到18行定义该贝叶斯网络bnet的条件概率表,即贝叶斯网络参数。

        第20到24行代码是将贝叶斯网络bnet的参数取出存在CPT里,主要是可以用来和后面的函数learn_params 学得的参数CPT4和函数bayes_update_params学得的参数CPT5作对比。

        第26到32行代码是根据贝叶斯网络bnet生成一些数据samples,供函数learn_params和函数bayes_update_params学习使用。

        小结一下,前32行代码主要是给定一个贝叶斯网络(包括结构和参数),然后根据这个贝叶斯网络生成一些数据。接下来要做的是根据这些数据,分别使用函数learn_params和函数bayes_update_params在已知结构DAG的前提下学习贝叶斯网络参数。

        第34到42行代码先是生成一个贝叶斯网络bnet2(与第14行完全相同),然后使用函数tabular_CPD初始化该贝叶斯网络参数,bnet2与bnet的区别仅在于网络参数,至于说为什么要生成网络bnet2,个人感觉应该是函数learn_params和函数bayes_update_params的学习过程可能是一个迭代过程,需要从某个起点开始迭代(这两个函数内部细节没研究,所以是纯属瞎猜的,别当真^_^)。

        再插一句第36行到第37行两句代码的含义:这个主要是保证每次运行结果都一样。大家都知道rand函数可以产生随机数,但实际也只能是伪随机数而已,比如rand(3,4)是产生一个3*4的矩阵,矩阵每个元素均在0到1之间。但是,若在matlab中执行以下代码:

rand('state', 0);rand(3,4)

你会发现每次生的结果均是一样的,具体可以参见:Matlab中rand('state',s)和rand('state',0)表示什么意思?,链接:http://www.ilovematlab.cn/thread-57952-1-1.html

        第46行代码调用函数learn_params学习贝叶斯网络参数,从第45行注释可以知道这个函数的原理是最大似然估计(maximumlikelihood estimator, MLE)。

        第49行代码调用函数bayes_update_params学习贝叶斯网络参数,第48行注释似乎在说若没有先验时等价于最大似然估计。

        第51到55行代码将函数learn_params学得的参数存到CPT4当中,第57到62行代码将函数bayes_update_params学得的参数存到CPT5当中,与第20到第24行所做事情完全一样。

        第65行之后的代码是在数据集并非完全可见时调用函数learn_params_em学习贝叶斯网络参数(函数名称最后的em肯定是指期望最大化(ExpectationMaximization)的英文首字母了),当然从第66行的注释其实就可以知道代码在做什么事情。其中第68到74行是手动将前面的数据集samples隐藏50%的元素(第73行将其置为“空”,即不可见)。这段代码由第65行的if语句选择执行,默认为”if 1”,上面我将其改为了”if 0”,因为我执行这段代码报错了,而对于部分可见数据集的参数学习问题暂时也不在我关注范围之内,所以就先不去琢磨了。(注:工具箱BNT并没有问题,之所以执行这段代码报错是由于后来我又添加了BDAGL工具箱,对BNT工具箱造成影响所致,去除BDAGL工具箱运行就没有任何问题了@20180304)

 

        其实细想一下可知,对于数据集中所有数据完全可见的贝叶斯网络来说,在已知结构DAG的前提下,完全可以通过对数据集手动计数的方式计算出贝叶斯网络参数。具体来说,对于例子中的贝叶斯网络,结构如下:

所有节点取值均为1和2共两个离散值。

        对于节点C来说,没有parents,所以该节点的条件概率表即为

P(C=1)

P(C=2)

要估计这两个概率只需要分别统计数据集中C=1的数据个数#(C=1)和C=2的数据个数#(C=2),则

P(C=1)= #(C=1) / [#(C=1) + #(C=2)],

P(C=2)= #(C=2) / [#(C=1) + #(C=2)];

        对于节点S来说,parents为节点C,所以该节点的条件概率表即为

P(S=1|C=1)

P(S=2|C=1)

P(S=1|C=2)

P(S=2|C=2)

要估计这四个概率需要分别统计数据集中当C=1时S=1的个数#(S=1|C=1)和S=2的个数#(S=2|C=1),以及当C=2时S=1的个数#(S=1|C=2)和S=2的个数#(S=2|C=2),则

P(S=1|C=1)= #(S=1|C=1) / [ #(S=1|C=1) + #(S=2|C=1)],

P(S=2|C=1)= #(S=2|C=1) / [ #(S=1|C=1) + #(S=2|C=1)],

P(S=1|C=2)= #(S=1|C=2) / [ #(S=1|C=2) + #(S=2|C=2)],

P(S=2|C=2)= #(S=2|C=2) / [ #(S=1|C=2) + #(S=2|C=2)];

        对于节点R来说,情况与节点S一样,因此只需将有关节点S的叙述中的S改为R即可;

        对于节点W来说,parents为R和S,所以该节点的条件概率表即为

P(W=1|R=1,S=1)

P(W=2|R=1,S=1)

P(W=1|R=1,S=2)

P(W=2|R=1,S=2)

P(W=1|R=2,S=1)

P(W=2|R=2,S=1)

P(W=1|R=2,S=2)

P(W=2|R=2,S=2)

要估计这八个概率需要分别统计数据集中当R=1且S=1时W=1的个数#(W=1|R=1,S=1)和W=2的个数#(W=2|R=1,S=1),当R=1且S=2时W=1的个数#(W=1|R=1,S=2)和W=2的个数#(W=2|R=1,S=2),当R=2且S=1时W=1的个数#(W=1|R=2,S=1)和W=2的个数#(W=2|R=2,S=1),当R=2且S=2时W=1的个数#(W=1|R=2,S=2)和W=2的个数#(W=2|R=2,S=2),则

P(W=1|R=1,S=1)= #(W=1|R=1,S=1) / [#(W=1|R=1,S=1) + #(W=2|R=1,S=1)],

P(W=2|R=1,S=1)= #(W=2|R=1,S=1) / [#(W=1|R=1,S=1) + #(W=2|R=1,S=1)],

P(W=1|R=1,S=2)= #(W=1|R=1,S=2) / [#(W=1|R=1,S=2) + #(W=2|R=1,S=2)],

P(W=2|R=1,S=2)= #(W=2|R=1,S=2) / [#(W=1|R=1,S=2) + #(W=2|R=1,S=2)],

P(W=1|R=2,S=1)= #(W=1|R=2,S=1) / [#(W=1|R=2,S=1) + #(W=2|R=2,S=1)],

P(W=2|R=2,S=1)= #(W=2|R=2,S=1) / [#(W=1|R=2,S=1) + #(W=2|R=2,S=1)],

P(W=1|R=2,S=2)= #(W=1|R=2,S=2) / [#(W=1|R=2,S=2) + #(W=2|R=2,S=2)],

P(W=2|R=2,S=2)= #(W=2|R=2,S=2) / [#(W=1|R=2,S=2) + #(W=2|R=2,S=2)];

        为了有更加直观的认识,我将例子中的数据集(samples)中的10个数据摘了出来,如下:

1       1       2       1       2       1       1       1       1       2

2       2       1       1       1       1       2       1       2       2

1       1       2       1       1       1       1       1       1       2

2       2       2       1       1       1       2       1       2       2

数据集中第一行对应C节点,第二行对应S节点,第三行对应R节点,第四行对应W节点,每一列为一组数据,共10列因此称为共10个数据。若仅从这10个数据来看,#(C=1)=7,#(C=2)=3,因为第一行共有7个1和3个2;#(S=1|C=1)=3,#(S=2|C=1)=4,因为第二行当中对应第一行等于1的列当中共有3个1(第2行的第4、6、8列)和4个2(第2行的第1、2、7、9列),剩下的数据以此类推就好了,具体可参见接下来的matlab程序。

        为了验证上面的叙述是否正确,我专门编写了程序learn1_statics.m,该程序需要执行完learn1.m后直接运行,因为该程序中使用了learn1.m运行时产生的数据。

%mylearn1 统计数据@20171225am by jbb0523
%C = 1; S = 2; R = 3; W = 4;
%   C
%  / \
% S   R
% \   /
%   W

P_C1 = sum(data(C,:)==1)/length(data(C,:));
P_C2 = sum(data(C,:)==2)/length(data(C,:));
% disp(['验证:P(C=1)=',num2str(P_C1),',P(C=2)=',num2str(P_C2)]);
% disp(['CPT4:P(C=1)=',num2str(CPT4{C}(1)),',P(C=2)=',num2str(CPT4{C}(2))]);
% disp(['CPT5:P(C=1)=',num2str(CPT5{C}(1)),',P(C=2)=',num2str(CPT5{C}(2))]);
P_C=[P_C1;P_C2];

position = (data(C,:)==1);
P_S1_C1 = sum(data(S,position)==1)/length(data(S,position));
P_S2_C1 = sum(data(S,position)==2)/length(data(S,position));
position = (data(C,:)==2);
P_S1_C2 = sum(data(S,position)==1)/length(data(S,position));
P_S2_C2 = sum(data(S,position)==2)/length(data(S,position));
% disp(['验证:P(S=1|C=1)=',num2str(P_S1_C1),',P(S=2|C=1)=',num2str(P_S2_C1)]);
% disp(['CPT4:P(S=1|C=1)=',num2str(CPT4{S}(1,1)),',P(S=2|C=1)=',num2str(CPT4{S}(1,2))]);
% disp(['CPT5:P(S=1|C=1)=',num2str(CPT5{S}(1,1)),',P(S=2|C=1)=',num2str(CPT5{S}(1,2))]);
% disp(['验证:P(S=1|C=2)=',num2str(P_S1_C2),',P(S=2|C=2)=',num2str(P_S2_C2)]);
% disp(['CPT4:P(S=1|C=2)=',num2str(CPT4{S}(2,1)),',P(S=2|C=2)=',num2str(CPT4{S}(2,2))]);
% disp(['CPT5:P(S=1|C=2)=',num2str(CPT5{S}(2,1)),',P(S=2|C=2)=',num2str(CPT5{S}(2,2))]);
P_S_C=[P_S1_C1 P_S2_C1;P_S1_C2 P_S2_C2];

position = (data(C,:)==1);
P_R1_C1 = sum(data(R,position)==1)/length(data(R,position));
P_R2_C1 = sum(data(R,position)==2)/length(data(R,position));
position = (data(C,:)==2);
P_R1_C2 = sum(data(R,position)==1)/length(data(R,position));
P_R2_C2 = sum(data(R,position)==2)/length(data(R,position));
% disp(['验证:P(R=1|C=1)=',num2str(P_R1_C1),',P(R=2|C=1)=',num2str(P_R2_C1)]);
% disp(['CPT4:P(R=1|C=1)=',num2str(CPT4{R}(1,1)),',P(R=2|C=1)=',num2str(CPT4{R}(1,2))]);
% disp(['CPT5:P(R=1|C=1)=',num2str(CPT5{R}(1,1)),',P(R=2|C=1)=',num2str(CPT5{R}(1,2))]);
% disp(['验证:P(R=1|C=2)=',num2str(P_R1_C2),',P(R=2|C=2)=',num2str(P_R2_C2)]);
% disp(['CPT4:P(R=1|C=2)=',num2str(CPT4{R}(2,1)),',P(R=2|C=2)=',num2str(CPT4{R}(2,2))]);
% disp(['CPT5:P(R=1|C=2)=',num2str(CPT5{R}(2,1)),',P(R=2|C=2)=',num2str(CPT5{R}(2,2))]);
P_R_C = [P_R1_C1 P_R2_C1;P_R1_C2 P_R2_C2];

positionR1 = (data(R,:)==1);positionS1 = (data(S,:)==1);
positionR2 = (data(R,:)==2);positionS2 = (data(S,:)==2);
positionR1S1 = positionR1&positionS1;%逻辑与&
positionR1S2 = positionR1&positionS2;%逻辑与&
positionR2S1 = positionR2&positionS1;%逻辑与&
positionR2S2 = positionR2&positionS2;%逻辑与&
P_W1_R1S1 = sum(data(W,positionR1S1)==1)/length(data(W,positionR1S1));
P_W1_R1S2 = sum(data(W,positionR1S2)==1)/length(data(W,positionR1S2));
P_W1_R2S1 = sum(data(W,positionR2S1)==1)/length(data(W,positionR2S1));
P_W1_R2S2 = sum(data(W,positionR2S2)==1)/length(data(W,positionR2S2));
P_W2_R1S1 = sum(data(W,positionR1S1)==2)/length(data(W,positionR1S1));
P_W2_R1S2 = sum(data(W,positionR1S2)==2)/length(data(W,positionR1S2));
P_W2_R2S1 = sum(data(W,positionR2S1)==2)/length(data(W,positionR2S1));
P_W2_R2S2 = sum(data(W,positionR2S2)==2)/length(data(W,positionR2S2));
% disp(['验证:P(W=1|R=1,S=1)=',num2str(P_W1_R1S1),',P(W=2|R=1,S=1)=',num2str(P_W2_R1S1)]);
% disp(['验证:P(W=1|R=1,S=2)=',num2str(P_W1_R1S2),',P(W=2|R=1,S=2)=',num2str(P_W2_R1S2)]);
% disp(['验证:P(W=1|R=2,S=1)=',num2str(P_W1_R2S1),',P(W=2|R=2,S=1)=',num2str(P_W2_R2S1)]);
% disp(['验证:P(W=1|R=2,S=2)=',num2str(P_W1_R2S2),',P(W=2|R=2,S=2)=',num2str(P_W2_R2S2)]);
P_W_RS = [P_W1_R1S1 P_W2_R1S1;P_W1_R1S2 P_W2_R1S2;P_W1_R2S1 P_W2_R2S1;P_W1_R2S2 P_W2_R2S2];
disp('------------------------------P(C)------------------------------')
disp('手动计算P(C):')
disp(P_C)
disp('learn_params(CPT4)P(C):')
disp(CPT4{C})
disp('bayes_update_params(CPT5)P(C):')
disp(CPT5{C})
disp('------------------------------P(S|C)------------------------------')
disp('手动计算P(S|C):')
disp(P_S_C)
disp('learn_params(CPT4)P(S|C):')
disp(CPT4{S})
disp('bayes_update_params(CPT4)P(S|C):')
disp(CPT5{S})
disp('------------------------------P(R|C)------------------------------')
disp('手动计算P(R|C):')
disp(P_R_C)
disp('learn_params(CPT4)P(R|C):')
disp(CPT4{R})
disp('bayes_update_params(CPT4)P(R|C):')
disp(CPT5{R})
disp('------------------------------P(W|S,R)------------------------------')
disp('手动计算P(W|R,S):')
disp(P_W_RS)
disp('learn_params(CPT4)P(W|R,S):')
a=CPT4{4}(:,:,1);b=CPT4{4}(:,:,2);P_W_RS_CPT4=[a(:),b(:)];disp(P_W_RS_CPT4);
disp('bayes_update_params(CPT5)P(W|R,S):')
a=CPT5{4}(:,:,1);b=CPT5{4}(:,:,2);P_W_RS_CPT5=[a(:),b(:)];disp(P_W_RS_CPT5);

运行learn1.m后再运行该程序,可得如下输出:

------------------------------P(C)------------------------------

手动计算P(C):

   0.5400

   0.4600

 

learn_params(CPT4)P(C):

   0.5400

   0.4600

 

bayes_update_params(CPT5)P(C):

   0.5400

   0.4600

 

------------------------------P(S|C)------------------------------

手动计算P(S|C):

   0.4444    0.5556

   0.9565    0.0435

 

learn_params(CPT4)P(S|C):

   0.4444    0.5556

   0.9565    0.0435

 

bayes_update_params(CPT4)P(S|C):

   0.4444    0.5556

   0.9565    0.0435

 

------------------------------P(R|C)------------------------------

手动计算P(R|C):

   0.8889    0.1111

   0.2609    0.7391

 

learn_params(CPT4)P(R|C):

   0.8889    0.1111

   0.2609    0.7391

 

bayes_update_params(CPT4)P(R|C):

   0.8889    0.1111

   0.2609    0.7391

 

------------------------------P(W|S,R)------------------------------

手动计算P(W|R,S):

   1.0000         0

   0.1538    0.8462

    0.1765   0.8235

        0    1.0000

 

learn_params(CPT4)P(W|R,S):

   1.0000         0

   0.1538    0.8462

   0.1765    0.8235

        0    1.0000

 

bayes_update_params(CPT5)P(W|R,S):

   1.0000         0

   0.1538    0.8462

   0.1765    0.8235

        0    1.0000

可以发现手动统计的结果与函数learn_params和函数bayes_update_params的输出结果一模一样。

        那么,既然可以手动统计学习贝叶斯网络参数,为什么工具箱还要写这两个函数呢?

        你是在问我么?

        我哪儿知道该问谁去^_^

        对了,能不能自己写一个手动统计学习贝叶斯网络参数的函数呢?

  • 35
    点赞
  • 202
    收藏
    觉得还不错? 一键收藏
  • 27
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 27
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值