Java 实现K-means

// 图像聚类

package com.amberlovesx.kmeans;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;

/**
 *@author lsw
 *@since 1.0
 */
public class ImageCluster {
    /**
     * 需要分类的数量
     */
    private int k;
    /**
     * 需要迭代的次数
     */
    private int m;
    /**
     * 数据集合
     */
    private DateItem[][] source;
    /**
     * 中心集合
     */
    private DateItem[] center;
    /**
     *
     */
    private DateItem[] centerSum;
    /**
     * 读取图片的数据,写入二维数组
     */
    private int[][] getImageData(String path){
        BufferedImage bufferedImage = null;
        try{
            bufferedImage = ImageIO.read(new File(path));
        } catch (IOException e){
            e.printStackTrace();
        }
        int width = bufferedImage.getWidth();
        int height = bufferedImage.getHeight();
        int[][] imageData = new int[width][height];
        for (int i = 0; i < width; i++){
            for (int j = 0; j < height; j++){
                imageData[i][j] = bufferedImage.getRGB(i,j);
            }
        }
        return imageData;
    }

    /**
     * 初始化数据
     * @param data
     * @return
     */
    private DateItem[][] initDate(int[][] data){
        if(data.length<0 || null == data){
           return null;
        }
        DateItem[][] dateItems = new DateItem[data.length][data[0].length];
        for(int i=0; i < data.length; i++){
            for(int j = 0; j < data[0].length; j++){
                DateItem dateItem = new DateItem();
                Color color = new Color(data[i][j]);
                dateItem.r = color.getRed();
                dateItem.g = color.getGreen();
                dateItem.b = color.getBlue();
                dateItem.group = 1;
                dateItems[i][j] = dateItem;
            }
        }
        return dateItems;
    }
    /**
     * 初始化数据中心,随机生成
     */
    private void initCenters(int k){
        center = new DateItem[k];
        centerSum = new DateItem[k];
        int width, height;
        for (int i = 0; i < k; i++){
            DateItem cent  = new DateItem();
            DateItem cent2 = new DateItem();
            width  = (int)Math.random() * source.length;
            height = (int)Math.random() * source[0].length;
            cent.group = i;
            cent.r = source[width][height].r;
            cent.g = source[width][height].g;
            cent.b = source[width][height].b;
            center[i] = cent;

            cent2.r = cent.r;
            cent2.g = cent.g;
            cent2.b = cent.b;
            cent2.group = 0;
            centerSum[i] = cent2;
            width = 0;
            height = 0;
        }
        for (int i = 0; i < k; i++){
            System.out.println("r:" + center[i].r + "  g:" + center[i].g + "  b:" + center[i].b);
        }
    }
    /**
     * 计算两个点的距离,欧氏距离
     */
    private double calDistance(DateItem item1, DateItem item2){
        double distance = 0;
        distance = Math.sqrt(Math.pow(item1.r-item2.r,2)+Math.pow(item1.g-item2.g,2)+Math.pow(item1.b-item2.b,2));
        return distance;
    }

    /**
     * 返回一个数组中最小的坐标
     */
    private int minDistance(double[] distance){
        double mindistance = distance[0];
        int minLocation = 0;
        for(int i = 1; i < distance.length; i++){
            if(distance[i] < mindistance){
                mindistance = distance[i];
                minLocation = i;
            } else if(distance[i] == mindistance){
                if(Math.random()*10 < 5){
                    minLocation = i;
                }
            }
        }
        return minLocation;
    }

    /**
     * 对每个点进行分类
     */
    private void clusterSet(){
        int group = -1;
        double[] distance = new double[k];
        for(int i = 0; i < source.length; i++){
            for(int j = 0; j < source[0].length; j++){
                for(int centerCount = 0; centerCount < center.length; centerCount++){
                    distance[centerCount] = calDistance(center[centerCount],source[i][j]);
                }
                group = minDistance(distance);
                source[i][j].group = group;

                centerSum[group].r += source[i][j].r;
                centerSum[group].g += source[i][j].g;
                centerSum[group].b += source[i][j].b;
                centerSum[group].group += 1;

                group = -1;
            }
        }
    }

    /**
     * 设置新的中心
     */
    private void setNewCenter(){
        for(int i=0;i<centerSum.length;i++){
            System.out.println(i+":"+centerSum[i].group+":"+centerSum[i].r+":"+centerSum[i].g+":"+centerSum[i].b);
            //取平均值为新的中心
            center[i].r=(int)(centerSum[i].r/centerSum[i].group);
            center[i].g=(int)(centerSum[i].g/centerSum[i].group);
            center[i].b=(int)(centerSum[i].b/centerSum[i].group);
            //重置之前的求和结果
            centerSum[i].r=center[i].r;
            centerSum[i].g=center[i].g;
            centerSum[i].b=center[i].b;
            centerSum[i].group=0;
        }
    }

    //输出聚类好的数据
    private void ImagedataOut(String path){
        Color c0=new Color(255,0,0);
        Color c1=new Color(0,255,0);
        Color c2=new Color(0,0,255);
        Color c3=new Color(128,128,128);
        BufferedImage nbi=new BufferedImage(source.length,source[0].length,BufferedImage.TYPE_INT_RGB);
        for(int i = 0;i < source.length; i++){
            for(int j = 0;j < source[0].length; j++){
                if(source[i][j].group == 0)
                    nbi.setRGB(i, j, c0.getRGB());
                else if(source[i][j].group==1)
                    nbi.setRGB(i, j, c1.getRGB());
                else if(source[i][j].group==2)
                    nbi.setRGB(i, j, c2.getRGB());
                else if (source[i][j].group==3)
                    nbi.setRGB(i, j, c3.getRGB());
            }
        }
        try{
            ImageIO.write(nbi, "jpg", new File(path));
        }catch(IOException e){
            e.printStackTrace();
        }
    }
    //进行kmeans计算的核心函数
    public void kmeans(String path,int k,int m) {
        source = initDate(getImageData(path));
        this.k = k;
        this.m = m;

        //初始化聚类中心
        initCenters(k);

        //进行m次聚类
        for (int level = 0; level < m; level++) {
            clusterSet();
            setNewCenter();
            for (int i = 0; i < center.length; i++) {
                System.out.println("(" + center[i].r + "," + center[i].g + "," + center[i].b + ")");
            }
        }
        clusterSet();
        System.out.println("第" + m + "次迭代完成,聚类中心为:");
        for (int i = 0; i < center.length; i++) {
            System.out.println("(" + center[i].r + "," + center[i].g + "," + center[i].b + ")");
        }
        System.out.println("迭代总次数:" + m);
        ImagedataOut("D:\\temp\\DesertAfter.jpg");
    }
}

// 数据集点

package com.amberlovesx.kmeans;

public class DateItem {
    public double r;
    public double g;
    public double b;
    public int group;
}

// 测试类

package com.amberlovesx.kmeans;

import java.nio.file.Path;
import java.nio.file.Paths;

public class Test {
    public static void main(String[] args){
        ImageCluster ic=new ImageCluster();
        ic.kmeans("D:\\temp\\Koala.jpg",10,10);
        /*Path cur = Paths.get("D:\\temp");
        System.out.println(cur.toAbsolutePath());*/
    }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值