Spark Preprocessing&FE practice

7 篇文章 0 订阅
2 篇文章 0 订阅

最近因为要做推荐系统 ,为了熟悉 pyspark 的操作,并且熟悉一下处理日志数据 , 故尝试处理此数据集


数据集介绍
Ali_Display_Ad_Click是阿里巴巴提供的一个淘宝展示广告点击率预估数据集。 下载地址 https://tianchi.aliyun.com/dataset/dataDetail?dataId=56

数据名称说明属性
raw_sample原始的样本骨架用户ID,广告ID,时间,资源位,是否点击
ad_feature广告的基本信息广告ID,广告计划ID,类目ID,品牌ID
user_profile用户的基本信息用户ID,年龄层,性别等
raw_behavior_log用户的行为日志用户ID,行为类型,时间,商品类目ID,品牌ID

原始样本骨架raw_sample
我们从淘宝网站中随机抽样了114万用户8天内的广告展示/点击日志(2600万条记录),构成原始的样本骨架。
字段说明如下:
(1) user_id:脱敏过的用户ID;
(2) adgroup_id:脱敏过的广告单元ID;
(3) time_stamp:时间戳;
(4) pid:资源位;
(5) noclk:为1代表没有点击;为0代表点击;
(6) clk:为0代表没有点击;为1代表点击;
我们用前面7天的做训练样本(20170506-20170512),用第8天的做测试样本(20170513)。

广告基本信息表ad_feature
本数据集涵盖了raw_sample中全部广告的基本信息。字段说明如下:
(1) adgroup_id:脱敏过的广告ID;
(2) cate_id:脱敏过的商品类目ID;
(3) campaign_id:脱敏过的广告计划ID;
(4) customer_id:脱敏过的广告主ID;
(5) brand:脱敏过的品牌ID;
(6) price: 宝贝的价格
其中一个广告ID对应一个商品(宝贝),一个宝贝属于一个类目,一个宝贝属于一个品牌。

用户基本信息表user_profile
本数据集涵盖了raw_sample中全部用户的基本信息。字段说明如下:
(1) userid:脱敏过的用户ID;
(2) cms_segid:微群ID;
(3) cms_group_id:cms_group_id;
(4) final_gender_code:性别 1:男,2:女;
(5) age_level:年龄层次;
(6) pvalue_level:消费档次,1:低档,2:中档,3:高档;
(7) shopping_level:购物深度,1:浅层用户,2:中度用户,3:深度用户
(8) occupation:是否大学生 ,1:是,0:否
(9) new_user_class_level:城市层级

用户的行为日志behavior_log
本数据集涵盖了raw_sample中全部用户22天内的购物行为(共七亿条记录)。字段说明如下:
(1) user:脱敏过的用户ID;
(2) time_stamp:时间戳;
(3) btag:行为类型, 包括以下四种:
在这里插入图片描述
(4) cate:脱敏过的商品类目;
(5) brand: 脱敏过的品牌词;
这里以user + time_stamp为key,会有很多重复的记录;这是因为我们的不同的类型的行为数据是不同部门记录的,在打包到一起的时候,实际上会有小的偏差(即两个一样的time_stamp实际上是差异比较小的两个时间)。

Preprocessing & Feature Engineering

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('raw_sample').getOrCreate()
df = spark.read.csv(r'D:\阿里ctr预估数据集\raw_sample.csv', header=True)
df.show()
+------+----------+----------+-----------+------+---+
|  user|time_stamp|adgroup_id|        pid|nonclk|clk|
+------+----------+----------+-----------+------+---+
|581738|1494137644|         1|430548_1007|     1|  0|
|449818|1494638778|         3|430548_1007|     1|  0|
|914836|1494650879|         4|430548_1007|     1|  0|
|914836|1494651029|         5|430548_1007|     1|  0|
|399907|1494302958|         8|430548_1007|     1|  0|
|628137|1494524935|         9|430548_1007|     1|  0|
|298139|1494462593|         9|430539_1007|     1|  0|
|775475|1494561036|         9|430548_1007|     1|  0|
|555266|1494307136|        11|430539_1007|     1|  0|
|117840|1494036743|        11|430548_1007|     1|  0|
|739815|1494115387|        11|430539_1007|     1|  0|
|623911|1494625301|        11|430548_1007|     1|  0|
|623911|1494451608|        11|430548_1007|     1|  0|
|421590|1494034144|        11|430548_1007|     1|  0|
|976358|1494156949|        13|430548_1007|     1|  0|
|286630|1494218579|        13|430539_1007|     1|  0|
|286630|1494289247|        13|430539_1007|     1|  0|
|771431|1494153867|        13|430548_1007|     1|  0|
|707120|1494220810|        13|430548_1007|     1|  0|
|530454|1494293746|        13|430548_1007|     1|  0|
+------+----------+----------+-----------+------+---+
only showing top 20 rows

