Spark Java 用 KMeans算法实现图片压缩

压缩前:981 KB
这里写图片描述
压缩后:111 KB
这里写图片描述

思路:
取得图片每一点的像素,组成向量Vector如下:(w,h,R,G,B);
设置目的K值,训练所有点,获得KMeansModel;
此遍历所有的点,利用模型预测每个点属于哪个 中心点,同时改变这个点的R,G,B值使这个点的颜色 与这个点所在的集合相同;
重新利用收集的数据画出图片。

一共需要两个类,一个处理跟图片相关,一个处理KMeans算法,好像pytho语言写就好简单,哎。

处理图片类如下:

import java.awt.Color;  
import java.awt.image.BufferedImage;  
import java.io.File;  
import java.io.FileOutputStream;  
import java.io.IOException;  
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;

import javax.imageio.ImageIO;

import scala.Tuple3;  

public class AnalyzePicture {  

    /**
     * 
     * @param oldFilePath:你想压缩的图片的地址
     * @param newFilePath:压缩后存放的地址
     * @param t:tuple类型(scala里面的,因为在学习spark所以就拿来用),<Integer,Integer,Integer>-><weight,high,RGB>
     */
    public static void generatePhoto(String oldFilePath,String newFilePath,List<Tuple3<Integer,Integer,Integer>>t)
    {
        try{
            BufferedImage imgOld = ImageIO.read(new File(oldFilePath));
            int w=imgOld.getWidth();
            int h=imgOld.getHeight();

            File out = new File(newFilePath);  
            if (!out.exists())  
                out.createNewFile();  
            OutputStream output = new FileOutputStream(out);  

            BufferedImage imgOut = new BufferedImage(w, h,  
                    BufferedImage.TYPE_3BYTE_BGR); 
            for(Tuple3<Integer,Integer,Integer> tt:t)
            {
                imgOut.setRGB( tt._1(), tt._2(),tt._3());
            }

            ImageIO.write(imgOut, "png", output);  

            output.close();

        }catch(Exception e)
        {
            e.printStackTrace();
        }
    }
    /**
     * 作用是返回图片每一点的RGP值
     * @param filePath :想要处理的图片的地址
     * @return String: with,high,R,G,P,这里我把RGP独立成三个方向的值
     */
    public static List<String> getImageGRBStr(String filePath) {
        File file  = new File(filePath);

        List<String>list=new ArrayList<String>();

        if (!file.exists()) {
            return null;
        }
        try {
            BufferedImage bufImg = ImageIO.read(file);
            int height = bufImg.getHeight();
            int width = bufImg.getWidth();

            for (int i = bufImg.getMinX(); i < width; i++) {
                for (int j = bufImg.getMinY(); j < height; j++) {
                    String str=i+","+j+","+((bufImg.getRGB(i, j) & 0xff0000) >> 16)+","+((bufImg.getRGB(i, j) & 0xff00) >> 8)+","+((bufImg.getRGB(i, j) & 0xff));
                    list.add(str);    
                }
            }

        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        return list;
    }
}  

KMeans算法类如下:

import java.util.ArrayList;
import java.util.List;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.SparkSession;

import scala.Tuple3;

public class PhotoKMeans {

    public static void main(String[]args)
    {
        //屏幕spark多余的log
        Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
        Logger.getLogger("org.apache.spark").setLevel(Level.OFF);

        //获取到到了图片的每个点的情况String->w,h,R,G,P
        List<String>list=AnalyzePicture.getImageGRBStr(args[0]);
        //设置spark为本地模式,这样就不要老是跑到linux集群上去跑了
        SparkSession spark=SparkSession.builder().master("local").appName("PhotoKMeans").getOrCreate();
        //目的,初始化非file的数据源
        JavaSparkContext javaSpark=new JavaSparkContext(spark.sparkContext());

        //首先把List<String>变成RDD<String>,RDD中的String为"w,h,R,G,P",去除","之后,RDD中的String映射为向量 {w,h,R,G,P}
        RDD<Vector>rdd=javaSpark.parallelize(list).map(new StringToVector(",")).rdd().cache();

        KMeansModel model=KMeans.train(rdd, Integer.parseInt(args[2]), Integer.parseInt(args[3]));//拿去处理

        //获得处理后的中心向量 
        Vector[]vectors=model.clusterCenters();

        //获得图片所有点的向量 ,为了下面获取每个点属于哪个中心点而准备
        List<Vector>points=rdd.toJavaRDD().collect();

        //存储重新调整之后图片每个点的w,h,RGP
        List<Tuple3<Integer,Integer,Integer>>tuple=new ArrayList<Tuple3<Integer,Integer,Integer>>();

        for(Vector v:points)
        {
            int cluster=model.predict(v);//start from 0

            double[]rgbs=vectors[cluster].toArray();
            double[]x=v.toArray();
            int ww=(int)x[0];
            int hh=(int)x[1];

            int rgb=((int) rgbs[2]<<16)+((int) rgbs[3]<<8)+((int) rgbs[4]);//这里是将R,G,P变成RGP
            Tuple3<Integer,Integer,Integer>t=new Tuple3<>(ww,hh,rgb);

            tuple.add(t);
        }
        //交给处理类生成新的图片
        AnalyzePicture.generatePhoto(args[0], args[1], tuple);

    }


    public static class StringToVector implements Function<String,Vector>
    {
        String target="";
        public StringToVector(String target)
        {
            this.target=target;
        }
        @Override
        public Vector call(String a) throws Exception {
            // TODO Auto-generated method stub
            String[]aa=a.split(target);
            double[]aaa=new double[aa.length];
            for(int i=0;i<aa.length;i++)
            {
                aaa[i]=Double.parseDouble(aa[i]);
            }
            return Vectors.dense(aaa);
        }   
    }
}

最用调用KMeans类,参数如下:
args[0] =你想处理的图片地址
args[1] =处理后新图片的地址
args[2] =K值(K越大图片越逼真)
args[3] =允许的最大迭代次数

C:/Users/Administrator/Pictures/fengjing.png C:/Users/Administrator/Pictures/2_564.png 600 5

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值