实验二 势函数算法的迭代训练
一.实验目的
通过本实验的学习,使学生了解或掌握模式识别中利用势函数思想设计非线性判别函数的方法,能够实现模式的分类。学会运用已学习的先导课程如数据结构和算法设计知识,选用合适的数据结构完成算法的设计和程序的实现。并通过训练数据来建立非线性判别函数,通过代待分类样本进行分类预测,通过检查预测结果和数据的几何分布特性检验分类器的正确性。通过选用此种分类方法进行分类器设计实验,强化学生对非线性分类器的了解和应用,从而牢固掌握模式识别课程内容知识。
二.实验内容
假定对病人3项主要指标检查得到正常(w1类)和非正常(w2类)的数据如下:
三.实验步骤
1、选定势函数(3个双变量对称基函数中选1;或做成多选的,实现人工自动选择);
2、确定合适数据结构,以便分别完成势函数和判别函数的正确表示;
3、对训练样本加以训练学习,建立判别函数,使其满足分类要求
4、记录并输出训练轮次;
5、对所有样本的类别用你的分类器加以判断(分类决策),比较与实际类别的差异;
6、对待分类样本进行判断,得到其类别(预测),如可能,以几何分布情况加以说明;
7、输出你的判别函数的表达形式(注意:表达形式要求便于阅读理解)。
四.测试
1、先测试已有样本的正确性。
2、用待分类数据加以分类。这里,对样本: (2,3,5),(6,7,10)
分别测试,检查它们几何分布情况是否与得到的分别属于w1类和w2类的结果相符,从而确认所设计的分类器是正确的。
五.实现提示
1)样本存放在矩阵s中,s的每一行是一个样本,为方便编程,可将类别号增加在每个样本中,作为最后一维;
2)为了保存和计算判别函数,可使用一个辅助的结构数组ftbl,该数组的每个分量含两个成分:index和symbol。 index记录对应样本下标号,symbol记录该项的符号。
六、实验代码
#include <iostream>
#include <math.h>
using namespace std;
#define n 6
//n表示样本总数。这里n=6,前3个样本属于第一类,后三个样本属于第二类
#define m 30
#define d 3
//d表示维长
struct sample {
int x[d];
int cl;
};
struct func{
int symbol;
int index; //用于记录样本号,即对应的样本的下标号
};
struct func ftbl[m];
int k,r=-1,tag=1,i,j;
float g=0,temp;
struct sample s[n]={
{1,2,5,1}, //1表示属于第1类
{1,1,2,1}, //1表示属于第1类
{3,3,6,1}, //1表示属于第1类
{5,6,10,2}, //2表示属于第2类
{7,6,11,2}, //2表示属于第2类
{8,7,12,2}}; //2表示属于第2类
int main()
{
while(tag==1)
{ tag=0;
for (k=0;k<n;k++)
{ if (r==-1){
r++; //r为项数
ftbl[r].symbol=1; //该项的符号。 1--正
ftbl[r].index=0; //该项对应的样本下标号
continue;
}
else{
g=0;
for(i=0;i<=r;i++)
{ temp=0;
for(j=0;j<d;j++) //d表示维长
temp+=(s[k].x[j]-s[ftbl[i].index].x[j])*(s[k].x[j]-s[ftbl[i].index].x[j]);
g+=ftbl[i].symbol*exp(-temp); //共r项,每项都是一指数形式
}
if((g>0&&s[k].cl==1)||(g<0&&s[k].cl==2)) continue; //正确分类时,不修改判别函数
else {
r++;
ftbl[r].index=k;
tag=1;
if(g>0&&s[k].cl==2)
ftbl[r].symbol=-1;
else if(g<0&&s[k].cl==1)
ftbl[r].symbol=1;
} // end of else
} //end of else
} // end of for
}// end of while
cout<<"\n\n\n";
for(i=0;i<=r;i++)
{if(ftbl[i].symbol==1)
if(i==0) cout<<"exp{-[(x1";
else cout<<"+exp{-[(x1";
else
cout<<"-exp{-[(x1";
if (s[ftbl[i].index].x[0]>0)
cout<<"-"<<s[ftbl[i].index].x[0]<<")^2+(x2";
else if(s[ftbl[i].index].x[0]<0)
cout<<"+"<<-s[ftbl[i].index].x[0]<<")^2+(x2";
else //s[ftbl[i].index].x[0]==0
cout<<")^2+(x2";
if (s[ftbl[i].index].x[1]>0)
cout<<"-"<<s[ftbl[i].index].x[1]<<")^2+(x3";
else if(s[ftbl[i].index].x[1]<0)
cout<<"+"<<-s[ftbl[i].index].x[1]<<")^2+(x3";
else
cout<<")^2+(x3";
if (s[ftbl[i].index].x[2]>0)
cout<<"-"<<s[ftbl[i].index].x[2]<<")^2]}";
else if(s[ftbl[i].index].x[2]<0)
cout<<"+"<<-s[ftbl[i].index].x[2]<<")^2]}";
else
cout<<")^2]}";
cout<<endl;
} // end of for
} // end of main()