# 数据情况
print('样本数', df.count())
print('空值', df.count() - df.dropna().count())
样本数 26557961
空值 0
row1, row2 = df.groupBy("clk").count().collect()
r = row2.asDict()['count'] / row1.asDict()['count']
print('点击了的广告占比', r)
点击了的广告占比 0.05422599045209166
df.printSchema()
root
 |-- user: string (nullable = true)
 |-- time_stamp: string (nullable = true)
 |-- adgroup_id: string (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: string (nullable = true)
 |-- clk: string (nullable = true)

from pyspark.sql.types import StructField, StructType, IntegerType, LongType

# 更改列的数据类型
raw_sample_df = df.\
                withColumn('user', df.user.cast(IntegerType())).\
                withColumn('time_stamp', df.time_stamp.cast(LongType())).\
                withColumn('nonclk', df.nonclk.cast(IntegerType())).\
                withColumn('clk', df.clk.cast(IntegerType()))

raw_sample_df.printSchema()
raw_sample_df.show()
root
 |-- user: integer (nullable = true)
 |-- time_stamp: long (nullable = true)
 |-- adgroup_id: string (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: integer (nullable = true)
 |-- clk: integer (nullable = true)

+------+----------+----------+-----------+------+---+
|  user|time_stamp|adgroup_id|        pid|nonclk|clk|
+------+----------+----------+-----------+------+---+
|581738|1494137644|         1|430548_1007|     1|  0|
|449818|1494638778|         3|430548_1007|     1|  0|
|914836|1494650879|         4|430548_1007|     1|  0|
|914836|1494651029|         5|430548_1007|     1|  0|
|399907|1494302958|         8|430548_1007|     1|  0|
|628137|1494524935|         9|430548_1007|     1|  0|
|298139|1494462593|         9|430539_1007|     1|  0|
|775475|1494561036|         9|430548_1007|     1|  0|
|555266|1494307136|        11|430539_1007|     1|  0|
|117840|1494036743|        11|430548_1007|     1|  0|
|739815|1494115387|        11|430539_1007|     1|  0|
|623911|1494625301|        11|430548_1007|     1|  0|
|623911|1494451608|        11|430548_1007|     1|  0|
|421590|1494034144|        11|430548_1007|     1|  0|
|976358|1494156949|        13|430548_1007|     1|  0|
|286630|1494218579|        13|430539_1007|     1|  0|
|286630|1494289247|        13|430539_1007|     1|  0|
|771431|1494153867|        13|430548_1007|     1|  0|
|707120|1494220810|        13|430548_1007|     1|  0|
|530454|1494293746|        13|430548_1007|     1|  0|
+------+----------+----------+-----------+------+---+
only showing top 20 rows

特征工程

from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

# StringIndexer 指定某一个类型是字符串的列,进行编码 如该列有 'a','b', 'c' -> 0, 1, 2
stringindexer = StringIndexer(inputCol='pid', outputCol='pid_feature')
# 独热编码
encoder = OneHotEncoder(dropLast=False, inputCol='pid_feature', outputCol='pid_value')
# 用管道对编码步骤进行封装
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline = pipeline.fit(raw_sample_df)
df = pipeline.transform(raw_sample_df)
df.show()
+------+----------+----------+-----------+------+---+-----------+-------------+
|  user|time_stamp|adgroup_id|        pid|nonclk|clk|pid_feature|    pid_value|
+------+----------+----------+-----------+------+---+-----------+-------------+
|581738|1494137644|         1|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|449818|1494638778|         3|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494650879|         4|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494651029|         5|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|399907|1494302958|         8|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|628137|1494524935|         9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|298139|1494462593|         9|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|775475|1494561036|         9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|555266|1494307136|        11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|117840|1494036743|        11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|739815|1494115387|        11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|623911|1494625301|        11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|623911|1494451608|        11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|421590|1494034144|        11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|976358|1494156949|        13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|286630|1494218579|        13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|286630|1494289247|        13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|771431|1494153867|        13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|707120|1494220810|        13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|530454|1494293746|        13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
+------+----------+----------+-----------+------+---+-----------+-------------+
only showing top 20 rows

# pyspark.ml.feature.OneHotEncoder 返回的新的一列的 数据类型是
# 稀疏向量 pyspark.ml.linalg.SparseVector
# 向量 (1.0, 0.0, 1.0, 3.0) 的稠密向量表示是 [1.0, 0.0, 1.0, 3.0]
# 稀疏格式表示是(4, [0, 2, 3], [1.0, 1.0, 3.0]) => (向量长度, 元素索引, 值)
from pyspark.ml.linalg import SparseVector

print(SparseVector(4, [1, 3], [3.0, 4.0]))
print(SparseVector(4, [1, 3], [3.0, 4.0]).toArray()) # 转换为 numpy.ndarray
print(df.select("pid_value").first())
print(df.select("pid_value").first().pid_value.toArray())
(4,[1,3],[3.0,4.0])
[0. 3. 0. 4.]
Row(pid_value=SparseVector(2, {0: 1.0}))
[1. 0.]
df.describe('time_stamp').show()
+-------+--------------------+
|summary|          time_stamp|
+-------+--------------------+
|  count|            26557961|
|   mean|1.4943547981415155E9|
| stddev|  198755.26175048228|
|    min|          1494000000|
|    max|          1494691186|
+-------+--------------------+

# 时间间隔, 时间戳以秒为单位
time_temp = 1494691186 - 1494000000
# 最大时间戳和最小时间戳间隔 8天
time_temp / (24*60*60)
7.999837962962963
# 一共是 8 天的数据,前 7 天作为训练集,最后一天作为测试集

train_data = raw_sample_df.filter(raw_sample_df['time_stamp'] <= \
                                  (1494691186 - (24*60*60)))
test_data = raw_sample_df.filter(raw_sample_df['time_stamp'] > \
                                (1494691186 - (24*60*60)))

num1, num2 = train_data.count(), test_data.count()
num1, num2, num1 / (num1+num2)
(23249291, 3308670, 0.8754170171422422)
# 处理 ad_feature 数据集
df = spark.read.csv('D:/阿里ctr预估数据集/ad_feature.csv', header=True)
df.show()
+----------+-------+-----------+--------+------+-----+
|adgroup_id|cate_id|campaign_id|customer| brand|price|
+----------+-------+-----------+--------+------+-----+
|     63133|   6406|      83237|       1| 95471|170.0|
|    313401|   6406|      83237|       1| 87331|199.0|
|    248909|    392|      83237|       1| 32233| 38.0|
|    208458|    392|      83237|       1|174374|139.0|
|    110847|   7211|     135256|       2|145952|32.99|
|    607788|   6261|     387991|       6|207800|199.0|
|    375706|   4520|     387991|       6|  NULL| 99.0|
|     11115|   7213|     139747|       9|186847| 33.0|
|     24484|   7207|     139744|       9|186847| 19.0|
|     28589|   5953|     395195|      13|  NULL|428.0|
|     23236|   5953|     395195|      13|  NULL|368.0|
|    300556|   5953|     395195|      13|  NULL|639.0|
|     92560|   5953|     395195|      13|  NULL|368.0|
|    590965|   4284|      28145|      14|454237|249.0|
|    529913|   4284|      70206|      14|  NULL|249.0|
|    546930|   4284|      28145|      14|  NULL|249.0|
|    639794|   6261|      70206|      14| 37004| 89.9|
|    335413|   4284|      28145|      14|  NULL|249.0|
|    794890|   4284|      70206|      14|454237|249.0|
|    684020|   6261|      70206|      14| 37004| 99.0|
+----------+-------+-----------+--------+------+-----+
only showing top 20 rows

# 替换空值为 -1,在做处理
df = df.replace(to_replace='NULL', value='-1')
df.printSchema()
root
 |-- adgroup_id: string (nullable = true)
 |-- cate_id: string (nullable = true)
 |-- campaign_id: string (nullable = true)
 |-- customer: string (nullable = true)
 |-- brand: string (nullable = true)
 |-- price: string (nullable = true)

from pyspark.sql.types import FloatType

df = df.\
    withColumn("adgroup_id", df.adgroup_id.cast(IntegerType())).\
    withColumn("cate_id", df.cate_id.cast(IntegerType())).\
    withColumn("campaign_id", df.campaign_id.cast(IntegerType())).\
    withColumn("customer", df.customer.cast(IntegerType())).\
    withColumn("brand", df.brand.cast(IntegerType())).\
    withColumn("price", df.price.cast(FloatType()))

df.printSchema()
df.show()
root
 |-- adgroup_id: integer (nullable = true)
 |-- cate_id: integer (nullable = true)
 |-- campaign_id: integer (nullable = true)
 |-- customer: integer (nullable = true)
 |-- brand: integer (nullable = true)
 |-- price: float (nullable = true)

+----------+-------+-----------+--------+------+-----+
|adgroup_id|cate_id|campaign_id|customer| brand|price|
+----------+-------+-----------+--------+------+-----+
|     63133|   6406|      83237|       1| 95471|170.0|
|    313401|   6406|      83237|       1| 87331|199.0|
|    248909|    392|      83237|       1| 32233| 38.0|
|    208458|    392|      83237|       1|174374|139.0|
|    110847|   7211|     135256|       2|145952|32.99|
|    607788|   6261|     387991|       6|207800|199.0|
|    375706|   4520|     387991|       6|    -1| 99.0|
|     11115|   7213|     139747|       9|186847| 33.0|
|     24484|   7207|     139744|       9|186847| 19.0|
|     28589|   5953|     395195|      13|    -1|428.0|
|     23236|   5953|     395195|      13|    -1|368.0|
|    300556|   5953|     395195|      13|    -1|639.0|
|     92560|   5953|     395195|      13|    -1|368.0|
|    590965|   4284|      28145|      14|454237|249.0|
|    529913|   4284|      70206|      14|    -1|249.0|
|    546930|   4284|      28145|      14|    -1|249.0|
|    639794|   6261|      70206|      14| 37004| 89.9|
|    335413|   4284|      28145|      14|    -1|249.0|
|    794890|   4284|      70206|      14|454237|249.0|
|    684020|   6261|      70206|      14| 37004| 99.0|
+----------+-------+-----------+--------+------+-----+
only showing top 20 rows

df.describe().show()
+-------+-----------------+-----------------+------------------+------------------+------------------+------------------+
|summary|       adgroup_id|          cate_id|       campaign_id|          customer|             brand|             price|
+-------+-----------------+-----------------+------------------+------------------+------------------+------------------+
|  count|           846811|           846811|            846811|            846811|            846811|            846811|
|   mean|         423406.0|5868.593464185043|206552.60428005777|113180.40600559038|162566.00186464275|1838.8671081309947|
| stddev|244453.4237388931|2705.171203318181|125192.34090758237| 73435.83494972257|152482.73866344756| 310887.7001702612|
|    min|                1|                1|                 1|                 1|                -1|              0.01|
|    max|           846811|            12960|            423436|            255875|            461497|             1.0E8|
+-------+-----------------+-----------------+------------------+------------------+------------------+------------------+

# 除了广告显示的价格 price,其他都是离散型特征,查看唯一值的数量
for col in df.columns[:-1]:
    print(col, df.groupBy(col).count().count())
adgroup_id 846811
cate_id 6769
campaign_id 423436
customer 255875
brand 99815

这些特征属于高数量类别特征,不适合用独热编码进行处理,可以用 smoothing,咱这不是在搞竞赛,就算了哈

https://www.cnblogs.com/bjwu/p/9087071.html

价格的话,可以很好的反应广告的属性,也不需要进行标准化和归一化了

# user_profile 数据集
df = spark.read.csv('D:/阿里ctr预估数据集/user_profile.csv', header=True)
df.show()
df.count()
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level |
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|   234|        0|           5|                2|        5|        null|             3|         0|                    3|
|   523|        5|           2|                2|        2|           1|             3|         1|                    2|
|   612|        0|           8|                1|        2|           2|             3|         0|                 null|
|  1670|        0|           4|                2|        4|        null|             1|         0|                 null|
|  2545|        0|          10|                1|        4|        null|             3|         0|                 null|
|  3644|       49|           6|                2|        6|           2|             3|         0|                    2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                    2|
|  6211|        0|           9|                1|        3|        null|             3|         0|                    2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                    4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                    1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                    2|
|  9293|        0|           5|                2|        5|        null|             3|         0|                    4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                    2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                    2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10812|        0|           4|                2|        4|        null|             2|         0|                 null|
| 10912|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10996|        0|           5|                2|        5|        null|             3|         0|                    4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                    3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                    4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
only showing top 20 rows

1061768
df.printSchema()
root
 |-- userid: string (nullable = true)
 |-- cms_segid: string (nullable = true)
 |-- cms_group_id: string (nullable = true)
 |-- final_gender_code: string (nullable = true)
 |-- age_level: string (nullable = true)
 |-- pvalue_level: string (nullable = true)
 |-- shopping_level: string (nullable = true)
 |-- occupation: string (nullable = true)
 |-- new_user_class_level : string (nullable = true)

# 这里的 null 表示空值,前面的 NULL 是字符串
df.printSchema()
df = df.\
    withColumn('userid', df.userid.cast(IntegerType())).\
    withColumn('cms_segid', df.cms_segid.cast(IntegerType())).\
    withColumn('cms_group_id', df.cms_group_id.cast(IntegerType())).\
    withColumn('final_gender_code', df.final_gender_code.cast(IntegerType())).\
    withColumn('age_level', df.age_level.cast(IntegerType())).\
    withColumn('pvalue_level', df.pvalue_level.cast(IntegerType())).\
    withColumn('shopping_level', df.shopping_level.cast(IntegerType())).\
    withColumn('occupation', df.occupation.cast(IntegerType())).\
    withColumn('new_user_class_level ', df['new_user_class_level '].cast(IntegerType())) # 这里后面多了一个空格 'new_user_class_level '

df.printSchema()
df.show()
root
 |-- userid: string (nullable = true)
 |-- cms_segid: string (nullable = true)
 |-- cms_group_id: string (nullable = true)
 |-- final_gender_code: string (nullable = true)
 |-- age_level: string (nullable = true)
 |-- pvalue_level: string (nullable = true)
 |-- shopping_level: string (nullable = true)
 |-- occupation: string (nullable = true)
 |-- new_user_class_level : string (nullable = true)

root
 |-- userid: integer (nullable = true)
 |-- cms_segid: integer (nullable = true)
 |-- cms_group_id: integer (nullable = true)
 |-- final_gender_code: integer (nullable = true)
 |-- age_level: integer (nullable = true)
 |-- pvalue_level: integer (nullable = true)
 |-- shopping_level: integer (nullable = true)
 |-- occupation: integer (nullable = true)
 |-- new_user_class_level : integer (nullable = true)

+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level |
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|   234|        0|           5|                2|        5|        null|             3|         0|                    3|
|   523|        5|           2|                2|        2|           1|             3|         1|                    2|
|   612|        0|           8|                1|        2|           2|             3|         0|                 null|
|  1670|        0|           4|                2|        4|        null|             1|         0|                 null|
|  2545|        0|          10|                1|        4|        null|             3|         0|                 null|
|  3644|       49|           6|                2|        6|           2|             3|         0|                    2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                    2|
|  6211|        0|           9|                1|        3|        null|             3|         0|                    2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                    4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                    1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                    2|
|  9293|        0|           5|                2|        5|        null|             3|         0|                    4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                    2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                    2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10812|        0|           4|                2|        4|        null|             2|         0|                 null|
| 10912|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10996|        0|           5|                2|        5|        null|             3|         0|                    4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                    3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                    4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
only showing top 20 rows

# 这个用户信息表的特征也全是离散值。。
na_col = ['pvalue_level', 'new_user_class_level ']
for col in df.columns:
    if col not in na_col:
        print(col, df.groupBy(col).count().count())
userid 1061768
cms_segid 97
cms_group_id 13
final_gender_code 2
age_level 7
shopping_level 3
occupation 2
# 查看 'pvalue_level', 'new_user_class_level ' 缺失值比例
for col in na_col:
    print(col, df.groupBy(col).count().show())
+------------+------+
|pvalue_level| count|
+------------+------+
|        null|575917|
|           1|154436|
|           3| 37759|
|           2|293656|
+------------+------+

pvalue_level None
+---------------------+------+
|new_user_class_level | count|
+---------------------+------+
|                 null|344920|
|                    1| 80548|
|                    3|173047|
|                    4|138833|
|                    2|324420|
+---------------------+------+

new_user_class_level  None
_sum = df.count()
# 空值比例
print('pvalue_level', 1 - df.dropna(subset=['pvalue_level']).count() / _sum)
print('new_user_class_level ', 1 - df.dropna(subset=['new_user_class_level ']).count() / _sum)
pvalue_level 0.5424132202138321
new_user_class_level  0.32485439380354275
df
DataFrame[userid: int, cms_segid: int, cms_group_id: int, final_gender_code: int, age_level: int, pvalue_level: int, shopping_level: int, occupation: int, new_user_class_level : int]
# 用随机森林填补 pvalue_level 列的缺失值
# 把 pvalue_level作为标签,其他特征作为特征向量进行训练
# 把用预测值填充缺失值,不为空的值作为 训练集的标签
from pyspark.mllib.regression import LabeledPoint

train = df.dropna(subset=['pvalue_level']).rdd.map(
    # LabeledPoint 得到 (标签, 特征向量) 的元组
    # 离散值编码是从0 开始,pvalue_level最小值是1, 1,2,3 
    lambda r:LabeledPoint(r.pvalue_level-1, [r.cms_segid, r.cms_group_id, \
                r.final_gender_code, r.age_level, r.shopping_level, r.occupation])
)

官方文档

classmethod trainClassifier(data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy=‘auto’, impurity=‘gini’, maxDepth=4, maxBins=32, seed=None)[source]

Train a random forest model for binary or multiclass classification.

  • Parameters

    • data – Training dataset: RDD of LabeledPoint. Labels should take values {0, 1, …, numClasses-1}.

    • numClasses – Number of classes for classification.

    • categoricalFeaturesInfo – Map storing arity of categorical features. An entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, …, k-1}.

    • numTrees – Number of trees in the random forest.

    • featureSubsetStrategy – Number of features to consider for splits at each node. Supported values: “auto”, “all”, “sqrt”, “log2”, “onethird”. If “auto” is set, this parameter is set based on numTrees: if numTrees == 1, set to “all”; if numTrees > 1 (forest) set to “sqrt”. (default: “auto”)

    • impurity – Criterion used for information gain calculation. Supported values: “gini” or “entropy”. (default: “gini”)

    • maxDepth – Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 means 1 internal node + 2 leaf nodes). (default: 4)

    • maxBins – Maximum number of bins used for splitting features. (default: 32)

    • seed – Random seed for bootstrapping and choosing feature subsets. Set as None to generate seed based on system time. (default: None)

    Returns

    RandomForestModel that can be used for prediction.

