Spark StringIndexer and IndexToString usage

StringIndexer

StringIndexer将一列labels转译成[0,labels基数)的index,labels基数即为labels的去重后总量,index的顺序为labels频次升序,因此出现最多次labels的index为0。如果输入的列时数字类型,我们会把它转化成string,并且使用string转译成index。当pipeline的下游组件例如Estimator或者Transformer使用生成的index时,需要将该组件的输入列名称设置为index的列名。在多数情况下,你可以使用setInputCol设置列名。

另外,当StringIndexer fit了一个dataset后,transfomer一个dataset遇到没见过的labels时,有两种处理策略:

  • 抛出异常(默认)
  • 跳过整行数据,setHandleInvalid(“skip”)

exmaple

import org.apache.spark.ml.feature.StringIndexer

val df = spark.createDataFrame(
  Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
).toDF("id", "category")

val indexer = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("categoryIndex")

val indexed = indexer.fit(df).transform(df)
indexed.show()

IndexToString

与StringIndexer对称的,IndexToString将index映射回原先的labels。通常我们使用StringIndexer产生index,然后使用模型训练数据,最后使用IndexToString找回原先的labels。

example

import org.apache.spark.ml.feature.{IndexToString, StringIndexer}

val df = spark.createDataFrame(Seq(
  (0, "a"),
  (1, "b"),
  (2, "c"),
  (3, "a"),
  (4, "a"),
  (5, "c")
)).toDF("id", "category")

val indexer = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("categoryIndex")
  .fit(df)
val indexed = indexer.transform(df)

val converter = new IndexToString()
  .setInputCol("categoryIndex")
  .setOutputCol("originalCategory")

val converted = converter.transform(indexed)
converted.select("id", "originalCategory").show()

官方的例子并不是很好,稍微修改一下或许你能更容易明白:

import org.apache.spark.ml.feature.{IndexToString, StringIndexer}

val df = spark.createDataFrame(Seq(
  (0, "a"),
  (1, "b"),
  (2, "c"),
  (3, "a"),
  (4, "a"),
  (5, "c")
)).toDF("id", "category")

val indexer = new StringIndexer().setInputCol("category").setOutputCol("categoryIndex").fit(df)
val indexed = indexer.transform(df)

//设置indexer的labels
val converter = new IndexToString().setInputCol("categoryIndex").setOutputCol("originalCategory").setLabels(indexer.labels)


val df1 = spark.createDataFrame(Seq(
  (10, 2.0),
  (11, 2.0),
  (12, 0.0),
  (13, 0.0),
  (14, 1.0),
  (15, 1.0)
)).toDF("id", "categoryIndex")

val converted = converter.transform(df1)
converted.select("id","categoryIndex""originalCategory").show()

引用

官方文档

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值