% 定义含有三个样本的符号运算
% d : nsample * nlabel
% lnd : nsample * nlabel
nlabel = 5;
nsample = 3;
feat_len = 4;
% lambda : nlabel nlabel
for i = 1:nlabel
for j = 1:nlabel
lambda(i,j) = sym(['lambda' num2str(i) num2str(j)],'real');
end
end
% d: nsample * nlabel
for i = 1:nsample
for j = 1:nlabel
d(i,j) = sym(['d' num2str(i) num2str(j)],'real');
end
end
% lnd : nsample * nlabel
for i = 1:nsample
for j = 1:nlabel
lnd(i,j) = sym(['lnd' num2str(i) num2str(j)],'real');
end
end
for i = 1:nsample
for j = 1:nlabel
lnp(i,j) = sym(['lnp' num2str(i) num2str(j)],'real');
end
end
for i = 1:nsample
for j = 1:nlabel
p(i,j) = sym(['p' num2str(i) num2str(j)],'real');
end
end
% d .* lnd: nsample * nlabel
% (d.*lnd) *lambda : [nsample * nlabel] [nlabel * nlabel] ==> nample *
% nlabel
%%%%%%%%%%%%%%%%%%%%%%%%% 第一项 %%%%%%%%%%%%%%%%%%%%%%%%%%
standard_result = 0;
for i = 1:nsample
for j = 1:nlabel
for m = 1:nlabel
standard_result = standard_result + lambda(j,m)* d(i,j)* lnd(i,j);
end
end
end
% 版本1
sum_lambda = sum(lambda,2);
result1 = sum(sum_lambda'*(d.*lnd)');
disp(simplify(result1 - standard_result));
% 版本2,版本1和版本2得到同样的结果
result1 = sum(sum(lambda'*(lnd' .* d')));
disp(simplify(result1 - standard_result));
%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 第二项 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%
standard_result = 0;
for i = 1:nsample
for j = 1:nlabel
for m = 1:nlabel
standard_result = standard_result + lambda(j,m)*d(i,j)* lnp(i,m);
end
end
end
result2 = lambda*lnp'.*d';
result2 = sum(sum(result2));
disp(simplify(result2 - standard_result));
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 第三项 %%%%%%%%%%%%%%%%%%%%%%%%
result3 = lambda*p'.*lnd';
result3 = sum(sum(result3));
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 第四项 %%%%%%%%%%%%%%%%%%%%%%%%
standard_result = 0;
for i =1:nsample
for j = 1:nlabel
for m = 1:nlabel
standard_result = standard_result + lambda(j,m)*p(i,m)*lnp(i,m);
end
end
end
result4 = sum(sum(lambda*(p.*lnp)'));
disp(simplify(result4 - standard_result));
% 四个表达式为:
result1 = sum(sum(lambda'*(lnd' .* d'))); % sum_lambda = sum(lambda,2);
result2 = -sum(sum(lambda*lnp'.*d')); % nlabel * nsample
result3 = -sum(sum(lambda*p'.*lnd')); % nlabel * nsample ,注意!点乘和叉乘不能交换位置。
result4 = sum(sum(lambda*(p.*lnp)')); % nlabel * nsample
result_all1 = sum(sum(lambda'*(lnd' .* d') -lambda*lnp'.*d' - lambda*p'.*lnd' + lambda*(p.*lnp)' ));
tresult = result1 + result2 + result3 + result4;
% 总的表达式
standard_result = 0;
for i = 1:nsample
for j = 1:nlabel
for m = 1:nlabel
standard_result = standard_result + lambda(j,m) * ( (d(i,j) - p(i,m)) * (lnd(i,j) - lnp(i,m)) );
end
end
end
disp('-----------------------');
disp(simplify(tresult - standard_result));
disp(simplify(result_all1 - standard_result));
如何矢量化编程
最新推荐文章于 2023-10-08 14:26:20 发布