ReadDatasetExample
import org.apache.spark.sql.{SparkSession, DataFrame}
import org.apache.spark.sql.functions._
object ReadDatasetExample {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("ReadDatasetExample")
.master("local")
.getOrCreate()
val filePath = "NHANES_age_prediction.csv" // 替换为您实际的文件路径
// 读取CSV文件并创建DataFrame
val df: DataFrame = spark.read
.format("csv")
.option("header", "true")
.load(filePath)
// 显示DataFrame的内容
df.show()
// 查看描述性统计信息
df.describe().show()
// 提取所需的列
val selectedColumns = Seq("age_group", "RIDAGEYR", "RIAGENDR", "PAQ605", "BMXBMI", "LBXGLU", "DIQ010", "LBXGLT", "LBXIN")
val selectedDF = df.select(selectedColumns.map(col): _*)
selectedDF.show(10)
df.show(5)
val columnNames = df.columns
columnNames.foreach(println)
df.printSchema()
val medianAge = df.selectExpr("percentile_approx(RIDAGEYR, 0.5)").first().getDouble(0)
val ageDistribution = df.groupBy("age_group")
.agg(count("*").alias("count"))
.sort("age_group")
.withColumn("percentage", expr("count / sum(count) over () * 100"))
ageDistribution.show()
}
}
SummaryStatistics
import org.apache.spark.sql.{SparkSession, DataFrame}
object SummaryStatistics {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("SummaryStatistics")
.master("local")
.getOrCreate()
val filePath = "NHANES_age_prediction.csv"
val df: DataFrame = spark.read
.format("csv")
.option("header", "true")
.load(filePath)
// 数值列名称数组
val numericColumns = df.dtypes.filter{ case (_, dataType) => dataType == "DoubleType" } // 过滤出 DoubleType 类型的列
.map{ case (column, _) => column }
// 使用 describe 函数统计数值列
val summaryStats = df.describe(numericColumns: _*)
// 打印汇总统计结果
summaryStats.show()
spark.stop()
}
}
DataPreprocessing
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.ml.feature.Imputer
import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoder}
import org.apache.spark.ml.feature.VectorAssembler
object DataPreprocessing {
def main(args: Array[String]): Unit = {
// 创建SparkSession对象
val spark = SparkSession.builder()
.appName("Data Preprocessing")
.master("local")
.getOrCreate()
// 读取数据集并创建DataFrame对象
val filePath = "NHANES_age_prediction.csv"
val df = spark.read.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load(filePath)
// 缺失值处理
val imputer = new Imputer()
.setInputCols(Array("RIDAGEYR", "BMXBMI", "LBXGLU", "LBXGLT", "LBXIN"))
.setOutputCols(Array("imputed_RIDAGEYR", "imputed_BMXBMI", "imputed_LBXGLU", "imputed_LBXGLT", "imputed_LBXIN"))
val imputedData = imputer.fit(df).transform(df)
// 标签编码
val labelIndexer = new StringIndexer()
.setInputCol("age_group")
.setOutputCol("label")
val encodedData = labelIndexer.fit(imputedData).transform(imputedData)
// 特征向量化
val featureCols = Array("RIDAGEYR", "RIAGENDR", "PAQ605", "BMXBMI", "LBXGLU", "DIQ010", "LBXGLT", "LBXIN")
val assembler = new VectorAssembler()
.setInputCols(featureCols)
.setOutputCol("features")
val assembledData = assembler.transform(encodedData)
// 独热编码
val encoder = new OneHotEncoder()
.setInputCol("RIAGENDR")
.setOutputCol("encoded_gender")
val encodedData = encoder.transform(assembledData)
// 显示处理后的数据
encodedData.show()
// 停止SparkSession
spark.stop()
}
}
AgeDistributionHistogram
import org.apache.spark.sql.functions._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.SparkSession
import org.jfree.chart.ChartFactory
import org.jfree.chart.plot.PlotOrientation
import org.jfree.data.category.DefaultCategoryDataset
object AgeDistributionHistogram {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("AgeDistributionHistogram")
.master("local")
.getOrCreate()
val filePath = "NHANES_age_prediction.csv"
// 读取CSV文件并创建DataFrame
val df: DataFrame = spark.read
.format("csv")
.option("header", "true")
.load(filePath)
// 将年龄列转为 Double 类型
val dfWithAge = df.withColumn("age", col("RIDAGEYR").cast(DataTypes.DoubleType))
// 为直方图创建数据
val ages = dfWithAge.select("age").rdd.map(_.getDouble(0)).collect()
// 使用 jfreechart 绘制直方图
val dataset = new DefaultCategoryDataset()
val frequencies = spark.sparkContext.parallelize(ages).histogram(10) // 调整 bins 的数量
// 构建直方图数据集
val binIndex = frequencies._1.indices
for (i <- binIndex.dropRight(1)) { // 使用 dropRight(1) 避免超出索引范围
dataset.addValue(frequencies._2(i).toDouble, "", s"${frequencies._1(i)}-${frequencies._1(i + 1)}")
}
// 创建直方图
val chart = ChartFactory.createBarChart(
"Age Distribution", // 标题
"Age Group", // X 轴标签
"Frequency", // Y 轴标签
dataset, // 数据集
PlotOrientation.VERTICAL, // 方向
true, // 是否显示图例
true, // 是否生成工具提示
false // 是否生成 URLs
)
// 显示图形
val frame = new org.jfree.chart.ChartFrame("Age Distribution Histogram", chart)
frame.pack()
frame.setVisible(true)
spark.stop()
}
}