java实现 k nn算法_k-均值算法的java实现

importjava.io.BufferedReader;

importjava.io.FileNotFoundException;

importjava.io.FileReader;

importjava.io.IOException;

publicclassKAverage {

privateintsampleCount =0;

privateintdimensionCount =0;

privateintcenterCount =0;

privatedouble[][] sampleValues;

privatedouble[][] centers;

privatedouble[][] tmpCenters;

privateString dataFile ="";

/**

* 通过构造器传人数据文件

*/

publicKAverage(String dataFile)throwsNumberInvalieException {

this.dataFile = dataFile;

}

/**

* 第一行为s;d;c含义分别为样例的数目,每个样例特征的维数,聚类中心个数 文件格式为d[,d]...;d[,d]... 如:1,2;2,3;1,5

* 每一维之间用,隔开,每个样例间用;隔开。结尾没有';' 可以有多行

*/

privateintinitData(String fileName) {

String line;

String samplesValue[];

String dimensionsValue[] =newString[dimensionCount];

BufferedReader in;

try{

in =newBufferedReader(newFileReader(fileName));

}catch(FileNotFoundException e) {

e.printStackTrace();

return-1;

}

/*

* 预处理样本,允许后面几维为0时,不写入文件

*/

for(inti =0; i 

for(intj =0; j 

sampleValues[i][j] =0;

}

}

inti =0;

doubletmpValue =0.0;

try{

line = in.readLine();

String params[] = line.split(";");

if(params.length !=3) {// 必须为3个参数,否则错误

return-1;

}

/**

* 获取参数

*/

this.sampleCount = Integer.parseInt(params[0]);

this.dimensionCount = Integer.parseInt(params[1]);

this.centerCount = Integer.parseInt(params[2]);

if(sampleCount <=0|| dimensionCount <=0|| centerCount <=0) {

thrownewNumberInvalieException("input number <= 0.");

}

if(sampleCount 

thrownewNumberInvalieException(

"sample number 

}

sampleValues =newdouble[sampleCount][dimensionCount +1];

centers =newdouble[centerCount][dimensionCount];

tmpCenters =newdouble[centerCount][dimensionCount];

while((line = in.readLine()) !=null) {

samplesValue = line.split(";");

for(intj =0; j 

dimensionsValue = samplesValue[j].split(",");

for(intk =0; k 

tmpValue = Double.parseDouble(dimensionsValue[k]);

sampleValues[i][k] = tmpValue;

}

i++;

}

}

}catch(IOException e) {

e.printStackTrace();

return-2;

}catch(Exception e) {

e.printStackTrace();

return-3;

}

return1;

}

/**

* 返回样本中第s1个和第s2个间的欧式距离

*/

privatedoublegetDistance(ints1,ints2)throwsNumberInvalieException {

doubledistance =0.0;

if(s1 <0|| s1 >= sampleCount || s2 <0|| s2 >= sampleCount) {

thrownewNumberInvalieException("number out of bound.");

}

for(inti =0; i 

distance += (sampleValues[s1][i] - sampleValues[s2][i])

* (sampleValues[s1][i] - sampleValues[s2][i]);

}

returndistance;

}

/**

* 返回给定两个向量间的欧式距离

*/

privatedoublegetDistance(doubles1[],doubles2[]) {

doubledistance =0.0;

for(inti =0; i 

distance += (s1[i] - s2[i]) * (s1[i] - s2[i]);

}

returndistance;

}

/**

* 更新样本中第s个样本的最近中心

*/

privateintgetNearestCenter(ints) {

intcenter =0;

doubleminDistance = Double.MAX_VALUE;

doubledistance =0.0;

for(inti =0; i 

distance = getDistance(sampleValues[s], centers[i]);

if(distance 

minDistance = distance;

center = i;

}

}

sampleValues[s][dimensionCount] = center;

returncenter;

}

/**

* 更新所有中心

*/

privatevoidupdateCenters() {

doublecenter[] =newdouble[dimensionCount];

for(inti =0; i 

center[i] =0;

}

intcount =0;

for(inti =0; i 

count =0;

for(intj =0; j 

if(sampleValues[j][dimensionCount] == i) {

count++;

for(intk =0; k 

center[k] += sampleValues[j][k];

}

}

}

for(intj =0; j 

centers[i][j] = center[j] / count;

}

}

}

/**

* 判断算法是否终止

*/

privatebooleantoBeContinued() {

for(inti =0; i 

for(intj =0; j 

if(tmpCenters[i][j] != centers[i][j]) {

returntrue;

}

}

}

returnfalse;

}

/**

* 关键方法,调用其他方法,处理数据

*/

publicvoiddoCaculate() {

initData(dataFile);

for(inti =0; i 

for(intj =0; j 

centers[i][j] = sampleValues[i][j];

}

}

for(inti =0; i 

for(intj =0; j 

tmpCenters[i][j] =0;

}

}

while(toBeContinued()) {

for(inti =0; i 

getNearestCenter(i);

}

for(inti =0; i 

for(intj =0; j 

tmpCenters[i][j] = centers[i][j];

}

}

updateCenters();

System.out

.println("******************************************************");

showResultData();

}

}

/*

* 显示数据

*/

privatevoidshowSampleData() {

for(inti =0; i 

for(intj =0; j 

if(j ==0) {

System.out.print(sampleValues[i][j]);

}else{

System.out.print(","+ sampleValues[i][j]);

}

}

System.out.println();

}

}

/*

* 分组显示结果

*/

privatevoidshowResultData() {

for(inti =0; i 

System.out.println("第"+ (i +1) +"个分组内容为:");

for(intj =0; j 

if(sampleValues[j][dimensionCount] == i) {

for(intk =0; k <= dimensionCount; k++) {

if(k ==0) {

System.out.print(sampleValues[j][k]);

}else{

System.out.print(","+ sampleValues[j][k]);

}

}

System.out.println();

}

}

}

}

publicstaticvoidmain(String[] args) {

/*

*也可以通过命令行得到参数

*/

String fileName ="D://eclipsejava//K-Average//src//sample.txt";

if(args.length >0){

fileName = args[0];

}

try{

KAverage ka =newKAverage(fileName);

ka.doCaculate();

System.out

.println("***************************<>**************************");

ka.showResultData();

}catch(Exception e) {

e.printStackTrace();

}

}

}

import java.io.BufferedReader;

import java.io.FileNotFoundException;

import java.io.FileReader;

import java.io.IOException;

public class KAverage {

private int sampleCount = 0;

private int dimensionCount = 0;

private int centerCount = 0;

private double[][] sampleValues;

private double[][] centers;

private double[][] tmpCenters;

private String dataFile = "";

/**

* 通过构造器传人数据文件

*/

public KAverage(String dataFile) throws NumberInvalieException {

this.dataFile = dataFile;

}

/**

* 第一行为s;d;c含义分别为样例的数目,每个样例特征的维数,聚类中心个数 文件格式为d[,d]...;d[,d]... 如:1,2;2,3;1,5

* 每一维之间用,隔开,每个样例间用;隔开。结尾没有';' 可以有多行

*/

private int initData(String fileName) {

String line;

String samplesValue[];

String dimensionsValue[] = new String[dimensionCount];

BufferedReader in;

try {

in = new BufferedReader(new FileReader(fileName));

} catch (FileNotFoundException e) {

e.printStackTrace();

return -1;

}

/*

* 预处理样本,允许后面几维为0时,不写入文件

*/

for (int i = 0; i < sampleCount; i++) {

for (int j = 0; j < dimensionCount; j++) {

sampleValues[i][j] = 0;

}

}

int i = 0;

double tmpValue = 0.0;

try {

line = in.readLine();

String params[] = line.split(";");

if (params.length != 3) {// 必须为3个参数,否则错误

return -1;

}

/**

* 获取参数

*/

this.sampleCount = Integer.parseInt(params[0]);

this.dimensionCount = Integer.parseInt(params[1]);

this.centerCount = Integer.parseInt(params[2]);

if (sampleCount <= 0 || dimensionCount <= 0 || centerCount <= 0) {

throw new NumberInvalieException("input number <= 0.");

}

if (sampleCount < centerCount) {

throw new NumberInvalieException(

"sample number < center number");

}

sampleValues = new double[sampleCount][dimensionCount + 1];

centers = new double[centerCount][dimensionCount];

tmpCenters = new double[centerCount][dimensionCount];

while ((line = in.readLine()) != null) {

samplesValue = line.split(";");

for (int j = 0; j < samplesValue.length; j++) {

dimensionsValue = samplesValue[j].split(",");

for (int k = 0; k < dimensionsValue.length; k++) {

tmpValue = Double.parseDouble(dimensionsValue[k]);

sampleValues[i][k] = tmpValue;

}

i++;

}

}

} catch (IOException e) {

e.printStackTrace();

return -2;

} catch (Exception e) {

e.printStackTrace();

return -3;

}

return 1;

}

/**

* 返回样本中第s1个和第s2个间的欧式距离

*/

private double getDistance(int s1, int s2) throws NumberInvalieException {

double distance = 0.0;

if (s1 < 0 || s1 >= sampleCount || s2 < 0 || s2 >= sampleCount) {

throw new NumberInvalieException("number out of bound.");

}

for (int i = 0; i < dimensionCount; i++) {

distance += (sampleValues[s1][i] - sampleValues[s2][i])

* (sampleValues[s1][i] - sampleValues[s2][i]);

}

return distance;

}

/**

* 返回给定两个向量间的欧式距离

*/

private double getDistance(double s1[], double s2[]) {

double distance = 0.0;

for (int i = 0; i < dimensionCount; i++) {

distance += (s1[i] - s2[i]) * (s1[i] - s2[i]);

}

return distance;

}

/**

* 更新样本中第s个样本的最近中心

*/

private int getNearestCenter(int s) {

int center = 0;

double minDistance = Double.MAX_VALUE;

double distance = 0.0;

for (int i = 0; i < centerCount; i++) {

distance = getDistance(sampleValues[s], centers[i]);

if (distance < minDistance) {

minDistance = distance;

center = i;

}

}

sampleValues[s][dimensionCount] = center;

return center;

}

/**

* 更新所有中心

*/

private void updateCenters() {

double center[] = new double[dimensionCount];

for (int i = 0; i < dimensionCount; i++) {

center[i] = 0;

}

int count = 0;

for (int i = 0; i < centerCount; i++) {

count = 0;

for (int j = 0; j < sampleCount; j++) {

if (sampleValues[j][dimensionCount] == i) {

count++;

for (int k = 0; k < dimensionCount; k++) {

center[k] += sampleValues[j][k];

}

}

}

for (int j = 0; j < dimensionCount; j++) {

centers[i][j] = center[j] / count;

}

}

}

/**

* 判断算法是否终止

*/

private boolean toBeContinued() {

for (int i = 0; i < centerCount; i++) {

for (int j = 0; j < dimensionCount; j++) {

if (tmpCenters[i][j] != centers[i][j]) {

return true;

}

}

}

return false;

}

/**

* 关键方法,调用其他方法,处理数据

*/

public void doCaculate() {

initData(dataFile);

for (int i = 0; i < centerCount; i++) {

for (int j = 0; j < dimensionCount; j++) {

centers[i][j] = sampleValues[i][j];

}

}

for (int i = 0; i < centerCount; i++) {

for (int j = 0; j < dimensionCount; j++) {

tmpCenters[i][j] = 0;

}

}

while (toBeContinued()) {

for (int i = 0; i < sampleCount; i++) {

getNearestCenter(i);

}

for (int i = 0; i < centerCount; i++) {

for (int j = 0; j < dimensionCount; j++) {

tmpCenters[i][j] = centers[i][j];

}

}

updateCenters();

System.out

.println("******************************************************");

showResultData();

}

}

/*

* 显示数据

*/

private void showSampleData() {

for (int i = 0; i < sampleCount; i++) {

for (int j = 0; j < dimensionCount; j++) {

if (j == 0) {

System.out.print(sampleValues[i][j]);

} else {

System.out.print("," + sampleValues[i][j]);

}

}

System.out.println();

}

}

/*

* 分组显示结果

*/

private void showResultData() {

for (int i = 0; i < centerCount; i++) {

System.out.println("第" + (i + 1) + "个分组内容为:");

for (int j = 0; j < sampleCount; j++) {

if (sampleValues[j][dimensionCount] == i) {

for (int k = 0; k <= dimensionCount; k++) {

if (k == 0) {

System.out.print(sampleValues[j][k]);

} else {

System.out.print("," + sampleValues[j][k]);

}

}

System.out.println();

}

}

}

}

public static void main(String[] args) {

/*

*也可以通过命令行得到参数

*/

String fileName = "D://eclipsejava//K-Average//src//sample.txt";

if(args.length > 0){

fileName = args[0];

}

try {

KAverage ka = new KAverage(fileName);

ka.doCaculate();

System.out

.println("***************************<>**************************");

ka.showResultData();

} catch (Exception e) {

e.printStackTrace();

}

}

}

Java代码

icon_copy.gif

/*

* 根据自己的需要定义一些异常,使得系统性更强

*/

publicclassNumberInvalieExceptionextendsException {

privateString cause;

publicNumberInvalieException(String cause){

if(cause ==null||"".equals(cause)){

this.cause ="unknow";

}else{

this.cause = cause;

}

}

@Override

publicString toString() {

return"Number Invalie!Cause by "+ cause;

}

}

/*

* 根据自己的需要定义一些异常,使得系统性更强

*/

public class NumberInvalieException extends Exception {

private String cause;

public NumberInvalieException(String cause){

if(cause == null || "".equals(cause)){

this.cause = "unknow";

}else{

this.cause = cause;

}

}

@Override

public String toString() {

return "Number Invalie!Cause by " + cause;

}

}

测试数据 20;2;4 0,0;1,0;0,1;1,1;2,1;1,2;2,2;3,2;6,6;7,6 8,6;6,7;7,7;8,7;9,7;7,8;8,8;9,8;8,9;9,9 测试结果 ***************************<>************************** 第1个分组内容为: 0.0,0.0,0.0 1.0,0.0,0.0 0.0,1.0,0.0 1.0,1.0,0.0 2.0,1.0,0.0 1.0,2.0,0.0 2.0,2.0,0.0 3.0,2.0,0.0 第2个分组内容为: 6.0,6.0,1.0 7.0,6.0,1.0 8.0,6.0,1.0 6.0,7.0,1.0 7.0,7.0,1.0 8.0,7.0,1.0 9.0,7.0,1.0 7.0,8.0,1.0 8.0,8.0,1.0 9.0,8.0,1.0 8.0,9.0,1.0 9.0,9.0,1.0

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值