随机森林

pom

 

<dependencies>
    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-core_2.12</artifactId>
        <version>2.4.0</version>
    </dependency>

    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-streaming_2.12</artifactId>
        <version>2.4.0</version>
    </dependency>

    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-mllib_2.12</artifactId>
        <version>2.4.0</version>
    </dependency>

    <dependency>
        <groupId>com.thoughtworks.paranamer</groupId>
        <artifactId>paranamer</artifactId>
        <version>2.8</version>
    </dependency>
</dependencies>
<!--打可执行jar包-->
<build>
    <plugins>
        <plugin>
            <groupId>org.apache.maven.plugins</groupId>
            <artifactId>maven-compiler-plugin</artifactId>
            <version>3.3</version>
            <configuration>
                <source>1.8</source>
                <target>1.8</target>
                <encoding>UTF-8</encoding>
            </configuration>
        </plugin>
    </plugins>
    <resources>
        <resource>
            <directory>src/main/resources</directory>
            <includes>
                <include>**/*.*</include>
            </includes>
        </resource>
    </resources>
</build>

 

 

代码

 

import java.util.HashMap;
import java.util.Map;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import scala.Tuple2;

public class MyRandomForest {
    public static void main(String[] args) {
        SparkConf sparkConf = new SparkConf().setAppName("app").setMaster("local[1]");
        JavaSparkContext jsc = new JavaSparkContext(sparkConf);
        // 加载数据
        String path = "D:\\IdeaProjects\\SparkMLlib\\src\\test\\java\\data4";
        JavaRDD<String> javaRDD = jsc.textFile(path);
        JavaRDD<LabeledPoint> data = javaRDD.map(new Function<String, LabeledPoint>() {
            @Override
            public LabeledPoint call(String line) throws Exception {
                String[] split = line.split(",");
                String[] arr = split[1].split(" ");
                double[] vectors = new double[arr.length];
                for (int i = 0; i < arr.length; i++) {
                    vectors[i] = Double.parseDouble(arr[i]);
                }
                LabeledPoint labeledPoint = new LabeledPoint(Double.parseDouble(split[0]), Vectors.dense(vectors));
                return labeledPoint;
            }
        });
        // 将数据集划分为训练数据和测试数据
        JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});//将数据分成7:3
        JavaRDD<LabeledPoint> training = splits[0];
        JavaRDD<LabeledPoint> testData = splits[1];
        // 随机森林模型训练
        Integer numClasses = 2;//划分的类型数量
        Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
        Integer numTrees = 1; // 树的数量
        String featureSubsetStrategy = "auto"; //算法自动选择 auto/all
        String impurity = "gini";//随机森林有三种方式,entropy,gini,variance,回归肯定就是variance
        Integer maxDepth = 10;//深度
        Integer maxBins = 32;//数据最大分端数
        Integer seed = 1000000;//采样种子,种子不变,采样结果不变
        //训练模型
        RandomForestModel model = RandomForest.trainClassifier(
                training,
                numClasses,
                categoricalFeaturesInfo,
                numTrees,
                featureSubsetStrategy,
                impurity,
                maxDepth,
                maxBins,
                seed
        );
        //测试数据预测
        JavaPairRDD<Double, Double> predictionAndLabel = testData
                .mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
        //计算错误率
        double testErr = predictionAndLabel.filter(pl -> !pl._1.equals(pl._2())).count() / (double) testData.count();
        System.out.println("Test err:" + testErr);
        //打印树形结构
        System.out.println(model.toDebugString());
        //新数据预测
        Vector v = Vectors.dense(new double[]{3, 8});
        System.out.println("预测为" + model.predict(v));
    }
}


 

 

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小钻风巡山

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值