目录
0. 相关文章链接
1. 概述
使用RFormula选择特征列在Spark2.1.0版本只支持一部分R操作,包括:~’, ‘.’, ‘:’, ‘+’, and ‘-‘.
~ separate target and terms 分割标签与特征
+ concat terms, “+ 0” means removing intercept 将两个特征相加
- remove a term, “- 1” means removing intercept 减去一个特征
: interaction (multiplication for numeric values, or binarized categorical values) 将多个特征相乘变成一个特征
. all columns except target 选取所有特征
举个小例子,假设有a,b两列作为2个特征,y是应变量。
y ~ a + b 表示建立这样的线性模型:y ~ w0 + w1 * a + w2 * b ,其中w0是截距。
y ~ a + b + a:b - 1 表示线性模型:y ~ w1 * a + w2 * b + w3 * a * b
(-1表示去掉截距,所以模型中没有w0了,a:b表示将ab两个特征相乘生成新的特征)
也就是说,我们可以通过这些简单的符号去表示线性模型。
RFormula可以生成多组列向量来表示特征,和一组double或string类型的列来标签。
就像在R中使用公式来建立线性模型一样,字符串类型的特征会被One-hot编码,数值类型的特征会被转换成double类型。如果标签列是字符串类型,会先将它转换成双精度的字符串索引。
如果在dataframe中不存在标签列,将会根据公式中的自变量去生成标签应变量作为输出。
来看下面的例子,假设有一个4列的dataframe:
id | country | hour | clicked |
---|---|---|---|
7 | “US” | 18 | 1.0 |
8 | “CA” | 12 | 0.0 |
9 | “NZ” | 15 | 0.0 |
如果使用RFormula,并且构建公式: clicked ~ country + hour,它表示通过country, hour这两个特征去预测clicked这个应变量。于是我们会得到以下dataframe:
id | country | hour | clicked | features | label |
---|---|---|---|---|---|
7 | “US” | 18 | 1.0 | [0.0, 0.0, 18.0] | 1.0 |
8 | “CA” | 12 | 0.0 | [0.0, 1.0, 12.0] | 0.0 |
9 | “NZ” | 15 | 0.0 | [1.0, 0.0, 15.0] | 0.0 |
features列为转换后的特征表示,因为country是字符串类型的类编变量,故进行one-hot编码变成了两列, hour是数值型的,故转换成double类型。
label列是应变量click列,双精度的类型保持不变。
2. Spark代码
/**
* Created by cc on 17-1-11.
*/
object FeatureSelection {
def main(args: Array[String]) {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
val conf = new SparkConf().setAppName("FeatureSelection").setMaster("local")
val sc = new SparkContext(conf)
val spark = SparkSession
.builder()
.appName("Feature Extraction")
.config("spark.some.config.option", "some-value")
.getOrCreate()
// 认为创建一个dataframe,有3行4列
val dataset = spark.createDataFrame(Seq(
(7, "US", 18, 1.0),
(8, "CA", 12, 0.0),
(9, "NZ", 15, 0.0)
)).toDF("id", "country", "hour", "clicked")
// 训练
val formula = new RFormula() //创建一个对象
.setFormula("clicked ~ country + hour") //设置公式
.setFeaturesCol("feature") //设置选择出来的特征的列名
.setLabelCol("label") //设置标签列的列名
val model = formula.fit(dataset)
//转换
val output = model.transform(dataset)
output.show(false)
output.select("feature", "label").show(false)
spark.close()
}
}
打印结果:
+---+-------+----+-------+--------------+-----+
|id |country|hour|clicked|feature |label|
+---+-------+----+-------+--------------+-----+
|7 |US |18 |1.0 |[0.0,0.0,18.0]|1.0 |
|8 |CA |12 |0.0 |[1.0,0.0,12.0]|0.0 |
|9 |NZ |15 |0.0 |[0.0,1.0,15.0]|0.0 |
+---+-------+----+-------+--------------+-----+
+--------------+-----+
|feature |label|
+--------------+-----+
|[0.0,0.0,18.0]|1.0 |
|[1.0,0.0,12.0]|0.0 |
|[0.0,1.0,15.0]|0.0 |
+--------------+-----+
注:其他相关文章链接由此进 -> 算法文章汇总