%%time
from pyspark.mllib.tree import RandomForest

rfc = RandomForest.trainClassifier(data=train, numClasses=3, \
                                   categoricalFeaturesInfo={},numTrees=10)
Wall time: 12.7 s
# 筛选出 'pvalue_level' 存在缺失值的行并填充
pvalue_level_na_df = df.na.fill(-1).where('pvalue_level=-1')
pvalue_level_na_df.show()
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level |
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                    3|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                   -1|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                   -1|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                    2|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                    4|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                   -1|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                    4|
| 11602|        0|           5|                2|        5|          -1|             3|         0|                    2|
| 11727|        0|           3|                2|        3|          -1|             3|         0|                    1|
| 12195|        0|          10|                1|        4|          -1|             3|         0|                    2|
| 12620|        0|           4|                2|        4|          -1|             2|         0|                   -1|
| 12873|        0|           5|                2|        5|          -1|             3|         0|                    2|
| 14027|        0|          10|                1|        4|          -1|             3|         0|                    3|
| 14437|        0|           5|                2|        5|          -1|             3|         0|                   -1|
| 14574|        0|           1|                2|        1|          -1|             2|         0|                   -1|
| 14985|        0|          11|                1|        5|          -1|             2|         0|                   -1|
| 15525|        0|           3|                2|        3|          -1|             3|         0|                    1|
| 17025|        0|           5|                2|        5|          -1|             3|         0|                   -1|
| 17097|        0|           4|                2|        4|          -1|             2|         0|                   -1|
| 18799|        0|           5|                2|        5|          -1|             3|         0|                    4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
only showing top 20 rows

