基于LR的点击率预测模型训练

4.2 基于LR的点击率预测模型训练

  • 本小节主要根据广告点击样本数据集(raw_sample)、广告基本特征数据集(ad_feature)、用户基本信息数据集(user_profile)构建出了一个完整的样本数据集,并按日期划分为了训练集(前七天)和测试集(最后一天),利用逻辑回归进行训练。

    训练模型时,通过对类别特征数据进行处理,一定程度达到提高了模型的效果

'''从HDFS中加载样本数据信息'''
_raw_sample_df1 = spark.read.csv("hdfs://localhost:8020/csv/raw_sample.csv", header=True)
# _raw_sample_df1.show()    # 展示数据,默认前20条
# 更改表结构,转换为对应的数据类型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, LongType, StringType
  
# 更改df表结构:更改列类型和列名称
_raw_sample_df2 = _raw_sample_df1.\
    withColumn("user", _raw_sample_df1.user.cast(IntegerType())).withColumnRenamed("user", "userId").\
    withColumn("time_stamp", _raw_sample_df1.time_stamp.cast(LongType())).withColumnRenamed("time_stamp", "timestamp").\
    withColumn("adgroup_id", _raw_sample_df1.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("pid", _raw_sample_df1.pid.cast(StringType())).\
    withColumn("nonclk", _raw_sample_df1.nonclk.cast(IntegerType())).\
    withColumn("clk", _raw_sample_df1.clk.cast(IntegerType()))
_raw_sample_df2.printSchema()
_raw_sample_df2.show()

# 样本数据pid特征处理
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

stringindexer = StringIndexer(inputCol='pid', outputCol='pid_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pid_feature', outputCol='pid_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_raw_sample_df2)
raw_sample_df = pipeline_fit.transform(_raw_sample_df2)
raw_sample_df.show()

'''pid和特征的对应关系
430548_1007:0
430549_1007:1
'''

显示结果:

root
 |-- userId: integer (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: integer (nullable = true)
 |-- clk: integer (nullable = true)

+------+----------+---------+-----------+------+---+
|userId| timestamp|adgroupId|        pid|nonclk|clk|
+------+----------+---------+-----------+------+---+
|581738|1494137644|        1|430548_1007|     1|  0|
|449818|1494638778|        3|430548_1007|     1|  0|
|914836|1494650879|        4|430548_1007|     1|  0|
|914836|1494651029|        5|430548_1007|     1|  0|
|399907|1494302958|        8|430548_1007|     1|  0|
|628137|1494524935|        9|430548_1007|     1|  0|
|298139|1494462593|        9|430539_1007|     1|  0|
|775475|1494561036|        9|430548_1007|     1|  0|
|555266|1494307136|       11|430539_1007|     1|  0|
|117840|1494036743|       11|430548_1007|     1|  0|
|739815|1494115387|       11|430539_1007|     1|  0|
|623911|1494625301|       11|430548_1007|     1|  0|
|623911|1494451608|       11|430548_1007|     1|  0|
|421590|1494034144|       11|430548_1007|     1|  0|
|976358|1494156949|       13|430548_1007|     1|  0|
|286630|1494218579|       13|430539_1007|     1|  0|
|286630|1494289247|       13|430539_1007|     1|  0|
|771431|1494153867|       13|430548_1007|     1|  0|
|707120|1494220810|       13|430548_1007|     1|  0|
|530454|1494293746|       13|430548_1007|     1|  0|
+------+----------+---------+-----------+------+---+
only showing top 20 rows

+------+----------+---------+-----------+------+---+-----------+-------------+
|userId| timestamp|adgroupId|        pid|nonclk|clk|pid_feature|    pid_value|
+------+----------+---------+-----------+------+---+-----------+-------------+
|581738|1494137644|        1|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|449818|1494638778|        3|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494650879|        4|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494651029|        5|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|399907|1494302958|        8|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|628137|1494524935|        9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|298139|1494462593|        9|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|775475|1494561036|        9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|555266|1494307136|       11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|117840|1494036743|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|739815|1494115387|       11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|623911|1494625301|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|623911|1494451608|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|421590|1494034144|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|976358|1494156949|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|286630|1494218579|       13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|286630|1494289247|       13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|771431|1494153867|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|707120|1494220810|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|530454|1494293746|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
+------+----------+---------+-----------+------+---+-----------+-------------+
only showing top 20 rows

'pid和特征的对应关系\n430548_1007:0\n430549_1007:1\n'
  • 从HDFS中加载广告基本信息数据
_ad_feature_df = spark.read.csv("hdfs://localhost:9000/datasets/ad_feature.csv", header=True)

# 更改表结构,转换为对应的数据类型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType

# 替换掉NULL字符串
_ad_feature_df = _ad_feature_df.replace("NULL", "-1")
 
# 更改df表结构:更改列类型和列名称
ad_feature_df = _ad_feature_df.\
    withColumn("adgroup_id", _ad_feature_df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("cate_id", _ad_feature_df.cate_id.cast(IntegerType())).withColumnRenamed("cate_id", "cateId").\
    withColumn("campaign_id", _ad_feature_df.campaign_id.cast(IntegerType())).withColumnRenamed("campaign_id", "campaignId").\
    withColumn("customer", _ad_feature_df.customer.cast(IntegerType())).withColumnRenamed("customer", "customerId").\
    withColumn("brand", _ad_feature_df.brand.cast(IntegerType())).withColumnRenamed("brand", "brandId").\
    withColumn("price", _ad_feature_df.price.cast(FloatType()))
ad_feature_df.printSchema()
ad_feature_df.show()

显示结果:

root
 |-- adgroupId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- campaignId: integer (nullable = true)
 |-- customerId: integer (nullable = true)
 |-- brandId: integer (nullable = true)
 |-- price: float (nullable = true)

+---------+------+----------+----------+-------+-----+
|adgroupId|cateId|campaignId|customerId|brandId|price|
+---------+------+----------+----------+-------+-----+
|    63133|  6406|     83237|         1|  95471|170.0|
|   313401|  6406|     83237|         1|  87331|199.0|
|   248909|   392|     83237|         1|  32233| 38.0|
|   208458|   392|     83237|         1| 174374|139.0|
|   110847|  7211|    135256|         2| 145952|32.99|
|   607788|  6261|    387991|         6| 207800|199.0|
|   375706|  4520|    387991|         6|     -1| 99.0|
|    11115|  7213|    139747|         9| 186847| 33.0|
|    24484|  7207|    139744|         9| 186847| 19.0|
|    28589|  5953|    395195|        13|     -1|428.0|
|    23236|  5953|    395195|        13|     -1|368.0|
|   300556|  5953|    395195|        13|     -1|639.0|
|    92560|  5953|    395195|        13|     -1|368.0|
|   590965|  4284|     28145|        14| 454237|249.0|
|   529913|  4284|     70206|        14|     -1|249.0|
|   546930|  4284|     28145|        14|     -1|249.0|
|   639794|  6261|     70206|        14|  37004| 89.9|
|   335413|  4284|     28145|        14|     -1|249.0|
|   794890|  4284|     70206|        14| 454237|249.0|
|   684020|  6261|     70206|        14|  37004| 99.0|
+---------+------+----------+----------+-------+-----+
only showing top 20 rows
  • 从HDFS加载用户基本信息数据
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, FloatType

# 构建表结构schema对象
schema = StructType([
    StructField("userId", IntegerType()),
    StructField("cms_segid", IntegerType()),
    StructField("cms_group_id", IntegerType()),
    StructField("final_gender_code", IntegerType()),
    StructField("age_level", IntegerType()),
    StructField("pvalue_level", IntegerType()),
    StructField("shopping_level", IntegerType()),
    StructField("occupation", IntegerType()),
    StructField("new_user_class_level", IntegerType())
])
# 利用schema从hdfs加载
_user_profile_df1 = spark.read.csv("hdfs://localhost:9000/datasets/user_profile.csv", header=True, schema=schema)
# user_profile_df.printSchema()
# user_profile_df.show()

'''对缺失数据进行特征热编码'''
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

# 使用热编码转换pvalue_level的一维数据为多维,增加n-1个虚拟变量,n为pvalue_level的取值范围

# 需要先将缺失值全部替换为数值,便于处理,否则会抛出异常
from pyspark.sql.types import StringType
_user_profile_df2 = _user_profile_df1.na.fill(-1)
# _user_profile_df2.show()

# 热编码时,必须先将待处理字段转为字符串类型才可处理
_user_profile_df3 = _user_profile_df2.withColumn("pvalue_level", _user_profile_df2.pvalue_level.cast(StringType()))\
    .withColumn("new_user_class_level", _user_profile_df2.new_user_class_level.cast(StringType()))
# _user_profile_df3.printSchema()

# 对pvalue_level进行热编码,求值
# 运行过程是先将pvalue_level转换为一列新的特征数据,然后对该特征数据求出的热编码值,存在了新的一列数据中,类型为一个稀疏矩阵
stringindexer = StringIndexer(inputCol='pvalue_level', outputCol='pl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pl_onehot_feature', outputCol='pl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_user_profile_df3)
_user_profile_df4 = pipeline_fit.transform(_user_profile_df3)
# pl_onehot_value列的值为稀疏矩阵,存储热编码的结果
# _user_profile_df4.printSchema()
# _user_profile_df4.show()

# 使用热编码转换new_user_class_level的一维数据为多维
stringindexer = StringIndexer(inputCol='new_user_class_level', outputCol='nucl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='nucl_onehot_feature', outputCol='nucl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_user_profile_df4)
user_profile_df = pipeline_fit.transform(_user_profile_df4)
user_profile_df.show()

显示结果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|pl_onehot_feature|pl_onehot_value|nucl_onehot_feature|nucl_onehot_value|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|              0.0|  (4,[0],[1.0])|                2.0|    (5,[2],[1.0])|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|              2.0|  (4,[2],[1.0])|                1.0|    (5,[1],[1.0])|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|              0.0|  (4,[0],[1.0])|                1.0|    (5,[1],[1.0])|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|              1.0|  (4,[1],[1.0])|                4.0|    (5,[4],[1.0])|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|              2.0|  (4,[2],[1.0])|                2.0|    (5,[2],[1.0])|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
only showing top 20 rows

  • 热编码中:"pvalue_level"特征对应关系:
+------------+----------------------+
|pvalue_level|pl_onehot_feature     |
+------------+----------------------+
|          -1|                   0.0|
|           3|                   3.0|
|           1|                   2.0|
|           2|                   1.0|
+------------+----------------------+
  • “new_user_class_level”的特征对应关系
+--------------------+------------------------+
|new_user_class_level|nucl_onehot_feature     |
+--------------------+------------------------+
|                  -1|                     0.0|
|                   3|                     2.0|
|                   1|                     4.0|
|                   4|                     3.0|
|                   2|                     1.0|
+--------------------+------------------------+
user_profile_df.groupBy("pvalue_level").min("pl_onehot_feature").show()
user_profile_df.groupBy("new_user_class_level").min("nucl_onehot_feature").show()

显示结果:

+------------+----------------------+
|pvalue_level|min(pl_onehot_feature)|
+------------+----------------------+
|          -1|                   0.0|
|           3|                   3.0|
|           1|                   2.0|
|           2|                   1.0|
+------------+----------------------+

+--------------------+------------------------+
|new_user_class_level|min(nucl_onehot_feature)|
+--------------------+------------------------+
|                  -1|                     0.0|
|                   3|                     2.0|
|                   1|                     4.0|
|                   4|                     3.0|
|                   2|                     1.0|
+--------------------+------------------------+

# raw_sample_df和ad_feature_df合并条件
condition = [raw_sample_df.adgroupId==ad_feature_df.adgroupId]
_ = raw_sample_df.join(ad_feature_df, condition, 'outer')

# _和user_profile_df合并条件
condition2 = [_.userId==user_profile_df.userId]
datasets = _.join(user_profile_df, condition2, "outer")
# 查看datasets的结构
datasets.printSchema()
# 查看datasets条目数
print(datasets.count())

显示结果:

root
 |-- userId: integer (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: integer (nullable = true)
 |-- clk: integer (nullable = true)
 |-- pid_feature: double (nullable = true)
 |-- pid_value: vector (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- campaignId: integer (nullable = true)
 |-- customerId: integer (nullable = true)
 |-- brandId: integer (nullable = true)
 |-- price: float (nullable = true)
 |-- userId: integer (nullable = true)
 |-- cms_segid: integer (nullable = true)
 |-- cms_group_id: integer (nullable = true)
 |-- final_gender_code: integer (nullable = true)
 |-- age_level: integer (nullable = true)
 |-- pvalue_level: string (nullable = true)
 |-- shopping_level: integer (nullable = true)
 |-- occupation: integer (nullable = true)
 |-- new_user_class_level: string (nullable = true)
 |-- pl_onehot_feature: double (nullable = true)
 |-- pl_onehot_value: vector (nullable = true)
 |-- nucl_onehot_feature: double (nullable = true)
 |-- nucl_onehot_value: vector (nullable = true)

26557961
  • 训练CTRModel_Normal:直接将对应的特征的特征值组合成对应的特征向量进行训练
# 剔除冗余、不需要的字段
useful_cols = [
    # 
    # 时间字段,划分训练集和测试集
    "timestamp",
    # label目标值字段
    "clk",  
    # 特征值字段
    "pid_value",       # 资源位的特征向量
    "price",    # 广告价格
    "cms_segid",    # 用户微群ID
    "cms_group_id",    # 用户组ID
    "final_gender_code",    # 用户性别特征,[1,2]
    "age_level",    # 年龄等级,1-
    "shopping_level",
    "occupation",
    "pl_onehot_value",
    "nucl_onehot_value"
]
# 筛选指定字段数据,构建新的数据集
datasets_1 = datasets.select(*useful_cols)
# 由于前面使用的是outer方式合并的数据,产生了部分空值数据,这里必须先剔除掉
datasets_1 = datasets_1.dropna()
print("剔除空值数据后,还剩:", datasets_1.count())

显示结果:

剔除空值数据后,还剩: 25029435

  • 根据特征字段计算出特征向量,并划分出训练数据集和测试数据集
from pyspark.ml.feature import VectorAssembler
# 根据特征字段计算特征向量
datasets_1 = VectorAssembler().setInputCols(useful_cols[2:]).setOutputCol("features").transform(datasets_1)
# 训练数据集: 约7天的数据
train_datasets_1 = datasets_1.filter(datasets_1.timestamp<=(1494691186-24*60*60))
# 测试数据集:约1天的数据量
test_datasets_1 = datasets_1.filter(datasets_1.timestamp>(1494691186-24*60*60))
# 所有的特征的特征向量已经汇总到在features字段中
train_datasets_1.show(5)
test_datasets_1.show(5)

显示结果:

+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
| timestamp|clk|    pid_value| price|cms_segid|cms_group_id|final_gender_code|age_level|shopping_level|occupation|pl_onehot_value|nucl_onehot_value|            features|
+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
|1494261938|  0|(2,[1],[1.0])| 108.0|        0|          11|                1|        5|             3|         0|  (4,[0],[1.0])|    (5,[1],[1.0])|(18,[1,2,4,5,6,7,...|
|1494261938|  0|(2,[1],[1.0])|1880.0|        0|          11|                1|        5|             3|         0|  (4,[0],[1.0])|    (5,[1],[1.0])|(18,[1,2,4,5,6,7,...|
|1494553913|  0|(2,[1],[1.0])|2360.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494553913|  0|(2,[1],[1.0])|2200.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494436784|  0|(2,[1],[1.0])|5649.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
only showing top 5 rows

+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
| timestamp|clk|    pid_value|price|cms_segid|cms_group_id|final_gender_code|age_level|shopping_level|occupation|pl_onehot_value|nucl_onehot_value|            features|
+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
|1494677292|  0|(2,[1],[1.0])|176.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494677292|  0|(2,[1],[1.0])|698.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494677292|  0|(2,[1],[1.0])|697.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494684007|  0|(2,[1],[1.0])|247.0|       18|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[4],[1.0])|(18,[1,2,3,4,5,6,...|
|1494684007|  0|(2,[1],[1.0])|109.0|       18|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[4],[1.0])|(18,[1,2,3,4,5,6,...|
+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
only showing top 5 rows

from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression()
# 设置目标字段、特征值字段并训练
model = lr.setLabelCol("clk").setFeaturesCol("features").fit(train_datasets_1)
# 对模型进行存储
model.save("hdfs://localhost:9000/models/CTRModel_Normal.obj")
# 载入训练好的模型
from pyspark.ml.classification import LogisticRegressionModel
model = LogisticRegressionModel.load("hdfs://localhost:9000/models/CTRModel_Normal.obj")
# 根据测试数据进行预测
result_1 = model.transform(test_datasets_1)
# 按probability升序排列数据,probability表示预测结果的概率
# 如果预测值是0,其概率是0.9248,那么反之可推出1的可能性就是1-0.9248=0.0752,即点击概率约为7.52%
# 因为前面提到广告的点击率一般都比较低,所以预测值通常都是0,因此通常需要反减得出点击的概率
result_1.select("clk", "price", "probability", "prediction").sort("probability").show(100)

显示结果:

+---+-----------+--------------------+----------+
|clk|      price|         probability|prediction|
+---+-----------+--------------------+----------+
|  0|      1.0E8|[0.86822033939259...|       0.0|
|  0|      1.0E8|[0.88410457194969...|       0.0|
|  0|      1.0E8|[0.89175497837562...|       0.0|
|  1|5.5555556E7|[0.92481456486873...|       0.0|
|  0|      1.5E7|[0.93741450446939...|       0.0|
|  0|      1.5E7|[0.93757135079959...|       0.0|
|  0|      1.5E7|[0.93834723093801...|       0.0|
|  0|     1099.0|[0.93972095713786...|       0.0|
|  0|      338.0|[0.93972134993018...|       0.0|
|  0|      311.0|[0.93972136386626...|       0.0|
|  0|      300.0|[0.93972136954393...|       0.0|
|  0|      278.0|[0.93972138089925...|       0.0|
|  0|      188.0|[0.93972142735283...|       0.0|
|  0|      176.0|[0.93972143354663...|       0.0|
|  0|      168.0|[0.93972143767584...|       0.0|
|  0|      158.0|[0.93972144283734...|       0.0|
|  1|      138.0|[0.93972145316035...|       0.0|
|  0|      125.0|[0.93972145987031...|       0.0|
|  0|      119.0|[0.93972146296721...|       0.0|
|  0|       78.0|[0.93972148412937...|       0.0|
|  0|      59.98|[0.93972149343040...|       0.0|
|  0|       58.0|[0.93972149445238...|       0.0|
|  0|       56.0|[0.93972149548468...|       0.0|
|  0|       38.0|[0.93972150477538...|       0.0|
|  1|       35.0|[0.93972150632383...|       0.0|
|  0|       33.0|[0.93972150735613...|       0.0|
|  0|       30.0|[0.93972150890458...|       0.0|
|  0|       27.6|[0.93972151014334...|       0.0|
|  0|       18.0|[0.93972151509838...|       0.0|
|  0|       30.0|[0.93980311191464...|       0.0|
|  0|       28.0|[0.93980311294563...|       0.0|
|  0|       25.0|[0.93980311449212...|       0.0|
|  0|      688.0|[0.93999362023323...|       0.0|
|  0|      339.0|[0.93999379960808...|       0.0|
|  0|      335.0|[0.93999380166395...|       0.0|
|  0|      220.0|[0.93999386077017...|       0.0|
|  0|      176.0|[0.93999388338470...|       0.0|
|  0|      158.0|[0.93999389263610...|       0.0|
|  0|      158.0|[0.93999389263610...|       0.0|
|  1|      149.0|[0.93999389726180...|       0.0|
|  0|      122.5|[0.93999391088191...|       0.0|
|  0|       99.0|[0.93999392296012...|       0.0|
|  0|       88.0|[0.93999392861375...|       0.0|
|  0|       79.0|[0.93999393323945...|       0.0|
|  0|       75.0|[0.93999393529532...|       0.0|
|  0|       68.0|[0.93999393889308...|       0.0|
|  0|       68.0|[0.93999393889308...|       0.0|
|  0|       59.9|[0.93999394305620...|       0.0|
|  0|      44.98|[0.93999395072458...|       0.0|
|  0|       35.5|[0.93999395559698...|       0.0|
|  0|       33.0|[0.93999395688189...|       0.0|
|  0|       32.8|[0.93999395698469...|       0.0|
|  0|       30.0|[0.93999395842379...|       0.0|
|  0|       28.0|[0.93999395945172...|       0.0|
|  0|       19.9|[0.93999396361485...|       0.0|
|  0|       19.8|[0.93999396366625...|       0.0|
|  0|       19.8|[0.93999396366625...|       0.0|
|  0|       12.0|[0.93999396767518...|       0.0|
|  0|        6.7|[0.93999397039920...|       0.0|
|  0|      568.0|[0.94000369247841...|       0.0|
|  0|      398.0|[0.94000377983931...|       0.0|
|  0|      158.0|[0.94000390317214...|       0.0|
|  0|     5718.0|[0.94001886593718...|       0.0|
|  0|     5718.0|[0.94001886593718...|       0.0|
|  1|     5608.0|[0.94001892245145...|       0.0|
|  0|     4120.0|[0.94001968693052...|       0.0|
|  0|     1027.5|[0.94002127571285...|       0.0|
|  0|     1027.5|[0.94002127571285...|       0.0|
|  0|      989.0|[0.94002129549211...|       0.0|
|  0|      672.0|[0.94002145834965...|       0.0|
|  0|      660.0|[0.94002146451460...|       0.0|
|  0|      598.0|[0.94002149636681...|       0.0|
|  0|      598.0|[0.94002149636681...|       0.0|
|  0|      563.0|[0.94002151434789...|       0.0|
|  0|      509.0|[0.94002154209012...|       0.0|
|  0|      509.0|[0.94002154209012...|       0.0|
|  0|      500.0|[0.94002154671382...|       0.0|
|  0|      498.0|[0.94002154774131...|       0.0|
|  0|      440.0|[0.94002157753851...|       0.0|
|  0|      430.0|[0.94002158267595...|       0.0|
|  0|      388.0|[0.94002160425322...|       0.0|
|  0|      369.0|[0.94002161401436...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      366.0|[0.94002161555560...|       0.0|
|  0|      366.0|[0.94002161555560...|       0.0|
|  0|      348.0|[0.94002162480299...|       0.0|
|  0|      299.0|[0.94002164997645...|       0.0|
|  0|      299.0|[0.94002164997645...|       0.0|
|  0|      299.0|[0.94002164997645...|       0.0|
|  0|      298.0|[0.94002165049020...|       0.0|
|  0|      297.0|[0.94002165100394...|       0.0|
|  0|      278.0|[0.94002166076508...|       0.0|
|  1|      275.0|[0.94002166230631...|       0.0|
|  0|      275.0|[0.94002166230631...|       0.0|
|  0|      273.0|[0.94002166333380...|       0.0|
|  0|      258.0|[0.94002167103995...|       0.0|
|  0|      256.0|[0.94002167206744...|       0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows
  • 查看样本中点击的被实际点击的条目的预测情况
result_1.filter(result_1.clk==1).select("clk", "price", "probability", "prediction").sort("probability").show(100)

显示结果:

+---+-----------+--------------------+----------+
|clk|      price|         probability|prediction|
+---+-----------+--------------------+----------+
|  1|5.5555556E7|[0.92481456486873...|       0.0|
|  1|      138.0|[0.93972145316035...|       0.0|
|  1|       35.0|[0.93972150632383...|       0.0|
|  1|      149.0|[0.93999389726180...|       0.0|
|  1|     5608.0|[0.94001892245145...|       0.0|
|  1|      275.0|[0.94002166230631...|       0.0|
|  1|       35.0|[0.94002178560473...|       0.0|
|  1|       49.0|[0.94004219516957...|       0.0|
|  1|      915.0|[0.94021082858784...|       0.0|
|  1|      598.0|[0.94021099096349...|       0.0|
|  1|      568.0|[0.94021100633025...|       0.0|
|  1|      398.0|[0.94021109340848...|       0.0|
|  1|      368.0|[0.94021110877521...|       0.0|
|  1|      299.0|[0.94021114411869...|       0.0|
|  1|      278.0|[0.94021115487539...|       0.0|
|  1|      259.0|[0.94021116460765...|       0.0|
|  1|      258.0|[0.94021116511987...|       0.0|
|  1|      258.0|[0.94021116511987...|       0.0|
|  1|      258.0|[0.94021116511987...|       0.0|
|  1|      195.0|[0.94021119738998...|       0.0|
|  1|      188.0|[0.94021120097554...|       0.0|
|  1|      178.0|[0.94021120609778...|       0.0|
|  1|      159.0|[0.94021121583003...|       0.0|
|  1|      149.0|[0.94021122095226...|       0.0|
|  1|      138.0|[0.94021122658672...|       0.0|
|  1|       58.0|[0.94021126756458...|       0.0|
|  1|       49.0|[0.94021127217459...|       0.0|
|  1|       35.0|[0.94021127934572...|       0.0|
|  1|       25.0|[0.94021128446795...|       0.0|
|  1|     2890.0|[0.94028789742257...|       0.0|
|  1|      220.0|[0.94028926340218...|       0.0|
|  1|      188.0|[0.94031410659516...|       0.0|
|  1|       68.0|[0.94031416796289...|       0.0|
|  1|       58.0|[0.94031417307687...|       0.0|
|  1|      198.0|[0.94035413548387...|       0.0|
|  1|      208.0|[0.94039204931181...|       0.0|
|  1|     8888.0|[0.94045237642030...|       0.0|
|  1|      519.0|[0.94045664687995...|       0.0|
|  1|      478.0|[0.94045666780037...|       0.0|
|  1|      349.0|[0.94045673362308...|       0.0|
|  1|      348.0|[0.94045673413334...|       0.0|
|  1|      316.0|[0.94045675046144...|       0.0|
|  1|      298.0|[0.94045675964600...|       0.0|
|  1|      298.0|[0.94045675964600...|       0.0|
|  1|      199.0|[0.94045681016104...|       0.0|
|  1|      199.0|[0.94045681016104...|       0.0|
|  1|      198.0|[0.94045681067129...|       0.0|
|  1|      187.1|[0.94045681623305...|       0.0|
|  1|      176.0|[0.94045682189685...|       0.0|
|  1|      168.0|[0.94045682597887...|       0.0|
|  1|      160.0|[0.94045683006090...|       0.0|
|  1|      158.0|[0.94045683108140...|       0.0|
|  1|      158.0|[0.94045683108140...|       0.0|
|  1|      135.0|[0.94045684281721...|       0.0|
|  1|      129.0|[0.94045684587872...|       0.0|
|  1|      127.0|[0.94045684689923...|       0.0|
|  1|      125.0|[0.94045684791973...|       0.0|
|  1|      124.0|[0.94045684842999...|       0.0|
|  1|      118.0|[0.94045685149150...|       0.0|
|  1|      109.0|[0.94045685608377...|       0.0|
|  1|      108.0|[0.94045685659402...|       0.0|
|  1|       99.0|[0.94045686118630...|       0.0|
|  1|       98.0|[0.94045686169655...|       0.0|
|  1|       79.8|[0.94045687098314...|       0.0|
|  1|       79.0|[0.94045687139134...|       0.0|
|  1|       77.0|[0.94045687241185...|       0.0|
|  1|       72.5|[0.94045687470798...|       0.0|
|  1|       69.0|[0.94045687649386...|       0.0|
|  1|       68.0|[0.94045687700412...|       0.0|
|  1|       60.0|[0.94045688108613...|       0.0|
|  1|      43.98|[0.94045688926037...|       0.0|
|  1|       40.0|[0.94045689129118...|       0.0|
|  1|       39.9|[0.94045689134220...|       0.0|
|  1|       39.6|[0.94045689149528...|       0.0|
|  1|       32.0|[0.94045689537319...|       0.0|
|  1|       31.0|[0.94045689588345...|       0.0|
|  1|      25.98|[0.94045689844491...|       0.0|
|  1|       23.0|[0.94045689996546...|       0.0|
|  1|       19.0|[0.94045690200647...|       0.0|
|  1|       16.9|[0.94045690307800...|       0.0|
|  1|       10.0|[0.94045690659874...|       0.0|
|  1|        3.5|[0.94045690991538...|       0.0|
|  1|        3.5|[0.94045690991538...|       0.0|
|  1|        0.4|[0.94045691149716...|       0.0|
|  1|     3960.0|[0.94055740378069...|       0.0|
|  1|     3088.0|[0.94055784801535...|       0.0|
|  1|     1689.0|[0.94055856072019...|       0.0|
|  1|      998.0|[0.94055891273943...|       0.0|
|  1|      888.0|[0.94055896877705...|       0.0|
|  1|      788.0|[0.94055901972029...|       0.0|
|  1|      737.0|[0.94055904570133...|       0.0|
|  1|      629.0|[0.94055910071996...|       0.0|
|  1|      599.0|[0.94055911600291...|       0.0|
|  1|      599.0|[0.94055911600291...|       0.0|
|  1|      599.0|[0.94055911600291...|       0.0|
|  1|      499.0|[0.94055916694603...|       0.0|
|  1|      468.0|[0.94055918273839...|       0.0|
|  1|      459.0|[0.94055918732327...|       0.0|
|  1|      399.0|[0.94055921788912...|       0.0|
|  1|      399.0|[0.94055921788912...|       0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows

  • 训练CTRModel_AllOneHot

    • “pid_value”, 类别型特征,已被转换为多维特征==> 2维
    • “price”, 统计型特征 ===> 1维
    • “cms_segid”, 类别型特征,约97个分类 ===> 1维
    • “cms_group_id”, 类别型特征,约13个分类 ==> 1维
    • “final_gender_code”, 类别型特征,2个分类 ==> 1维
    • “age_level”, 类别型特征,7个分类 ==> 1维
    • “shopping_level”, 类别型特征,3个分类 ==> 1维
    • “occupation”, 类别型特征,2个分类 ==> 1维
    • “pl_onehot_value”, 类别型特征,已被转换为多维特征 ==> 4维
    • “nucl_onehot_value” 类别型特征,已被转换为多维特征 ==> 5维

    类别性特征都可以考虑进行热独编码,将单一变量变为多变量,相当于增加了相关特征的数量

    • “cms_segid”, 类别型特征,约97个分类 ===> 97维 舍弃
    • “cms_group_id”, 类别型特征,约13个分类 ==> 13维
    • “final_gender_code”, 类别型特征,2个分类 ==> 2维
    • “age_level”, 类别型特征,7个分类 ==>7维
    • “shopping_level”, 类别型特征,3个分类 ==> 3维
    • “occupation”, 类别型特征,2个分类 ==> 2维

    但由于cms_segid分类过多,这里考虑舍弃,避免数据过于稀疏

datasets_1.first()

显示结果:

datasets_1.first()
datasets_1.first()
Row(timestamp=1494261938, clk=0, pid_value=SparseVector(2, {1: 1.0}), price=1880.0, cms_segid=0, cms_group_id=11, final_gender_code=1, age_level=5, shopping_level=3, occupation=0, pl_onehot_value=SparseVector(4, {0: 1.0}), nucl_onehot_value=SparseVector(5, {1: 1.0}), features=SparseVector(18, {1: 1.0, 2: 1880.0, 4: 11.0, 5: 1.0, 6: 5.0, 7: 3.0, 9: 1.0, 14: 1.0}))
# 先将下列五列数据转为字符串类型,以便于进行热独编码
# - "cms_group_id",   类别型特征,约13个分类 ==> 13
# - "final_gender_code", 类别型特征,2个分类 ==> 2
# - "age_level",    类别型特征,7个分类 ==>7
# - "shopping_level",    类别型特征,3个分类 ==> 3
# - "occupation",    类别型特征,2个分类 ==> 2

datasets_2 = datasets.withColumn("cms_group_id", datasets.cms_group_id.cast(StringType()))\
    .withColumn("final_gender_code", datasets.final_gender_code.cast(StringType()))\
    .withColumn("age_level", datasets.age_level.cast(StringType()))\
    .withColumn("shopping_level", datasets.shopping_level.cast(StringType()))\
    .withColumn("occupation", datasets.occupation.cast(StringType()))
useful_cols_2 = [
    # 时间值,划分训练集和测试集
    "timestamp",
    # label目标值
    "clk",  
    # 特征值
    "price",
    "cms_group_id",
    "final_gender_code",
    "age_level",
    "shopping_level",
    "occupation",
    "pid_value", 
    "pl_onehot_value",
    "nucl_onehot_value"
]
# 筛选指定字段数据
datasets_2 = datasets_2.select(*useful_cols_2)
# 由于前面使用的是outer方式合并的数据,产生了部分空值数据,这里必须先剔除掉
datasets_2 = datasets_2.dropna()


from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline
# 热编码处理函数封装
def oneHotEncoder(col1, col2, col3, data):
    stringindexer = StringIndexer(inputCol=col1, outputCol=col2)
    encoder = OneHotEncoder(dropLast=False, inputCol=col2, outputCol=col3)
    pipeline = Pipeline(stages=[stringindexer, encoder])
    pipeline_fit = pipeline.fit(data)
    return pipeline_fit.transform(data)

# 对这五个字段进行热独编码
#     "cms_group_id",
#     "final_gender_code",
#     "age_level",
#     "shopping_level",
#     "occupation",
datasets_2 = oneHotEncoder("cms_group_id", "cms_group_id_feature", "cms_group_id_value", datasets_2)
datasets_2 = oneHotEncoder("final_gender_code", "final_gender_code_feature", "final_gender_code_value", datasets_2)
datasets_2 = oneHotEncoder("age_level", "age_level_feature", "age_level_value", datasets_2)
datasets_2 = oneHotEncoder("shopping_level", "shopping_level_feature", "shopping_level_value", datasets_2)
datasets_2 = oneHotEncoder("occupation", "occupation_feature", "occupation_value", datasets_2)
  • "cms_group_id"特征对应关系:
+------------+-------------------------+
|cms_group_id|min(cms_group_id_feature)|
+------------+-------------------------+
|           7|                      9.0|
|          11|                      6.0|
|           3|                      0.0|
|           8|                      8.0|
|           0|                     12.0|
|           5|                      3.0|
|           6|                     10.0|
|           9|                      5.0|
|           1|                      7.0|
|          10|                      4.0|
|           4|                      1.0|
|          12|                     11.0|
|           2|                      2.0|
+------------+-------------------------+

  • "final_gender_code"特征对应关系:
+-----------------+------------------------------+
|final_gender_code|min(final_gender_code_feature)|
+-----------------+------------------------------+
|                1|                           1.0|
|                2|                           0.0|
+-----------------+------------------------------+

  • "age_level"特征对应关系:
+---------+----------------------+
|age_level|min(age_level_feature)|
+---------+----------------------+
|        3|                   0.0|
|        0|                   6.0|
|        5|                   2.0|
|        6|                   5.0|
|        1|                   4.0|
|        4|                   1.0|
|        2|                   3.0|
+---------+----------------------+

  • "shopping_level"特征对应关系:
|shopping_level|min(shopping_level_feature)|
+--------------+---------------------------+
|             3|                        0.0|
|             1|                        2.0|
|             2|                        1.0|
+--------------+---------------------------+

  • "occupation"特征对应关系:
+----------+-----------------------+
|occupation|min(occupation_feature)|
+----------+-----------------------+
|         0|                    0.0|
|         1|                    1.0|
+----------+-----------------------+

datasets_2.groupBy("cms_group_id").min("cms_group_id_feature").show()
datasets_2.groupBy("final_gender_code").min("final_gender_code_feature").show()
datasets_2.groupBy("age_level").min("age_level_feature").show()
datasets_2.groupBy("shopping_level").min("shopping_level_feature").show()
datasets_2.groupBy("occupation").min("occupation_feature").show()

显示结果:

+------------+-------------------------+
|cms_group_id|min(cms_group_id_feature)|
+------------+-------------------------+
|           7|                      9.0|
|          11|                      6.0|
|           3|                      0.0|
|           8|                      8.0|
|           0|                     12.0|
|           5|                      3.0|
|           6|                     10.0|
|           9|                      5.0|
|           1|                      7.0|
|          10|                      4.0|
|           4|                      1.0|
|          12|                     11.0|
|           2|                      2.0|
+------------+-------------------------+

+-----------------+------------------------------+
|final_gender_code|min(final_gender_code_feature)|
+-----------------+------------------------------+
|                1|                           1.0|
|                2|                           0.0|
+-----------------+------------------------------+

+---------+----------------------+
|age_level|min(age_level_feature)|
+---------+----------------------+
|        3|                   0.0|
|        0|                   6.0|
|        5|                   2.0|
|        6|                   5.0|
|        1|                   4.0|
|        4|                   1.0|
|        2|                   3.0|
+---------+----------------------+

+--------------+---------------------------+
|shopping_level|min(shopping_level_feature)|
+--------------+---------------------------+
|             3|                        0.0|
|             1|                        2.0|
|             2|                        1.0|
+--------------+---------------------------+

+----------+-----------------------+
|occupation|min(occupation_feature)|
+----------+-----------------------+
|         0|                    0.0|
|         1|                    1.0|
+----------+-----------------------+

# 由于热独编码后,特征字段不再是之前的字段,重新定义特征值字段
feature_cols = [
    # 特征值
    "price",
    "cms_group_id_value",
    "final_gender_code_value",
    "age_level_value",
    "shopping_level_value",
    "occupation_value",
    "pid_value",
    "pl_onehot_value",
    "nucl_onehot_value"
]
# 根据特征字段计算出特征向量,并划分出训练数据集和测试数据集
from pyspark.ml.feature import VectorAssembler
datasets_2 = VectorAssembler().setInputCols(feature_cols).setOutputCol("features").transform(datasets_2)
train_datasets_2 = datasets_2.filter(datasets_2.timestamp<=(1494691186-24*60*60))
test_datasets_2 = datasets_2.filter(datasets_2.timestamp>(1494691186-24*60*60))
train_datasets_2.printSchema()
train_datasets_2.first()

显示结果:

root
 |-- timestamp: long (nullable = true)
 |-- clk: integer (nullable = true)
 |-- price: float (nullable = true)
 |-- cms_group_id: string (nullable = true)
 |-- final_gender_code: string (nullable = true)
 |-- age_level: string (nullable = true)
 |-- shopping_level: string (nullable = true)
 |-- occupation: string (nullable = true)
 |-- pid_value: vector (nullable = true)
 |-- pl_onehot_value: vector (nullable = true)
 |-- nucl_onehot_value: vector (nullable = true)
 |-- cms_group_id_feature: double (nullable = false)
 |-- cms_group_id_value: vector (nullable = true)
 |-- final_gender_code_feature: double (nullable = false)
 |-- final_gender_code_value: vector (nullable = true)
 |-- age_level_feature: double (nullable = false)
 |-- age_level_value: vector (nullable = true)
 |-- shopping_level_feature: double (nullable = false)
 |-- shopping_level_value: vector (nullable = true)
 |-- occupation_feature: double (nullable = false)
 |-- occupation_value: vector (nullable = true)
 |-- features: vector (nullable = true)

Row(timestamp=1494261938, clk=0, price=108.0, cms_group_id='11', final_gender_code='1', age_level='5', shopping_level='3', occupation='0', pid_value=SparseVector(2, {1: 1.0}), pl_onehot_value=SparseVector(4, {0: 1.0}), nucl_onehot_value=SparseVector(5, {1: 1.0}), cms_group_id_feature=6.0, cms_group_id_value=SparseVector(13, {6: 1.0}), final_gender_code_feature=1.0, final_gender_code_value=SparseVector(2, {1: 1.0}), age_level_feature=2.0, age_level_value=SparseVector(7, {2: 1.0}), shopping_level_feature=0.0, shopping_level_value=SparseVector(3, {0: 1.0}), occupation_feature=0.0, occupation_value=SparseVector(2, {0: 1.0}), features=SparseVector(39, {0: 108.0, 7: 1.0, 15: 1.0, 18: 1.0, 23: 1.0, 26: 1.0, 29: 1.0, 30: 1.0, 35: 1.0}))

  • 创建逻辑回归训练器,并训练模型
from pyspark.ml.classification import LogisticRegression
lr2 = LogisticRegression()
#设置目标值对应的列   setFeaturesCol 设置特征值对应的列名
model2 = lr2.setLabelCol("clk").setFeaturesCol("features").fit(train_datasets_2)
# 存储模型
model2.save("hdfs://localhost:9000/models/CTRModel_AllOneHot.obj")
from pyspark.ml.classification import LogisticRegressionModel
# 载入训练好的模型
model2 = LogisticRegressionModel.load("hdfs://localhost:9000/models/CTRModel_AllOneHot.obj")
result_2 = model2.transform(test_datasets_2)
# 按probability升序排列数据,probability表示预测结果的概率
result_2.select("clk", "price", "probability", "prediction").sort("probability").show(100)

# 对比前面的result_1的预测结果,能发现这里的预测率稍微准确了一点,这里top20里出现了3个点击的,但前面的只出现了1个
# 因此可见对特征的细化处理,已经帮助我们提高模型的效果的

显示结果:

+---+-----------+--------------------+----------+
|clk|      price|         probability|prediction|
+---+-----------+--------------------+----------+
|  0|      1.0E8|[0.85524418892857...|       0.0|
|  0|      1.0E8|[0.88353143762124...|       0.0|
|  0|      1.0E8|[0.89169808985616...|       0.0|
|  1|5.5555556E7|[0.92511743960350...|       0.0|
|  0|     179.01|[0.93239951738307...|       0.0|
|  1|      159.0|[0.93239952905659...|       0.0|
|  0|      118.0|[0.93239955297535...|       0.0|
|  0|      688.0|[0.93451506165953...|       0.0|
|  0|      339.0|[0.93451525933626...|       0.0|
|  0|      335.0|[0.93451526160190...|       0.0|
|  0|      220.0|[0.93451532673881...|       0.0|
|  0|      176.0|[0.93451535166074...|       0.0|
|  0|      158.0|[0.93451536185607...|       0.0|
|  0|      158.0|[0.93451536185607...|       0.0|
|  1|      149.0|[0.93451536695374...|       0.0|
|  0|      122.5|[0.93451538196353...|       0.0|
|  0|       99.0|[0.93451539527410...|       0.0|
|  0|       88.0|[0.93451540150458...|       0.0|
|  0|       79.0|[0.93451540660224...|       0.0|
|  0|       75.0|[0.93451540886787...|       0.0|
|  0|       68.0|[0.93451541283272...|       0.0|
|  0|       68.0|[0.93451541283272...|       0.0|
|  0|       59.9|[0.93451541742061...|       0.0|
|  0|      44.98|[0.93451542587140...|       0.0|
|  0|       35.5|[0.93451543124094...|       0.0|
|  0|       33.0|[0.93451543265696...|       0.0|
|  0|       32.8|[0.93451543277024...|       0.0|
|  0|       30.0|[0.93451543435618...|       0.0|
|  0|       28.0|[0.93451543548899...|       0.0|
|  0|       19.9|[0.93451544007688...|       0.0|
|  0|       19.8|[0.93451544013353...|       0.0|
|  0|       19.8|[0.93451544013353...|       0.0|
|  0|       12.0|[0.93451544455150...|       0.0|
|  0|        6.7|[0.93451544755345...|       0.0|
|  0|      568.0|[0.93458159339238...|       0.0|
|  0|      398.0|[0.93458168959099...|       0.0|
|  0|      158.0|[0.93458182540058...|       0.0|
|  0|      245.0|[0.93471518526899...|       0.0|
|  0|       99.0|[0.93471526772971...|       0.0|
|  0|       88.0|[0.93471527394249...|       0.0|
|  0|     1288.0|[0.93474589600376...|       0.0|
|  0|      688.0|[0.93474623473450...|       0.0|
|  0|      656.0|[0.93474625280009...|       0.0|
|  0|      568.0|[0.93474630248045...|       0.0|
|  0|      498.0|[0.93474634199889...|       0.0|
|  0|      399.0|[0.93474639788922...|       0.0|
|  0|      396.0|[0.93474639958287...|       0.0|
|  0|      298.0|[0.93474645490860...|       0.0|
|  0|      293.0|[0.93474645773134...|       0.0|
|  0|      209.0|[0.93474650515337...|       0.0|
|  0|      198.0|[0.93474651136339...|       0.0|
|  0|      198.0|[0.93474651136339...|       0.0|
|  0|      169.0|[0.93474652773527...|       0.0|
|  0|      168.0|[0.93474652829982...|       0.0|
|  0|      159.0|[0.93474653338074...|       0.0|
|  0|      155.0|[0.93474653563893...|       0.0|
|  0|      139.0|[0.93474654467169...|       0.0|
|  0|      138.0|[0.93474654523624...|       0.0|
|  0|      119.0|[0.93474655596264...|       0.0|
|  0|       99.0|[0.93474656725358...|       0.0|
|  0|       99.0|[0.93474656725358...|       0.0|
|  0|       88.0|[0.93474657346360...|       0.0|
|  0|       88.0|[0.93474657346360...|       0.0|
|  0|       79.0|[0.93474657854453...|       0.0|
|  0|       59.0|[0.93474658983547...|       0.0|
|  0|       59.0|[0.93474658983547...|       0.0|
|  0|       59.0|[0.93474658983547...|       0.0|
|  0|       58.0|[0.93474659040002...|       0.0|
|  0|       57.0|[0.93474659096456...|       0.0|
|  0|       49.8|[0.93474659502930...|       0.0|
|  0|      39.98|[0.93474660057315...|       0.0|
|  0|       36.8|[0.93474660236841...|       0.0|
|  0|       34.0|[0.93474660394914...|       0.0|
|  0|     6520.0|[0.93480919087761...|       0.0|
|  0|     3699.0|[0.93481078202537...|       0.0|
|  0|     1980.0|[0.93481175158689...|       0.0|
|  0|      660.0|[0.93481249609274...|       0.0|
|  0|      660.0|[0.93481249609274...|       0.0|
|  0|      398.0|[0.93481264386492...|       0.0|
|  0|      369.0|[0.93481266022137...|       0.0|
|  0|      299.0|[0.93481269970243...|       0.0|
|  0|      295.0|[0.93481270195849...|       0.0|
|  0|      278.0|[0.93481271154674...|       0.0|
|  0|      270.0|[0.93481271605886...|       0.0|
|  0|      228.0|[0.93481273974748...|       0.0|
|  0|      228.0|[0.93481273974748...|       0.0|
|  0|    11368.0|[0.93494253131370...|       0.0|
|  0|     9999.0|[0.93494330201510...|       0.0|
|  0|     1099.0|[0.93494360670448...|       0.0|
|  1|     8888.0|[0.93494392746484...|       0.0|
|  0|      338.0|[0.93494403511659...|       0.0|
|  0|      311.0|[0.93494405031645...|       0.0|
|  0|      300.0|[0.93494405650898...|       0.0|
|  0|      278.0|[0.93494406889404...|       0.0|
|  0|      188.0|[0.93494411956019...|       0.0|
|  0|      176.0|[0.93494412631568...|       0.0|
|  0|      168.0|[0.93494413081933...|       0.0|
|  0|      158.0|[0.93494413644890...|       0.0|
|  1|      138.0|[0.93494414770804...|       0.0|
|  0|      125.0|[0.93494415502647...|       0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows

result_2.filter(result_2.clk==1).select("clk", "price", "probability", "prediction").sort("probability").show(100)
# 从该结果也可以看出,result_2的点击率预测率普遍要比result_1高出一点点

显示结果:

+---+-----------+--------------------+----------+
|clk|      price|         probability|prediction|
+---+-----------+--------------------+----------+
|  1|5.5555556E7|[0.92511743960350...|       0.0|
|  1|      159.0|[0.93239952905659...|       0.0|
|  1|      149.0|[0.93451536695374...|       0.0|
|  1|     8888.0|[0.93494392746484...|       0.0|
|  1|      138.0|[0.93494414770804...|       0.0|
|  1|       35.0|[0.93494420569256...|       0.0|
|  1|      519.0|[0.93494863870621...|       0.0|
|  1|      478.0|[0.93494866178596...|       0.0|
|  1|      349.0|[0.93494873440265...|       0.0|
|  1|      348.0|[0.93494873496557...|       0.0|
|  1|      316.0|[0.93494875297901...|       0.0|
|  1|      298.0|[0.93494876311156...|       0.0|
|  1|      298.0|[0.93494876311156...|       0.0|
|  1|      199.0|[0.93494881884058...|       0.0|
|  1|      199.0|[0.93494881884058...|       0.0|
|  1|      198.0|[0.93494881940350...|       0.0|
|  1|      187.1|[0.93494882553931...|       0.0|
|  1|      176.0|[0.93494883178772...|       0.0|
|  1|      168.0|[0.93494883629107...|       0.0|
|  1|      160.0|[0.93494884079442...|       0.0|
|  1|      158.0|[0.93494884192026...|       0.0|
|  1|      158.0|[0.93494884192026...|       0.0|
|  1|      135.0|[0.93494885486740...|       0.0|
|  1|      129.0|[0.93494885824491...|       0.0|
|  1|      127.0|[0.93494885937075...|       0.0|
|  1|      125.0|[0.93494886049659...|       0.0|
|  1|      124.0|[0.93494886105951...|       0.0|
|  1|      118.0|[0.93494886443702...|       0.0|
|  1|      109.0|[0.93494886950329...|       0.0|
|  1|      108.0|[0.93494887006621...|       0.0|
|  1|       99.0|[0.93494887513247...|       0.0|
|  1|       98.0|[0.93494887569539...|       0.0|
|  1|       79.8|[0.93494888594051...|       0.0|
|  1|       79.0|[0.93494888639085...|       0.0|
|  1|       77.0|[0.93494888751668...|       0.0|
|  1|       72.5|[0.93494889004982...|       0.0|
|  1|       69.0|[0.93494889202003...|       0.0|
|  1|       68.0|[0.93494889258295...|       0.0|
|  1|       60.0|[0.93494889708630...|       0.0|
|  1|      43.98|[0.93494890610426...|       0.0|
|  1|       40.0|[0.93494890834467...|       0.0|
|  1|       39.9|[0.93494890840096...|       0.0|
|  1|       39.6|[0.93494890856984...|       0.0|
|  1|       32.0|[0.93494891284802...|       0.0|
|  1|       31.0|[0.93494891341094...|       0.0|
|  1|      25.98|[0.93494891623679...|       0.0|
|  1|       23.0|[0.93494891791428...|       0.0|
|  1|       19.0|[0.93494892016596...|       0.0|
|  1|       16.9|[0.93494892134809...|       0.0|
|  1|       10.0|[0.93494892523222...|       0.0|
|  1|        3.5|[0.93494892889119...|       0.0|
|  1|        3.5|[0.93494892889119...|       0.0|
|  1|        0.4|[0.93494893063624...|       0.0|
|  1|     1288.0|[0.93501426059874...|       0.0|
|  1|      980.0|[0.93501443381533...|       0.0|
|  1|      788.0|[0.93501454179429...|       0.0|
|  1|      698.0|[0.93501459240937...|       0.0|
|  1|      695.0|[0.93501459409654...|       0.0|
|  1|      688.0|[0.93501459803326...|       0.0|
|  1|      599.0|[0.93501464808591...|       0.0|
|  1|      588.0|[0.93501465427219...|       0.0|
|  1|      516.0|[0.93501469476419...|       0.0|
|  1|      495.0|[0.93501470657436...|       0.0|
|  1|      398.0|[0.93501476112603...|       0.0|
|  1|      368.0|[0.93501477799768...|       0.0|
|  1|      339.0|[0.93501479430693...|       0.0|
|  1|      335.0|[0.93501479655648...|       0.0|
|  1|      324.0|[0.93501480274275...|       0.0|
|  1|      316.0|[0.93501480724185...|       0.0|
|  1|      299.0|[0.93501481680244...|       0.0|
|  1|      295.0|[0.93501481905199...|       0.0|
|  1|      279.0|[0.93501482805020...|       0.0|
|  1|      268.0|[0.93501483423646...|       0.0|
|  1|      259.0|[0.93501483929795...|       0.0|
|  1|      259.0|[0.93501483929795...|       0.0|
|  1|      249.0|[0.93501484492182...|       0.0|
|  1|      238.0|[0.93501485110809...|       0.0|
|  1|      199.0|[0.93501487304119...|       0.0|
|  1|      198.0|[0.93501487360358...|       0.0|
|  1|      179.0|[0.93501488428894...|       0.0|
|  1|      175.0|[0.93501488653849...|       0.0|
|  1|      129.0|[0.93501491240829...|       0.0|
|  1|      128.0|[0.93501491297068...|       0.0|
|  1|      118.0|[0.93501491859455...|       0.0|
|  1|      109.0|[0.93501492365603...|       0.0|
|  1|       98.0|[0.93501492984229...|       0.0|
|  1|       89.0|[0.93501493490377...|       0.0|
|  1|       79.0|[0.93501494052764...|       0.0|
|  1|       75.0|[0.93501494277718...|       0.0|
|  1|       69.8|[0.93501494570159...|       0.0|
|  1|       30.0|[0.93501496808458...|       0.0|
|  1|       15.0|[0.93501497652038...|       0.0|
|  1|      368.0|[0.93665387743951...|       0.0|
|  1|      198.0|[0.93665397079735...|       0.0|
|  1|      178.0|[0.93665398178062...|       0.0|
|  1|      158.0|[0.93665399276388...|       0.0|
|  1|      158.0|[0.93665399276388...|       0.0|
|  1|      149.0|[0.93665399770635...|       0.0|
|  1|       68.0|[0.93665404218855...|       0.0|
|  1|       36.0|[0.93665405976176...|       0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

lucky-zhao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值