spark初步学习-scala代码

ReadDatasetExample

import org.apache.spark.sql.{SparkSession, DataFrame}
import org.apache.spark.sql.functions._

object ReadDatasetExample {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("ReadDatasetExample")
      .master("local")
      .getOrCreate()

    val filePath = "NHANES_age_prediction.csv" // 替换为您实际的文件路径

    // 读取CSV文件并创建DataFrame
    val df: DataFrame = spark.read
      .format("csv")
      .option("header", "true")
      .load(filePath)

    // 显示DataFrame的内容
    df.show()

    // 查看描述性统计信息
    df.describe().show()

    // 提取所需的列
    val selectedColumns = Seq("age_group", "RIDAGEYR", "RIAGENDR", "PAQ605", "BMXBMI", "LBXGLU", "DIQ010", "LBXGLT", "LBXIN")
    val selectedDF = df.select(selectedColumns.map(col): _*)
    selectedDF.show(10)

    df.show(5)
    val columnNames = df.columns
    columnNames.foreach(println)
    df.printSchema()
    val medianAge = df.selectExpr("percentile_approx(RIDAGEYR, 0.5)").first().getDouble(0)

    val ageDistribution = df.groupBy("age_group")
      .agg(count("*").alias("count"))
      .sort("age_group")
      .withColumn("percentage", expr("count / sum(count) over () * 100"))

    ageDistribution.show()
  }
}

SummaryStatistics

import org.apache.spark.sql.{SparkSession, DataFrame}

object SummaryStatistics {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("SummaryStatistics")
      .master("local")
      .getOrCreate()

    val filePath = "NHANES_age_prediction.csv"
    val df: DataFrame = spark.read
      .format("csv")
      .option("header", "true")
      .load(filePath)

    // 数值列名称数组
    val numericColumns = df.dtypes.filter{ case (_, dataType) => dataType == "DoubleType" }  // 过滤出 DoubleType 类型的列
      .map{ case (column, _) => column }

    // 使用 describe 函数统计数值列
    val summaryStats = df.describe(numericColumns: _*)

    // 打印汇总统计结果
    summaryStats.show()

    spark.stop()
  }
}

DataPreprocessing

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.ml.feature.Imputer
import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoder}
import org.apache.spark.ml.feature.VectorAssembler

object DataPreprocessing {
  def main(args: Array[String]): Unit = {
    // 创建SparkSession对象
    val spark = SparkSession.builder()
      .appName("Data Preprocessing")
      .master("local")
      .getOrCreate()

    // 读取数据集并创建DataFrame对象
    val filePath = "NHANES_age_prediction.csv"
    val df = spark.read.format("csv")
      .option("header", "true")
      .option("inferSchema", "true")
      .load(filePath)

    // 缺失值处理
    val imputer = new Imputer()
      .setInputCols(Array("RIDAGEYR", "BMXBMI", "LBXGLU", "LBXGLT", "LBXIN"))
      .setOutputCols(Array("imputed_RIDAGEYR", "imputed_BMXBMI", "imputed_LBXGLU", "imputed_LBXGLT", "imputed_LBXIN"))
    val imputedData = imputer.fit(df).transform(df)

    // 标签编码
    val labelIndexer = new StringIndexer()
      .setInputCol("age_group")
      .setOutputCol("label")
    val encodedData = labelIndexer.fit(imputedData).transform(imputedData)

    // 特征向量化
    val featureCols = Array("RIDAGEYR", "RIAGENDR", "PAQ605", "BMXBMI", "LBXGLU", "DIQ010", "LBXGLT", "LBXIN")
    val assembler = new VectorAssembler()
      .setInputCols(featureCols)
      .setOutputCol("features")
    val assembledData = assembler.transform(encodedData)

    // 独热编码
    val encoder = new OneHotEncoder()
      .setInputCol("RIAGENDR")
      .setOutputCol("encoded_gender")
    val encodedData = encoder.transform(assembledData)

    // 显示处理后的数据
    encodedData.show()

    // 停止SparkSession
    spark.stop()
  }
}

AgeDistributionHistogram

import org.apache.spark.sql.functions._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.SparkSession

import org.jfree.chart.ChartFactory
import org.jfree.chart.plot.PlotOrientation
import org.jfree.data.category.DefaultCategoryDataset

object AgeDistributionHistogram {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("AgeDistributionHistogram")
      .master("local")
      .getOrCreate()

    val filePath = "NHANES_age_prediction.csv"

    // 读取CSV文件并创建DataFrame
    val df: DataFrame = spark.read
      .format("csv")
      .option("header", "true")
      .load(filePath)

    // 将年龄列转为 Double 类型
    val dfWithAge = df.withColumn("age", col("RIDAGEYR").cast(DataTypes.DoubleType))

    // 为直方图创建数据
    val ages = dfWithAge.select("age").rdd.map(_.getDouble(0)).collect()

    // 使用 jfreechart 绘制直方图
    val dataset = new DefaultCategoryDataset()
    val frequencies = spark.sparkContext.parallelize(ages).histogram(10)  // 调整 bins 的数量

    // 构建直方图数据集
    val binIndex = frequencies._1.indices
    for (i <- binIndex.dropRight(1)) { // 使用 dropRight(1) 避免超出索引范围
      dataset.addValue(frequencies._2(i).toDouble, "", s"${frequencies._1(i)}-${frequencies._1(i + 1)}")
    }


    // 创建直方图
    val chart = ChartFactory.createBarChart(
      "Age Distribution",  // 标题
      "Age Group",  // X 轴标签
      "Frequency",  // Y 轴标签
      dataset,  // 数据集
      PlotOrientation.VERTICAL,  // 方向
      true,  // 是否显示图例
      true,  // 是否生成工具提示
      false  // 是否生成 URLs
    )

    // 显示图形
    val frame = new org.jfree.chart.ChartFrame("Age Distribution Histogram", chart)
    frame.pack()
    frame.setVisible(true)

    spark.stop()
  }
}

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值