def feature_row(r):
    '''筛选出作为特征向量的列'''
    return r.cms_segid, r.cms_group_id, r.final_gender_code, r.age_level, r.shopping_level, r.occupation

# 筛选出要进行预测的特征向量
rdd = pvalue_level_na_df.rdd.map(feature_row)
pred = rfc.predict(rdd)
pred
MapPartitionsRDD[373] at mapPartitions at PythonMLLibAPI.scala:1336
# 对 标签/预测值 进行 +1 处理
pred_df = pred.map(lambda value:value + 1).collect()
pred_df
[2.0,
 2.0,
 2.0,
 ...
 ]
type(pred_df)
list
# 转为 pd.dataframe 操作,spark 的 dataframe 合并两个 df 很麻烦,还得同一个df的另一部分才能合并
p_obj = pvalue_level_na_df.toPandas()
p_obj['pvalue_level'] = pred_df
pdf = spark.createDataFrame(p_obj)
pdf.printSchema()
root
 |-- userid: long (nullable = true)
 |-- cms_segid: long (nullable = true)
 |-- cms_group_id: long (nullable = true)
 |-- final_gender_code: long (nullable = true)
 |-- age_level: long (nullable = true)
 |-- pvalue_level: double (nullable = true)
 |-- shopping_level: long (nullable = true)
 |-- occupation: long (nullable = true)
 |-- new_user_class_level : long (nullable = true)

