// 图像聚类
package com.amberlovesx.kmeans;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
/**
*@author lsw
*@since 1.0
*/
public class ImageCluster {
/**
* 需要分类的数量
*/
private int k;
/**
* 需要迭代的次数
*/
private int m;
/**
* 数据集合
*/
private DateItem[][] source;
/**
* 中心集合
*/
private DateItem[] center;
/**
*
*/
private DateItem[] centerSum;
/**
* 读取图片的数据,写入二维数组
*/
private int[][] getImageData(String path){
BufferedImage bufferedImage = null;
try{
bufferedImage = ImageIO.read(new File(path));
} catch (IOException e){
e.printStackTrace();
}
int width = bufferedImage.getWidth();
int height = bufferedImage.getHeight();
int[][] imageData = new int[width][height];
for (int i = 0; i < width; i++){
for (int j = 0; j < height; j++){
imageData[i][j] = bufferedImage.getRGB(i,j);
}
}
return imageData;
}
/**
* 初始化数据
* @param data
* @return
*/
private DateItem[][] initDate(int[][] data){
if(data.length<0 || null == data){
return null;
}
DateItem[][] dateItems = new DateItem[data.length][data[0].length];
for(int i=0; i < data.length; i++){
for(int j = 0; j < data[0].length; j++){
DateItem dateItem = new DateItem();
Color color = new Color(data[i][j]);
dateItem.r = color.getRed();
dateItem.g = color.getGreen();
dateItem.b = color.getBlue();
dateItem.group = 1;
dateItems[i][j] = dateItem;
}
}
return dateItems;
}
/**
* 初始化数据中心,随机生成
*/
private void initCenters(int k){
center = new DateItem[k];
centerSum = new DateItem[k];
int width, height;
for (int i = 0; i < k; i++){
DateItem cent = new DateItem();
DateItem cent2 = new DateItem();
width = (int)Math.random() * source.length;
height = (int)Math.random() * source[0].length;
cent.group = i;
cent.r = source[width][height].r;
cent.g = source[width][height].g;
cent.b = source[width][height].b;
center[i] = cent;
cent2.r = cent.r;
cent2.g = cent.g;
cent2.b = cent.b;
cent2.group = 0;
centerSum[i] = cent2;
width = 0;
height = 0;
}
for (int i = 0; i < k; i++){
System.out.println("r:" + center[i].r + " g:" + center[i].g + " b:" + center[i].b);
}
}
/**
* 计算两个点的距离,欧氏距离
*/
private double calDistance(DateItem item1, DateItem item2){
double distance = 0;
distance = Math.sqrt(Math.pow(item1.r-item2.r,2)+Math.pow(item1.g-item2.g,2)+Math.pow(item1.b-item2.b,2));
return distance;
}
/**
* 返回一个数组中最小的坐标
*/
private int minDistance(double[] distance){
double mindistance = distance[0];
int minLocation = 0;
for(int i = 1; i < distance.length; i++){
if(distance[i] < mindistance){
mindistance = distance[i];
minLocation = i;
} else if(distance[i] == mindistance){
if(Math.random()*10 < 5){
minLocation = i;
}
}
}
return minLocation;
}
/**
* 对每个点进行分类
*/
private void clusterSet(){
int group = -1;
double[] distance = new double[k];
for(int i = 0; i < source.length; i++){
for(int j = 0; j < source[0].length; j++){
for(int centerCount = 0; centerCount < center.length; centerCount++){
distance[centerCount] = calDistance(center[centerCount],source[i][j]);
}
group = minDistance(distance);
source[i][j].group = group;
centerSum[group].r += source[i][j].r;
centerSum[group].g += source[i][j].g;
centerSum[group].b += source[i][j].b;
centerSum[group].group += 1;
group = -1;
}
}
}
/**
* 设置新的中心
*/
private void setNewCenter(){
for(int i=0;i<centerSum.length;i++){
System.out.println(i+":"+centerSum[i].group+":"+centerSum[i].r+":"+centerSum[i].g+":"+centerSum[i].b);
//取平均值为新的中心
center[i].r=(int)(centerSum[i].r/centerSum[i].group);
center[i].g=(int)(centerSum[i].g/centerSum[i].group);
center[i].b=(int)(centerSum[i].b/centerSum[i].group);
//重置之前的求和结果
centerSum[i].r=center[i].r;
centerSum[i].g=center[i].g;
centerSum[i].b=center[i].b;
centerSum[i].group=0;
}
}
//输出聚类好的数据
private void ImagedataOut(String path){
Color c0=new Color(255,0,0);
Color c1=new Color(0,255,0);
Color c2=new Color(0,0,255);
Color c3=new Color(128,128,128);
BufferedImage nbi=new BufferedImage(source.length,source[0].length,BufferedImage.TYPE_INT_RGB);
for(int i = 0;i < source.length; i++){
for(int j = 0;j < source[0].length; j++){
if(source[i][j].group == 0)
nbi.setRGB(i, j, c0.getRGB());
else if(source[i][j].group==1)
nbi.setRGB(i, j, c1.getRGB());
else if(source[i][j].group==2)
nbi.setRGB(i, j, c2.getRGB());
else if (source[i][j].group==3)
nbi.setRGB(i, j, c3.getRGB());
}
}
try{
ImageIO.write(nbi, "jpg", new File(path));
}catch(IOException e){
e.printStackTrace();
}
}
//进行kmeans计算的核心函数
public void kmeans(String path,int k,int m) {
source = initDate(getImageData(path));
this.k = k;
this.m = m;
//初始化聚类中心
initCenters(k);
//进行m次聚类
for (int level = 0; level < m; level++) {
clusterSet();
setNewCenter();
for (int i = 0; i < center.length; i++) {
System.out.println("(" + center[i].r + "," + center[i].g + "," + center[i].b + ")");
}
}
clusterSet();
System.out.println("第" + m + "次迭代完成,聚类中心为:");
for (int i = 0; i < center.length; i++) {
System.out.println("(" + center[i].r + "," + center[i].g + "," + center[i].b + ")");
}
System.out.println("迭代总次数:" + m);
ImagedataOut("D:\\temp\\DesertAfter.jpg");
}
}
// 数据集点
package com.amberlovesx.kmeans;
public class DateItem {
public double r;
public double g;
public double b;
public int group;
}
// 测试类
package com.amberlovesx.kmeans;
import java.nio.file.Path;
import java.nio.file.Paths;
public class Test {
public static void main(String[] args){
ImageCluster ic=new ImageCluster();
ic.kmeans("D:\\temp\\Koala.jpg",10,10);
/*Path cur = Paths.get("D:\\temp");
System.out.println(cur.toAbsolutePath());*/
}
}