图像颜色kmean聚类



import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;
import java.util.*;
import java.util.List;

/**
 * Created by Administrator on 2017/9/21 0021.
 */
public class ImageUtil {

    public static Image readImage(String path) {

        Image image = null;
        File imagePath = new File(path);
        try {
            image = ImageIO.read(imagePath);

            return image;
        } catch (IOException e) {

            e.printStackTrace();
            return null;
        }
    }

    public static BufferedImage readFileImage(String path){

        File file = new File(path);
        BufferedImage image = null;
        try {
            image = ImageIO.read(file);
            return image;
        } catch (IOException e) {
            e.printStackTrace();
            return image ;
        }
    }

    //BufferedImage转成数组
    public static INDArray convertImageToArray(BufferedImage image){

        int width = image.getWidth() ;
        int height = image.getHeight() ;
        Nd4j.factory().setDType(DataBuffer.Type.DOUBLE);
        INDArray rgbArray = Nd4j.zeros(width,height,3);

        for(int j=0;j<height;j++){

            for(int i=0;i<width;i++){

               // int rgb = image.getRGB(i,j);
                Object data = image.getRaster().getDataElements(i, j, null);//获取该点像素,并以object类型表示
                int R = image.getColorModel().getRed(data);
                int G = image.getColorModel().getBlue(data);
                int B = image.getColorModel().getGreen(data);
               // Color color = new Color(rgb);
                rgbArray.putScalar(i,j,0,R);
                rgbArray.putScalar(i,j,1,G);
                rgbArray.putScalar(i,j,2,B);
                //System.out.println(rgbArray[j][i]);
            }
        }
        return rgbArray ;
    }
    //BufferedImage转成数组
    /**
     * 将图片转成INDArray,图片对象识别,INDArray[1,3,w,h]
     * @param image 图片
     * @param w 图片宽度
     * @param h 图片高
     * @param channel 图片图片颜色通道
     * */
    public static INDArray imageToArray(BufferedImage image,int w,int h,int channel){

        int width = image.getWidth() ;
        int height = image.getHeight() ;
        BufferedImage resizeImag = new BufferedImage(w,h,BufferedImage.TYPE_INT_RGB);

        resizeImag.getGraphics().drawImage(image,0,0,w,h,null);

        Nd4j.factory().setDType(DataBuffer.Type.DOUBLE);
        INDArray rgbArray = Nd4j.zeros(1,channel,w,h);

        for(int j=0;j<h;j++){

            for(int i=0;i<w;i++){

                // int rgb = image.getRGB(i,j);
                Object data = resizeImag.getRaster().getDataElements(i, j, null);//获取该点像素,并以object类型表示
                int R = resizeImag.getColorModel().getRed(data);
                int G = resizeImag.getColorModel().getGreen(data);
                int B = resizeImag.getColorModel().getBlue(data);
                // Color color = new Color(rgb);
                rgbArray.putScalar(0,0,i,j,R);
                rgbArray.putScalar(0,1,i,j,G);
                rgbArray.putScalar(0,2,i,j,B);
                //System.out.println(rgbArray[j][i]);
            }
        }
        return rgbArray ;
    }
    /**
    **数据归一化
     */
    public static INDArray getImageNormalization(INDArray imageArr) {

      //  int maxValueIndex = 255//Nd4j.getExecutioner().execAndReturn(new IAMax(imageArr)).getFinalResult();
       // int minValueIndex = //Nd4j.getExecutioner().execAndReturn(new IAMin(imageArr)).getFinalResult();

        double maxValue = 255;//imageArr.getDouble(maxValueIndex);
        double minValue = 0;//imageArr.getDouble(minValueIndex);

        try {

            for (int i = 0; i < imageArr.size(0); i++) {

                for (int j = 0; j < imageArr.size(1); j++) {

                    double r = (imageArr.getDouble(i,j,0) - minValue) / (maxValue - minValue+1);
                    double g = (imageArr.getDouble(i,j,1) - minValue) / (maxValue - minValue+1);
                    double b = (imageArr.getDouble(i,j,2) - minValue) / (maxValue - minValue+1);
                    imageArr.putScalar(i,j,0, r);
                    imageArr.putScalar(i,j,1, g);
                    imageArr.putScalar(i,j,2, b);
                }
            }

        } catch (Exception e) {
            e.printStackTrace();

        }
        return imageArr;
    }
    /**
     **数据归一化
     */
    public static INDArray getNormalization(INDArray imageArr) {

        int maxValueIndex = Nd4j.getExecutioner().execAndReturn(new IAMax(imageArr)).getFinalResult();
        int minValueIndex = Nd4j.getExecutioner().execAndReturn(new IAMin(imageArr)).getFinalResult();

        INDArray normlization = Nd4j.zeros(1,3,imageArr.size(2),imageArr.size(3));
        double maxValue = imageArr.getDouble(maxValueIndex);
        double minValue = imageArr.getDouble(minValueIndex);

        try {

            for (int i = 0; i < imageArr.size(2); i++) {

                for (int j = 0; j < imageArr.size(3); j++) {

                    double r = (imageArr.getDouble(0,0,i,j) - minValue) / (maxValue - minValue+1);
                    double g = (imageArr.getDouble(0,1,i,j) - minValue) / (maxValue - minValue+1);
                    double b = (imageArr.getDouble(0,2,i,j) - minValue) / (maxValue - minValue+1);
                    normlization.putScalar(0,0,i,j, r);
                    normlization.putScalar(0,1,i,j, g);
                    normlization.putScalar(0,2,i,j, b);
                }
            }

        } catch (Exception e) {
            e.printStackTrace();

        }
        return normlization;
    }
    public static INDArray kmean(INDArray imageArr,int number){

        int rows = imageArr.size(0);
        int columns = imageArr.size(1);
        INDArray centers = Nd4j.zeros(number,3);
        INDArray result = Nd4j.zeros(rows,columns);
        for (int i = 0;i < number;i++){

            Random random = new Random();
            int row = random.nextInt(rows);
            int column = random.nextInt(columns);
            centers.putScalar(i,0,imageArr.getDouble(row,column,0));
            centers.putScalar(i,1,imageArr.getDouble(row,column,1));
            centers.putScalar(i,2,imageArr.getDouble(row,column,2));
        }
        double threshold = 0.1;
        double distance = 1;
        Map<Integer,List<INDArray>> cluster =  null ;
        while(distance > threshold){

             cluster =  new HashMap<>() ;
             for(int i=0;i<rows;i++) {

                 for (int j = 0; j < columns; j++) {

                     INDArray rgb = imageArr.get(NDArrayIndex.point(i), NDArrayIndex.point(j), NDArrayIndex.all());

                     INDArray colorCenter = centers;
                     INDArray color = rgb.reshape(1,3);

                     int label = getClassCode(colorCenter,color);
                     if(cluster.containsKey(label)){

                          List<INDArray> list = cluster.get(label);
                          list.add(color);
                          cluster.put(label,list);
                     } else {

                         List<INDArray>  listColor = new ArrayList<>();
                         listColor.add(color);
                         cluster.put(label,listColor);
                     }
                     result.putScalar(i,j,label);
                 }
             }
             INDArray newCenters = getCenters(cluster,number);
             distance = getDisTance(newCenters,centers);
             centers = newCenters;
             System.out.println(distance);
        }
        System.out.println(centers);
        return result;
    }