pdf = pdf.\
    withColumn('userid', pdf.userid.cast(IntegerType())).\
    withColumn('cms_segid', pdf.cms_segid.cast(IntegerType())).\
    withColumn('cms_group_id', pdf.cms_group_id.cast(IntegerType())).\
    withColumn('final_gender_code', pdf.final_gender_code.cast(IntegerType())).\
    withColumn('age_level', pdf.age_level.cast(IntegerType())).\
    withColumn('pvalue_level', pdf.pvalue_level.cast(IntegerType())).\
    withColumn('shopping_level', pdf.shopping_level.cast(IntegerType())).\
    withColumn('occupation', pdf.occupation.cast(IntegerType())).\
    withColumn('new_user_class_level ', pdf['new_user_class_level '].cast(IntegerType())) # 这里后面多了一个空格 'new_user_class_level '

pdf.printSchema()
root
 |-- userid: integer (nullable = true)
 |-- cms_segid: integer (nullable = true)
 |-- cms_group_id: integer (nullable = true)
 |-- final_gender_code: integer (nullable = true)
 |-- age_level: integer (nullable = true)
 |-- pvalue_level: integer (nullable = true)
 |-- shopping_level: integer (nullable = true)
 |-- occupation: integer (nullable = true)
 |-- new_user_class_level : integer (nullable = true)

