PCA主成分分析基于sparkml采用Java语言开发

PCA主成分分析基于sparkml采用Java语言开发

Java代码在最后面

什么是PCA?

以下解释来源于知乎回答(10 封私信 / 76 条消息) 论智 - 知乎 (zhihu.com)

PCA主元分析,主要用于数据降维

寻找所有产品中很不相同的属性(特征),也就是寻找尽可能体现产品差异的属性,并且可以通过这些属性预测或者重建原本产品的特征(feature)

x轴为feature1,Y轴为feature2,根据产品的feature1和feature2值得到产品(蓝点)在坐标中的散列图(假设他们是相关的)

img

在蓝点中画一条直线,将所有点投影到这条直线上,红点为蓝点到直线上的投影

img

红点在直线上的分布为方差,方差最大时产品差异最大(x=(x1+x2+…+xn)/n)

在这里插入图片描述

基于新特性(红点的位置)重建原本的两个特性(蓝点的位置),红线的长度为重建误差(红线的均方根距离),当直线和两侧的粉短线位置重合时,重建误差最小且方差最大,该线即为我们要找的

如何进行标准化PCA

以下来自知乎马同学(10 封私信 / 78 条消息) 马同学 - 知乎 (zhihu.com)高赞回答(10 封私信 / 78 条消息) 如何通俗易懂地讲解什么是 PCA(主成分分析)? - 知乎 (zhihu.com)

比如有如下区域1的特征信息

区域1
a10
b2
c1
d7
e3

区域一特征信息的均值是
在这里插入图片描述