    public static INDArray kmeanByCenter(INDArray imageArr,int number,INDArray centers){

        int rows = imageArr.size(0);
        int columns = imageArr.size(1);

        INDArray result = Nd4j.zeros(rows,columns);

        double threshold = 0.000000001;
        double distance = 1;
        Map<Integer,List<INDArray>> cluster = new HashMap<>() ;

            for(int i=0;i<rows;i++) {

                for (int j = 0; j < columns; j++) {

                    INDArray rgb = imageArr.get(NDArrayIndex.point(i), NDArrayIndex.point(j), NDArrayIndex.all());
                    INDArray colorCenter = centers;
                    INDArray color = rgb.reshape(1,3);
                    int label = getClassCode(colorCenter,color);
                    if(cluster.containsKey(label)){

                        List<INDArray> list = cluster.get(label);
                        list.add(color);
                        cluster.put(label,list);
                    } else {

                        List<INDArray>  listColor = new ArrayList<>();
                        listColor.add(color);
                        cluster.put(label,listColor);
                    }
                    result.putScalar(i,j,label);
                }
            }
            INDArray newCenters = getCenters(cluster,number);
            distance = getDisTance(newCenters,centers);
            System.out.println(distance);

        return result;
    }

    public static BufferedImage arrayToGreyImage(INDArray sourceArray,int number){

        int width = sourceArray.rows();
        int height = sourceArray.columns();
        BufferedImage targetImage = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
        Integer[] colors = new Integer[number];
        for(int i = 0;i<number;i++){

            Random random = new Random();
            Integer red = random.nextInt(256);
            Integer green = random.nextInt(256);
            Integer blue = random.nextInt(256);
            Integer rgb = (65536    *    blue)    +    (256    *    green)    +    (red);
            colors[i] = rgb;
        }
        for(int i=0;i < width ;i++) {

            for (int j = 0;j < height ;j++){

                int index = sourceArray.getInt(i,j)-1;
                targetImage.setRGB(i, j, colors[index]);
            }
        }

        return targetImage;
    }
    public static INDArray getRGBTOHSV(INDArray image){

        int height = image.size(0);
        int width = image.size(1);
        INDArray imageHSV = Nd4j.zeros(height, width,3);

        for(int i = 0;i<height;i++){

            for(int j=0;j<width;j++){

                INDArray rgb = image.get(NDArrayIndex.point(i), NDArrayIndex.point(j), NDArrayIndex.all());
                INDArray row = rgb.reshape(1,3);
                int maxIndex = Nd4j.getExecutioner().execAndReturn(new IAMax(row)).getFinalResult();
                int minIndex  = Nd4j.getExecutioner().execAndReturn(new IAMin(row)).getFinalResult();
                int max = row.getInt(maxIndex);
                int min = row.getInt(minIndex);
                int R = row.getInt(0);
                int G = row.getInt(1);
                int B = row.getInt(2);
                float delta = max - min;
                float H,S,V;
                if(delta > 0){

                    if(R == max){

                        H = (G-B)/delta;
                    } else if(G == max){

                        H = 2f + (B-R)/(delta);
                    } else {

                        H = 4f + (R-G)/delta;
                    }
                    H = H*60f;
                    if(H < 0){

                        H += 360f;
                    }

                    S = delta/max;
                } else {

                    H = 0f;
                    S = 0f;
                }
                V = max*1.0f/255f;
                imageHSV.putScalar(i,j,0,(double)H);
                imageHSV.putScalar(i,j,1,(double)V);
                imageHSV.putScalar(i,j,2,(double)S);
            }

        }

        return imageHSV;
    }
    private static int getClassCode(INDArray X,INDArray Y){

        INDArray dXY = X.subRowVector(Y);

        double minValue = 999;
        int index = 0;
        for (int i = 0;i<dXY.size(0);i++) {

            double value = dXY.getDouble(i,0)*dXY.getDouble(i,0)+
                dXY.getDouble(i,1)*dXY.getDouble(i,1)+dXY.getDouble(i,2)*dXY.getDouble(i,2);

            double delta = Math.sqrt(value);
            if(minValue>delta){

                minValue = delta;
                index = i;
            }
        }
        return index+1;
    }
    private static double getDisTance(INDArray X,INDArray Y){

        INDArray dXY = X.sub(Y);

        double minValue = 999;
        int index = 0;
        for(int i=0;i<dXY.size(0);i++){

            double value = dXY.getDouble(i,0)*dXY.getDouble(i,0)+
                dXY.getDouble(i,1)*dXY.getDouble(i,1)+dXY.getDouble(i,2)*dXY.getDouble(i,2);
            double delta = Math.sqrt(value);
            if(minValue>delta){

                minValue = delta;
            }
        }
        return minValue;
    }
    private static INDArray getCenters(Map<Integer,List<INDArray>> clusters,int number){

        INDArray centers = Nd4j.zeros(number,3);

        int label = 1;
        while (label <= number) {

           List<INDArray> listValue = clusters.get(label);
           INDArray sumColor = Nd4j.zeros(1,3);
           for(INDArray color:listValue){

               sumColor.addi(color);
           }
           double count = listValue.size();
           double r = sumColor.getDouble(0,0)/count;
           double g = sumColor.getDouble(0,1)/count;
           double b = sumColor.getDouble(0,2)/count;
           int row = label - 1;
           centers.putScalar(row,0,r);
           centers.putScalar(row,1,g);
           centers.putScalar(row,2,b);

           label++;
        }
        return centers;
    }
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值