new_df = df.dropna(subset=['pvalue_level']).unionAll(pdf)
new_df.show()
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level |
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|   523|        5|           2|                2|        2|           1|             3|         1|                    2|
|   612|        0|           8|                1|        2|           2|             3|         0|                 null|
|  3644|       49|           6|                2|        6|           2|             3|         0|                    2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                    2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                    4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                    1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                    2|
|  9510|       55|           8|                1|        2|           2|             2|         0|                    2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                    2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10912|        0|           4|                2|        4|           2|             3|         0|                 null|
| 11256|        8|           2|                2|        2|           1|             3|         0|                    3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                    4|
| 11739|       20|           3|                2|        3|           2|             3|         0|                    4|
| 12549|       33|           4|                2|        4|           2|             3|         0|                    2|
| 15155|       36|           5|                2|        5|           2|             1|         0|                 null|
| 15347|       20|           3|                2|        3|           2|             3|         0|                    3|
| 15455|        8|           2|                2|        2|           2|             3|         0|                    3|
| 15783|        0|           4|                2|        4|           2|             3|         0|                 null|
| 16749|        5|           2|                2|        2|           1|             3|         1|                    4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
only showing top 20 rows

%%time
# 独热编码
df = df.withColumnRenamed('new_user_class_level ', 'new_user_class_level') # 我忍这个空格很久了,现在去掉
df = df.na.fill(-1)
df.show()
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 20 rows

