神经网络的MLPC(多层感知器分类器)

pom

<dependencies>
     <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>2.0.0</version>
        </dependency>

        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>2.11.12</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.11</artifactId>
            <version>2.0.0</version>
        </dependency>
</dependencies>
<!--打可执行jar包-->
<build>
    <plugins>
        <plugin>
            <groupId>org.apache.maven.plugins</groupId>
            <artifactId>maven-compiler-plugin</artifactId>
            <version>3.3</version>
            <configuration>
                <source>1.8</source>
                <target>1.8</target>
                <encoding>UTF-8</encoding>
            </configuration>
        </plugin>
    </plugins>
    <resources>
        <resource>
            <directory>src/main/resources</directory>
            <includes>
                <include>**/*.*</include>
            </includes>
        </resource>
    </resources>
</build>

 

代码

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FilterFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.*;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * 多层感知器分类器
 */

public class Multilayer {

    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder().config("spark.sql.warehouse.dir", "D:\\code\\eye\\subject\\src\\test\\cache").appName("test").master("local[2]").getOrCreate();
        spark.sparkContext().setLogLevel("ERROR");//忽略日志
        List<String> myList = getDataList();//1.构造假数据
        JavaSparkContext jsc = JavaSparkContext.fromSparkContext(spark.sparkContext());
        JavaRDD<String> value = jsc.parallelize(myList);//2.数据转化为JavaRDD,线上环境直接使用jsc读取HDFS文件得到JavaRDD
        JavaRDD<Row> rows = getRowJavaRDD(value);//将JavaRDD中的String转化为Row
        StructType schema = getStructType();//获取JavaRDD转化DataSet的转化规则
        Dataset<Row> data = spark.createDataFrame(rows, schema);//转化为DataSet
        data.show();//查看和转化后的数据
        Dataset<Row>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 1234L);//划分训练集与测试集,划分为两部分比例为6:4,seed是一个随机数,可忽略
        int[] layers = new int[]{2, 5, 4, 2};//对于神经网络指定的层,输入层为2类,输出层为2类,隐藏层功两层,注意输入层的特征数和输出层的标签类型数都需要与原始数据保持一致
        //创建训练器和设置其参数
        MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier()
//                .setFeaturesCol("features")//默认为Row中的features,可以主动修改特征字段
//                .setLabelCol("label")//默认为Row中的label,可以主动修改标签字段
                .setLayers(layers)
                .setBlockSize(128)//在矩阵中堆叠输入数据的块大小。数据在分区内堆叠,如果块大小大于分区中的剩余数据,则会将其调整为该数据的大小。建议尺寸在10到1000之间
                .setMaxIter(100);//最大迭代次数
        MultilayerPerceptronClassificationModel model = trainer.fit(splits[0]);//使用训练数据训练模型
        //---------------------测试模型---------------------------
        double trainProbability = getProbability(splits[0], model);//使用训练数据测试模型,获取训练的精度
        System.out.println("训练正确率为" + trainProbability + "%");
        //---------------------测试模型---------------------------
        double testProbability = getProbability(splits[0], model);//使用测试数据测试模型,获取测试的精度
        System.out.println("测试正确率为" + testProbability + "%");
        //---------------------校验数据---------------------------
        //使用新数据对模型进行校验
        double predict = model.predict(new SparseVector(2, new int[]{0, 1}, new double[]{2D, 3D}));
        System.out.println(predict);
        //批量数据识别,可以将批量数转化为RDD,再将model广播到每一台服务器中,在执行者内部进行分布式并发识别
    }

    /**
     * 将JavaRDD中的String格式化为Row,并提取将特征与标签
     *
     * @param value
     * @return
     */
    private static JavaRDD<Row> getRowJavaRDD(JavaRDD<String> value) {
        return value.map((Function<String, Row>) line -> {//3.将String转化为Row类型数据,用于后续训练数据(训练数据api的输入为DataSet<Row>)
            String[] parts = line.split(",");
            Double label = Double.parseDouble(parts[0]);//标签定义
            Double one = Double.parseDouble(parts[1]);//特征定义,特征可以定义多个,这里值定义了两个
            Double two = Double.parseDouble(parts[2]);
            //将特征合并成特征向量,感知器的api中数据输入格式里面的特征要求是Vector类型,所以这里需要转化
            Vector sparseVector = new SparseVector(2, new int[]{0, 1}, new double[]{one, two});//参数解释 1:特征字段数组,2:特征字段下标集合,3:特征字段值数组
            return RowFactory.create(label, sparseVector);//创建Row(标签,特征向量)
        });
    }

    /**
     * 定义JavaRD转化DataSet的转化规则
     *
     * @return
     */
    private static StructType getStructType() {
        ArrayList<StructField> fields = new ArrayList<>();//使用schema生成方案用于JavaRDD转化为DataSet
        fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));//定义字段类型为标签类型
        fields.add(DataTypes.createStructField("features", org.apache.spark.ml.linalg.SQLDataTypes.VectorType(), true));//定义字段类型为特征类型
        return DataTypes.createStructType(fields);
    }

    /**
     * 校验识别结果是否正确
     *
     * @param test
     * @param model
     * @return
     */
    private static double getProbability(Dataset<Row> test, MultilayerPerceptronClassificationModel model) {
        Dataset<Row> transform = model.transform(test);//校验数据
        long sum = transform.count();
        Dataset<Row> trueData = transform.filter((FilterFunction<Row>) value1 -> value1.getDouble(0) == value1.getDouble(value1.length() - 1));
        return trueData.count() * 100D / sum;
    }

    /**
     * 假数据
     * 数据格式为 标签,特征,特征
     * 数据格式可以自行修改,改完格式之后需要将解析方式一并修改
     *
     * @return
     */
    private static List<String> getDataList() {
        List<String> myList1 = Arrays.asList("1,1,-1", "1,3,-4", "1,5,3");
        List<String> myList2 = Arrays.asList("0,1,3", "0,-3,2", "0,2,3");
        List<String> myList = new ArrayList<>();
        myList.addAll(myList1);
        myList.addAll(myList2);
        return myList;
    }
}

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小钻风巡山

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值