以均值为原点0,那么得到区域1的特征信息与均值的差值
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-drHTJVzE-1645689076115)(https://www.zhihu.com/equation?tex=X_i-%5Coverline%7BX%7D)]

a10 -4.6 = 5.4
b2 -4.6 = - 2.6
c1 -4.6= -3.6
d7 -4.6= 2.4
e3 -4.6 = -1.6

这个过程叫 去中心化

用去中心化的数据可以直接计算出区域1的样本方差


在这里插入图片描述

现在新采集了区域2的特征信息,可以看出两者完全正相关,有一列其实是多余的

区域1区域2
a1010
b22
c11
d77
e33

求出区域1区域2特征信息的均值,分别对区域1区域2进行中心化后得到

区域1区域2
a5.45.4
b-2.6-2.6
c-3.6-3.6
d2.42.4
e-1.6-1.6

区域1区域2的特征信息的协方差是

把这个二维数据画在坐标轴上,横纵坐标分别为区域1区域2的特征信息,可以看出它们排列为一条直线

img

如果旋转坐标系,让横坐标和这条直线重合:

img

旋转后的坐标系,横纵坐标不再代表区域1区域2的特征信息了,而是两者的混合(术语是线性组合),这里把它们称作主元1主元2坐标值很容易用勾股定理计算出来,比如[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BgyyPa3v-1645689076117)(https://www.zhihu.com/equation?tex=a)]在“主元1”的坐标值为

img

很显然a在“主元2”上的坐标为0,把所有的区域1算到新的坐标系上

因为“主元2”全都为0,完全是多余的,我们只需要“主元1”就够了,这样就又把数据降为了一维,而且没有丢失任何信息:

img

非理想情况下的PCA

过程复杂,具体可移步知乎


在这里插入图片描述

主元2整体来看,数值很小,丢掉损失的信息也非常少,这样就实现了非理想情况下的降维。

Java代码实现

spark工程的pom依赖(CDH5.7)

       <!--spark ml-->
       <dependency>
           <groupId>org.apache.spark</groupId>
           <artifactId>spark-mllib_2.11</artifactId>
           <version>2.1.0.cloudera1</version>
       </dependency>

测试类

   @Test
   public void testpca() {
       List<Row> list = new ArrayList<>();
       list.add(RowFactory.create("1","47","34","50"));
       list.add(RowFactory.create("1","53","34","50"));
       list.add(RowFactory.create("1","54","38","50"));
       list.add(RowFactory.create("1","54","38","54"));
       list.add(RowFactory.create("1","54","54","54"));
       list.add(RowFactory.create("1","54","54","54"));
       list.add(RowFactory.create("1","54","54","54"));
       list.add(RowFactory.create("0","54","54","66"));
       list.add(RowFactory.create("0","58","54","66"));
       list.add(RowFactory.create("0","66","62","80"));
       //快速创建测试DataSet,也可以直接从hive取得
       Dataset<Row> rowDataset = DatasetCreateUtils.quickCreateStrDs(list, Lists.newArrayList("OID", "feature1","feature2","feature3"));
       //仅可计算Double类型数据
       Dataset<Row> data = rowDataset.withColumn("feature1D", rowDataset.col("feature1").cast(DataTypes.DoubleType))
               .withColumn("feature2D", rowDataset.col("feature2").cast(DataTypes.DoubleType))
               .withColumn("feature3D", rowDataset.col("feature3").cast(DataTypes.DoubleType))
               ;
       String[] transClos = (String[]) Arrays.asList("feature1D", "feature2D", "feature3D").toArray();
       /**必须创建VectorAssembler,不然会报
       	Column qty_2 must be of type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 but was actually StringType
       	sparkml计算仅接收此种数据类型
       */
       VectorAssembler vectorAssembler = new VectorAssembler().setInputCols(transClos).setOutputCol("fea_vector");

       Dataset<Row> dataset = vectorAssembler.transform(data);
		//调用主成分分析方法 PrincipalComponenAnalysis.pca
       Map<String, Object> pca = PrincipalComponenAnalysis.pca(dataset,"fea_vector");
       for (String s : pca.keySet()) {
           System.out.println(s);
           System.out.println(pca.get(s));
       }
   }

快速创建测试DataSet

import org.apache.commons.collections.CollectionUtils;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;
//快速创建dataset
public class DatasetCreateUtils {
   //按照自己的方式获取session即可
   final private static SparkSessionTemplate SESSION_TEMPLATE = SparkSessionTemplate.getInstance();
   final private static SparkSession SESSION = SESSION_TEMPLATE.getSession();
   final private static JavaSparkContext CONTEXT = SESSION_TEMPLATE.getContext();

   /**
    * 快速创建字符schema制造的Dataset
    *
    * @param lines     行数据
    * @param colNames  每个行对应的字符名
    * @return  快速创建字符schema制造的Dataset
    */
   public static Dataset<Row> quickCreateStrDs(List<Row> lines, List<String> colNames) {
       if (CollectionUtils.isEmpty(lines)||CollectionUtils.isEmpty(colNames)){
           throw new RuntimeException("创建的行数据或者对应的列名为空");
       }

       List<StructField> struct = new ArrayList<>();
       for (String colName:colNames){
           struct.add(DataTypes.createStructField(colName, DataTypes.StringType, true));
       }
       StructType schema = DataTypes.createStructType(struct);

       return SESSION.createDataFrame(lines, schema);
   }
}

主成分分析方法

import org.apache.spark.ml.feature.PCA;
import org.apache.spark.ml.feature.PCAModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

//主成分分析
public class PrincipalComponenAnalysis {
   public static Map<String, Object> pca(Dataset<Row> source,String features) {
       SparkSessionTemplate template = SparkSessionTemplate.getInstance();
       Map<String, Object> map = new HashMap<>();
       //训练数据
       map.put("training", toList(source));
       //设置算法参数
       PCA pca = new PCA()
               .setInputCol(features)//计算的字段
               .setOutputCol("pcaFeatures")//输出的字段
               .setK(3);//最小是1     迭代次数  维度
       //训练模型
       PCAModel pcaModel = pca.fit(source);

       pcaModel.pc();
       //转化数据
       Dataset<Row> predictions = pcaModel.transform(source).select("pcaFeatures");
       predictions.show(false);//此处可以回写到hive中,此处仅展示

       map.put("data", toList(predictions));
       return map;
   }

   /**
    * dataset数据转化为list数据
    *
    * @param record 数据
    * @return 数据集合
    */
   private static List<Map<String, String>> toList(Dataset<Row> record) {
       List<Map<String, String>> list = new ArrayList<>();
       String[] columns = record.columns();
       List<Row> rows = record.collectAsList();
       for (Row row : rows) {
           Map<String, String> obj = new HashMap<>();
           for (int j = 0; j < columns.length; j++) {
               String col = columns[j];
               Object rowAs = row.getAs(col);
               String val = "";
               val = rowAs.toString();

               obj.put(col, val);
           }
           list.add(obj);
       }
       return list;
   }
}

计算结果

±----------------------------------------------------------+
|pcaFeatures |
±----------------------------------------------------------+
|[-70.29375705173524,19.52136312893162,-23.296010412515088] |
|[-72.05630490597821,21.011030408503178,-28.834450306354988]|
|[-75.12068734584295,18.37619359234054,-29.65127049124429] |
|[-77.7556332856957,20.964895807465595,-28.116447786349624] |
|[-88.838131142326,9.432437023100675,-27.691435263346882] |
|[-88.838131142326,9.432437023100675,-27.691435263346882] |
|[-88.838131142326,9.432437023100675,-27.691435263346882] |
|[-96.74296896188423,17.198543668475843,-23.086967148662886]|
|[-97.91800086471287,18.191655188190218,-26.779260411222815]|
|[-115.0316243881699,23.472106588374192,-28.579461207709983]|
±----------------------------------------------------------+

创建的vectorAssembler数据fea_vector

{feature2=34, feature3=50, feature1D=47.0, feature1=47, feature2D=34.0, LOT_ID=1, feature3D=50.0, fea_vector=[47.0,34.0,50.0]}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值