Wall time: 492 ms
# 要进行独热编码必须先把该列值转为字符串类型
from pyspark.sql.types import StringType

df = df.withColumn('new_user_class_level', df['new_user_class_level'].cast(StringType()))

stringindexer = StringIndexer(inputCol='new_user_class_level',
                              outputCol='nucl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='nucl_onehot_feature', outputCol='nucl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(df)

df = pipeline_fit.transform(df)
df.show()
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-------------------+-----------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|nucl_onehot_feature|nucl_onehot_value|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-------------------+-----------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|                2.0|    (5,[2],[1.0])|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|                1.0|    (5,[1],[1.0])|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|                0.0|    (5,[0],[1.0])|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|                1.0|    (5,[1],[1.0])|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|                1.0|    (5,[1],[1.0])|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|                1.0|    (5,[1],[1.0])|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|                4.0|    (5,[4],[1.0])|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|                1.0|    (5,[1],[1.0])|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|                1.0|    (5,[1],[1.0])|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|                1.0|    (5,[1],[1.0])|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|                0.0|    (5,[0],[1.0])|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|                2.0|    (5,[2],[1.0])|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-------------------+-----------------+
only showing top 20 rows

df = df.withColumn('pvalue_level', df['pvalue_level'].cast(StringType()))

stringindexer = StringIndexer(inputCol='pvalue_level',
                              outputCol='pvalue_level_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pvalue_level_onehot_feature', outputCol='pl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(df)

df = pipeline_fit.transform(df)
df.printSchema()
df.columns
root
 |-- userid: integer (nullable = true)
 |-- cms_segid: integer (nullable = true)
 |-- cms_group_id: integer (nullable = true)
 |-- final_gender_code: integer (nullable = true)
 |-- age_level: integer (nullable = true)
 |-- pvalue_level: string (nullable = true)
 |-- shopping_level: integer (nullable = true)
 |-- occupation: integer (nullable = true)
 |-- new_user_class_level: string (nullable = true)
 |-- nucl_onehot_feature: double (nullable = false)
 |-- nucl_onehot_value: vector (nullable = true)
 |-- pvalue_level_onehot_feature: double (nullable = false)
 |-- pl_onehot_value: vector (nullable = true)

['userid',
 'cms_segid',
 'cms_group_id',
 'final_gender_code',
 'age_level',
 'pvalue_level',
 'shopping_level',
 'occupation',
 'new_user_class_level',
 'nucl_onehot_feature',
 'nucl_onehot_value',
 'pvalue_level_onehot_feature',
 'pl_onehot_value']
# 特征合并
from pyspark.ml.feature import VectorAssembler

feature_df = VectorAssembler().setInputCols(['age_level', 'pl_onehot_value', 'nucl_onehot_value']).\
                                setOutputCol('features').transform(df)
feature_df.show()
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-------------------+-----------------+---------------------------+---------------+--------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|nucl_onehot_feature|nucl_onehot_value|pvalue_level_onehot_feature|pl_onehot_value|            features|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-------------------+-----------------+---------------------------+---------------+--------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|                2.0|    (5,[2],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,7],[5.0,...|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|                1.0|    (5,[1],[1.0])|                        2.0|  (4,[2],[1.0])|(10,[0,3,6],[2.0,...|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,5],[2.0,...|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|                0.0|    (5,[0],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,5],[4.0,...|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,5],[4.0,...|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|                1.0|    (5,[1],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,6],[6.0,...|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|                1.0|    (5,[1],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,6],[5.0,...|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|                1.0|    (5,[1],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,6],[3.0,...|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|                        2.0|  (4,[2],[1.0])|(10,[0,3,8],[1.0,...|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|                4.0|    (5,[4],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,9],[5.0,...|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|                1.0|    (5,[1],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,6],[2.0,...|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,8],[5.0,...|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|                1.0|    (5,[1],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,6],[2.0,...|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|                1.0|    (5,[1],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,6],[4.0,...|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,5],[4.0,...|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|                0.0|    (5,[0],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,5],[4.0,...|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,5],[4.0,...|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,8],[5.0,...|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|                2.0|    (5,[2],[1.0])|                        2.0|  (4,[2],[1.0])|(10,[0,3,7],[2.0,...|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|                        2.0|  (4,[2],[1.0])|(10,[0,3,8],[4.0,...|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-------------------+-----------------+---------------------------+---------------+--------------------+
only showing top 20 rows
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值