四 LR实现CTR预估

4.1 Spark逻辑回归(LR)训练点击率预测模型

  • 本小节主要根据广告点击样本数据集(raw_sample)、广告基本特征数据集(ad_feature)、用户基本信息数据集(user_profile)构建出了一个完整的样本数据集,并按日期划分为了训练集(前七天)和测试集(最后一天),利用逻辑回归进行训练。

    训练模型时,通过对类别特征数据进行处理,一定程度达到提高了模型的效果

'''从HDFS中加载样本数据信息'''
_raw_sample_df1 = spark.read.csv("hdfs://localhost:9000/data/raw_sample.csv", header=True)
# _raw_sample_df1.show() # 展示数据,默认前20条
# 更改表结构,转换为对应的数据类型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, LongType, StringType
  
# 更改df表结构:更改列类型和列名称
_raw_sample_df2 = _raw_sample_df1.\
    withColumn("user", _raw_sample_df1.user.cast(IntegerType())).withColumnRenamed("user", "userId").\
    withColumn("time_stamp", _raw_sample_df1.time_stamp.cast(LongType())).withColumnRenamed("time_stamp", "timestamp").\
    withColumn("adgroup_id", _raw_sample_df1.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("pid", _raw_sample_df1.pid.cast(StringType())).\
    withColumn("nonclk", _raw_sample_df1.nonclk.cast(IntegerType())).\
    withColumn("clk", _raw_sample_df1.clk.cast(IntegerType()))
_raw_sample_df2.printSchema()
_raw_sample_df2.show()

# 样本数据pid特征处理
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

stringindexer = StringIndexer(inputCol='pid', outputCol='pid_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pid_feature', outputCol='pid_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_raw_sample_df2)
raw_sample_df = pipeline_fit.transform(_raw_sample_df2)
raw_sample_df.show()

'''pid和特征的对应关系 430548_1007:0 430549_1007:1 '''

显示结果:

root
 |-- userId: integer (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: integer (nullable = true)
 |-- clk: integer (nullable = true)

+------+----------+---------+-----------+------+---+
|userId| timestamp|adgroupId|        pid|nonclk|clk|
+------+----------+---------+-----------+------+---+
|581738|1494137644|        1|430548_1007|     1|  0|
|449818|1494638778|        3|430548_1007|     1|  0|
|914836|1494650879|        4|430548_1007|     1|  0|
|914836|1494651029|        5|430548_1007|     1|  0|
|399907|1494302958|        8|430548_1007|     1|  0|
|628137|1494524935|        9|430548_1007|     1|  0|
|298139|1494462593|        9|430539_1007|     1|  0|
|775475|1494561036|        9|430548_1007|     1|  0|
|555266|1494307136|       11|430539_1007|     1|  0|
|117840|1494036743|       11|430548_1007|     1|  0|
|739815|1494115387|       11|430539_1007|     1|  0|
|623911|1494625301|       11|430548_1007|     1|  0|
|623911|1494451608|       11|430548_1007|     1|  0|
|421590|1494034144|       11|430548_1007|     1|  0|
|976358|1494156949|       13|430548_1007|     1|  0|
|286630|1494218579|       13|430539_1007|     1|  0|
|286630|1494289247|       13|430539_1007|     1|  0|
|771431|1494153867|       13|430548_1007|     1|  0|
|707120|1494220810|       13|430548_1007|     1|  0|
|530454|1494293746|       13|430548_1007|     1|  0|
+------+----------+---------+-----------+------+---+
only showing top 20 rows

+------+----------+---------+-----------+------+---+-----------+-------------+
|userId| timestamp|adgroupId|        pid|nonclk|clk|pid_feature|    pid_value|
+------+----------+---------+-----------+------+---+-----------+-------------+
|581738|1494137644|        1|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|449818|1494638778|        3|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494650879|        4|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494651029|        5|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|399907|1494302958|        8|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|628137|1494524935|        9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|298139|1494462593|        9|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|775475|1494561036|        9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|555266|1494307136|       11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|117840|1494036743|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|739815|1494115387|       11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|623911|1494625301|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|623911|1494451608|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|421590|1494034144|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|976358|1494156949|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|286630|1494218579|       13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|286630|1494289247|       13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|771431|1494153867|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|707120|1494220810|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|530454|1494293746|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
+------+----------+---------+-----------+------+---+-----------+-------------+
only showing top 20 rows

'pid和特征的对应关系\n430548_1007:0\n430549_1007:1\n'
  • 从HDFS中加载广告基本信息数据
_ad_feature_df = spark.read.csv("hdfs://localhost:9000/datasets/ad_feature.csv", header=True)

# 更改表结构,转换为对应的数据类型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType

# 替换掉NULL字符串
_ad_feature_df = _ad_feature_df.replace("NULL", "-1")
 
# 更改df表结构:更改列类型和列名称
ad_feature_df = _ad_feature_df.\
    withColumn("adgroup_id", _ad_feature_df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("cate_id", _ad_feature_df.cate_id.cast(IntegerType())).withColumnRenamed("cate_id", "cateId").\
    withColumn("campaign_id", _ad_feature_df.campaign_id.cast(IntegerType())).withColumnRenamed("campaign_id", "campaignId").\
    withColumn("customer", _ad_feature_df.customer.cast(IntegerType())).withColumnRenamed("customer", "customerId").\
    withColumn("brand", _ad_feature_df.brand.cast(IntegerType())).withColumnRenamed("brand", "brandId").\
    withColumn("price", _ad_feature_df.price.cast(FloatType()))
ad_feature_df.printSchema()
ad_feature_df.show()

显示结果:

root
 |-- adgroupId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- campaignId: integer (nullable = true)
 |-- customerId: integer (nullable = true)
 |-- brandId: integer (nullable = true)
 |-- price: float (nullable = true)

+---------+------+----------+----------+-------+-----+
|adgroupId|cateId|campaignId|customerId|brandId|price|
+---------+------+----------+----------+-------+-----+
|    63133|  6406|     83237|         1|  95471|170.0|
|   313401|  6406|     83237|         1|  87331|199.0|
|   248909|   392|     83237|         1|  32233| 38.0|
|   208458|   392|     83237|         1| 174374|139.0|
|   110847|  7211|    135256|         2| 145952|32.99|
|   607788|  6261|    387991|         6| 207800|199.0|
|   375706|  4520|    387991|         6|     -1| 99.0|
|    11115|  7213|    139747|         9| 186847| 33.0|
|    24484|  7207|    139744|         9| 186847| 19.0|
|    28589|  5953|    395195|        13|     -1|428.0|
|    23236|  5953|    395195|        13|     -1|368.0|
|   300556|  5953|    395195|        13|     -1|639.0|
|    92560|  5953|    395195|        13|     -1|368.0|
|   590965|  4284|     28145|        14| 454237|249.0|
|   529913|  4284|     70206|        14|     -1|249.0|
|   546930|  4284|     28145|        14|     -1|249.0|
|   639794|  6261|     70206|        14|  37004| 89.9|
|   335413|  4284|     28145|        14|     -1|249.0|
|   794890|  4284|     70206|        14| 454237|249.0|
|   684020|  6261|     70206|        14|  37004| 99.0|
+---------+------+----------+----------+-------+-----+
only showing top 20 rows
  • 从HDFS加载用户基本信息数据
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, FloatType

# 构建表结构schema对象
schema = StructType([
    StructField("userId", IntegerType()),
    StructField("cms_segid", IntegerType()),
    StructField("cms_group_id", IntegerType()),
    StructField("final_gender_code", IntegerType()),
    StructField("age_level", IntegerType()),
    StructField("pvalue_level", IntegerType()),
    StructField("shopping_level", IntegerType()),
    StructField("occupation", IntegerType()),
    StructField("new_user_class_level", IntegerType())
])
# 利用schema从hdfs加载
_user_profile_df1 = spark.read.csv("hdfs://localhost:9000/datasets/user_profile.csv", header=True, schema=schema)
# user_profile_df.printSchema()
# user_profile_df.show()

'''对缺失数据进行特征热编码'''
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

# 使用热编码转换pvalue_level的一维数据为多维,增加n-1个虚拟变量,n为pvalue_level的取值范围

# 需要先将缺失值全部替换为数值,便于处理,否则会抛出异常
from pyspark.sql.types import StringType
_user_profile_df2 = _user_profile_df1.na.fill(-1)
# _user_profile_df2.show()

# 热编码时,必须先将待处理字段转为字符串类型才可处理
_user_profile_df3 = _user_profile_df2.withColumn("pvalue_level", _user_profile_df2.pvalue_level.cast(StringType()))\
    .withColumn("new_user_class_level", _user_profile_df2.new_user_class_level.cast(StringType()))
# _user_profile_df3.printSchema()

# 对pvalue_level进行热编码,求值
# 运行过程是先将pvalue_level转换为一列新的特征数据,然后对该特征数据求出的热编码值,存在了新的一列数据中,类型为一个稀疏矩阵
stringindexer = StringIndexer(inputCol='pvalue_level', outputCol='pl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pl_onehot_feature', outputCol='pl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_user_profile_df3)
_user_profile_df4 = pipeline_fit.transform(_user_profile_df3)
# pl_onehot_value列的值为稀疏矩阵,存储热编码的结果
# _user_profile_df4.printSchema()
# _user_profile_df4.show()

# 使用热编码转换new_user_class_level的一维数据为多维
stringindexer = StringIndexer(inputCol='new_user_class_level', outputCol='nucl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='nucl_onehot_feature', outputCol='nucl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_user_profile_df4)
user_profile_df = pipeline_fit.transform(_user_profile_df4)
user_profile_df.show()

显示结果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|pl_onehot_feature|pl_onehot_value|nucl_onehot_feature|nucl_onehot_value|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|              0.0|  (4,[0],[1.0])|                2.0|    (5,[2],[1.0])|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|              2.0|  (4,[2],[1.0])|                1.0|    (5,[1],[1.0])|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|              0.0|  (4,[0],[1.0])|                1.0|    (5,[1],[1.0])|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|              1.0|  (4,[1],[1.0])|                4.0|    (5,[4],[1.0])|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|              2.0|  (4,[2],[1.0])|                2.0|    (5,[2],[1.0])|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
only showing top 20 rows

  • 热编码中:"pvalue_level"特征对应关系:
+------------+----------------------+
|pvalue_level|pl_onehot_feature     |
+------------+----------------------+
|          -1|                   0.0|
|           3|                   3.0|
|           1|                   2.0|
|           2|                   1.0|
+------------+----------------------+
  • “new_user_class_level”的特征对应关系
+--------------------+------------------------+
|new_user_class_level|nucl_onehot_feature     |
+--------------------+------------------------+
|                  -1|                     0.0|
|                   3|                     2.0|
|                   1|                     4.0|
|                   4|                     3.0|
|                   2|                     1.0|
+--------------------+------------------------+
user_profile_df.groupBy("pvalue_level").min("pl_onehot_feature").show()
user_profile_df.groupBy("new_user_class_level").min("nucl_onehot_feature").show()

显示结果:

+------------+----------------------+
|pvalue_level|min(pl_onehot_feature)|
+------------+----------------------+
|          -1|                   0.0|
|           3|                   3.0|
|           1|                   2.0|
|           2|                   1.0|
+------------+----------------------+

+--------------------+------------------------+
|new_user_class_level|min(nucl_onehot_feature)|
+--------------------+------------------------+
|                  -1|                     0.0|
|                   3|                     2.0|
|                   1|                     4.0|
|                   4|                     3.0|
|                   2|                     1.0|
+--------------------+------------------------+

# raw_sample_df和ad_feature_df合并条件
condition = [raw_sample_df.adgroupId==ad_feature_df.adgroupId]
_ = raw_sample_df.join(ad_feature_df, condition, 'outer')

# _和user_profile_df合并条件
condition2 = [_.userId==user_profile_df.userId]
datasets = _.join(user_profile_df, condition2, "outer")
# 查看datasets的结构
datasets.printSchema()
# 查看datasets条目数
print(datasets.count())

显示结果:

root
 |-- userId: integer (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: integer (nullable = true)
 |-- clk: integer (nullable = true)
 |-- pid_feature: double (nullable = true)
 |-- pid_value: vector (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- campaignId: integer (nullable = true)
 |-- customerId: integer (nullable = true)
 |-- brandId: integer (nullable = true)
 |-- price: float (nullable = true)
 |-- userId: integer (nullable = true)
 |-- cms_segid: integer (nullable = true)
 |-- cms_group_id: integer (nullable = true)
 |-- final_gender_code: integer (nullable = true)
 |-- age_level: integer (nullable = true)
 |-- pvalue_level: string (nullable = true)
 |-- shopping_level: integer (nullable = true)
 |-- occupation: integer (nullable = true)
 |-- new_user_class_level: string (nullable = true)
 |-- pl_onehot_feature: double (nullable = true)
 |-- pl_onehot_value: vector (nullable = true)
 |-- nucl_onehot_feature: double (nullable = true)
 |-- nucl_onehot_value: vector (nullable = true)

26557961
  • 训练CTRModel_Normal:直接将对应的特征的特征值组合成对应的特征向量进行训练
# 剔除冗余、不需要的字段
useful_cols = [
    # 
    # 时间字段,划分训练集和测试集
    "timestamp",
    # label目标值字段
    "clk",  
    # 特征值字段
    "pid_value",       # 资源位的特征向量
    "price",    # 广告价格
    "cms_segid",    # 用户微群ID
    "cms_group_id",    # 用户组ID
    "final_gender_code",    # 用户性别特征,[1,2]
    "age_level",    # 年龄等级,1-
    "shopping_level",
    "occupation",
    "pl_onehot_value",
    "nucl_onehot_value"
]
# 筛选指定字段数据,构建新的数据集
datasets_1 = datasets.select(*useful_cols)
# 由于前面使用的是outer方式合并的数据,产生了部分空值数据,这里必须先剔除掉
datasets_1 = datasets_1.dropna()
print("剔除空值数据后,还剩:", datasets_1.count())

显示结果:

剔除空值数据后,还剩: 25029435

  • 根据特征字段计算出特征向量,并划分出训练数据集和测试数据集
from pyspark.ml.feature import VectorAssembler
# 根据特征字段计算特征向量
datasets_1 = VectorAssembler().setInputCols(useful_cols[2:]).setOutputCol("features").transform(datasets_1)
# 训练数据集: 约7天的数据
train_datasets_1 = datasets_1.filter(datasets_1.timestamp<=(1494691186-24*60*60))
# 测试数据集:约1天的数据量
test_datasets_1 = datasets_1.filter(datasets_1.timestamp>(1494691186-24*60*60))
# 所有的特征的特征向量已经汇总到在features字段中
train_datasets_1.show(5)
test_datasets_1.show(5)

显示结果:

+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
| timestamp|clk|    pid_value| price|cms_segid|cms_group_id|final_gender_code|age_level|shopping_level|occupation|pl_onehot_value|nucl_onehot_value|            features|
+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
|1494261938|  0|(2,[1],[1.0])| 108.0|        0|          11|                1|        5|             3|         0|  (4,[0],[1.0])|    (5,[1],[1.0])|(18,[1,2,4,5,6,7,...|
|1494261938|  0|(2,[1],[1.0])|1880.0|        0|          11|                1|        5|             3|         0|  (4,[0],[1.0])|    (5,[1],[1.0])|(18,[1,2,4,5,6,7,...|
|1494553913|  0|(2,[1],[1.0])|2360.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494553913|  0|(2,[1],[1.0])|2200.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494436784|  0|(2,[1],[1.0])|5649.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
only showing top 5 rows

+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
| timestamp|clk|    pid_value|price|cms_segid|cms_group_id|final_gender_code|age_level|shopping_level|occupation|pl_onehot_value|nucl_onehot_value|            features|
+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
|1494677292|  0|(2,[1],[1.0])|176.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494677292|  0|(2,[1],[1.0])|698.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494677292|  0|(2,[1],[1.0])|697.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494684007|  0|(2,[1],[1.0])|247.0|       18|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[4],[1.0])|(18,[1,2,3,4,5,6,...|
|1494684007|  0|(2,[1],[1.0])|109.0|       18|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[4],[1.0])|(18,[1,2,3,4,5,6,...|
+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
only showing top 5 rows

from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression()
# 设置目标字段、特征值字段并训练
model = lr.setLabelCol("clk").setFeaturesCol("features").fit(train_datasets_1)
# 对模型进行存储
model.save("hdfs://localhost:9000/models/CTRModel_Normal.obj")
# 载入训练好的模型
from pyspark.ml.classification import LogisticRegressionModel
model = LogisticRegressionModel.load("hdfs://localhost:9000/models/CTRModel_Normal.obj")
# 根据测试数据进行预测
result_1 = model.transform(test_datasets_1)
# 按probability升序排列数据,probability表示预测结果的概率
# 如果预测值是0,其概率是0.9248,那么反之可推出1的可能性就是1-0.9248=0.0752,即点击概率约为7.52%
# 因为前面提到广告的点击率一般都比较低,所以预测值通常都是0,因此通常需要反减得出点击的概率
result_1.select("clk", "price", "probability", "prediction").sort("probability").show(100)

显示结果:

+---+-----------+--------------------+----------+
|clk|      price|         probability|prediction|
+---+-----------+--------------------+----------+
|  0|      1.0E8|[0.86822033939259...|       0.0|
|  0|      1.0E8|[0.88410457194969...|       0.0|
|  0|      1.0E8|[0.89175497837562...|       0.0|
|  1|5.5555556E7|[0.92481456486873...|       0.0|
|  0|      1.5E7|[0.93741450446939...|       0.0|
|  0|      1.5E7|[0.93757135079959...|       0.0|
|  0|      1.5E7|[0.93834723093801...|       0.0|
|  0|     1099.0|[0.93972095713786...|       0.0|
|  0|      338.0|[0.93972134993018...|       0.0|
|  0|      311.0|[0.93972136386626...|       0.0|
|  0|      300.0|[0.93972136954393...|       0.0|
|  0|      278.0|[0.93972138089925...|       0.0|
|  0|      188.0|[0.93972142735283...|       0.0|
|  0|      176.0|[0.93972143354663...|       0.0|
|  0|      168.0|[0.93972143767584...|       0.0|
|  0|      158.0|[0.93972144283734...|       0.0|
|  1|      138.0|[0.93972145316035...|       0.0|
|  0|      125.0|[0.93972145987031...|       0.0|
|  0|      119.0|[0.93972146296721...|       0.0|
|  0|       78.0|[0.93972148412937...|       0.0|
|  0|      59.98|[0.93972149343040...|       0.0|
|  0|       58.0|[0.93972149445238...|       0.0|
|  0|       56.0|[0.93972149548468...|       0.0|
|  0|       38.0|[0.93972150477538...|       0.0|
|  1|       35.0|[0.93972150632383...|       0.0|
|  0|       33.0|[0.93972150735613...|       0.0|
|  0|       30.0|[0.93972150890458...|       0.0|
|  0|       27.6|[0.93972151014334...|       0.0|
|  0|       18.0|[0.93972151509838...|       0.0|
|  0|       30.0|[0.93980311191464...|       0.0|
|  0|       28.0|[0.93980311294563...|       0.0|
|  0|       25.0|[0.93980311449212...|       0.0|
|  0|      688.0|[0.93999362023323...|       0.0|
|  0|      339.0|[0.93999379960808...|       0.0|
|  0|      335.0|[0.93999380166395...|       0.0|
|  0|      220.0|[0.93999386077017...|       0.0|
|  0|      176.0|[0.93999388338470...|       0.0|
|  0|      158.0|[0.93999389263610...|       0.0|
|  0|      158.0|[0.93999389263610...|       0.0|
|  1|      149.0|[0.93999389726180...|       0.0|
|  0|      122.5|[0.93999391088191...|       0.0|
|  0|       99.0|[0.93999392296012...|       0.0|
|  0|       88.0|[0.93999392861375...|       0.0|
|  0|       79.0|[0.93999393323945...|       0.0|
|  0|       75.0|[0.93999393529532...|       0.0|
|  0|       68.0|[0.93999393889308...|       0.0|
|  0|       68.0|[0.93999393889308...|       0.0|
|  0|       59.9|[0.93999394305620...|       0.0|
|  0|      44.98|[0.93999395072458...|       0.0|
|  0|       35.5|[0.93999395559698...|       0.0|
|  0|       33.0|[0.93999395688189...|       0.0|
|  0|       32.8|[0.93999395698469...|       0.0|
|  0|       30.0|[0.93999395842379...|       0.0|
|  0|       28.0|[0.93999395945172...|       0.0|
|  0|       19.9|[0.93999396361485...|       0.0|
|  0|       19.8|[0.93999396366625...|       0.0|
|  0|       19.8|[0.93999396366625...|       0.0|
|  0|       12.0|[0.93999396767518...|       0.0|
|  0|        6.7|[0.93999397039920...|       0.0|
|  0|      568.0|[0.94000369247841...|       0.0|
|  0|      398.0|[0.94000377983931...|       0.0|
|  0|      158.0|[0.94000390317214...|       0.0|
|  0|     5718.0|[0.94001886593718...|       0.0|
|  0|     5718.0|[0.94001886593718...|       0.0|
|  1|     5608.0|[0.94001892245145...|       0.0|
|  0|     4120.0|[0.94001968693052...|       0.0|
|  0|     1027.5|[0.94002127571285...|       0.0|
|  0|     1027.5|[0.94002127571285...|       0.0|
|  0|      989.0|[0.94002129549211...|       0.0|
|  0|      672.0|[0.94002145834965...|       0.0|
|  0|      660.0|[0.94002146451460...|       0.0|
|  0|      598.0|[0.94002149636681...|       0.0|
|  0|      598.0|[0.94002149636681...|       0.0|
|  0|      563.0|[0.94002151434789...|       0.0|
|  0|      509.0|[0.94002154209012...|       0.0|
|  0|      509.0|[0.94002154209012...|       0.0|
|  0|      500.0|[0.94002154671382...|       0.0|
|  0|      498.0|[0.94002154774131...|       0.0|
|  0|      440.0|[0.94002157753851...|       0.0|
|  0|      430.0|[0.94002158267595...|       0.0|
|  0|      388.0|[0.94002160425322...|       0.0|
|  0|      369.0|[0.94002161401436...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      366.0|[0.94002161555560...|       0.0|
|  0|      366.0|[0.94002161555560...|       0.0|
|  0|      348.0|[0.94002162480299...|       0.0|
|  0|      299.0|[0.94002164997645...|       0.0|
|  0|      299.0|[0.94002164997645...|       0.0|
|  0|      299.0|[0.94002164997645...|       0.0|
|  0|      298.0|[0.94002165049020...|       0.0|
|  0|      297.0|[0.94002165100394...|       0.0|
|  0|      278.0|[0.94002166076508...|       0.0|
|  1|      275.0|[0.94002166230631...|       0.0|
|  0|      275.0|[0.94002166230631...|       0.0|
|  0|      273.0|[0.94002166333380...|       0.0|
|  0|      258.0|[0.94002167103995...|       0.0|
|  0|      256.0|[0.94002167206744...|       0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows

  • 查看样本中点击的被实际点击的条目的预测情况
result_1.filter(result_1.clk==1).select("clk", "price", "probability", "prediction").sort("probability").show(100)

显示结果:

+---+-----------+--------------------+----------+
|clk|      price|         probability|prediction|
+---+-----------+--------------------+----------+
|  1|5.5555556E7|[0.92481456486873...|       0.0|
|  1|      138.0|[0.93972145316035...|       0.0|
|  1|       35.0|[0.93972150632383...|       0.0|
|  1|      149.0|[0.93999389726180...|       0.0|
|  1|     5608.0|[0.94001892245145...|       0.0|
|  1|      275.0|[0.94002166230631...|       0.0|
|  1|       35.0|[0.94002178560473...|       0.0|
|  1|       49.0|[0.94004219516957...|       0.0|
|  1|      915.0|[0.94021082858784...|       0.0|
|  1|      598.0|[0.94021099096349...|       0.0|
|  1|      568.0|[0.94021100633025...|       0.0|
|  1|      398.0|[0.94021109340848...|       0.0|
|  1|      368.0|[0.94021110877521...|       0.0|
|  1|      299.0|[0.94021114411869...|       0.0|
|  1|      278.0|[0.94021115487539...|       0.0|
|  1|      259.0|[0.94021116460765...|       0.0|
|  1|      258.0|[0.94021116511987...|       0.0|
|  1|      258.0|[0.94021116511987...|       0.0|
|  1|      258.0|[0.94021116511987...|       0.0|
|  1|      195.0|[0.94021119738998...|       0.0|
|  1|      188.0|[0.94021120097554...|       0.0|
|  1|      178.0|[0.94021120609778...|       0.0|
|  1|      159.0|[0.94021121583003...|       0.0|
|  1|      149.0|[0.94021122095226...|       0.0|
|  1|      138.0|[0.94021122658672...|       0.0|
|  1|       58.0|[0.94021126756458...|       0.0|
|  1|       49.0|[0.94021127217459...|       0.0|
|  1|       35.0|[0.94021127934572...|       0.0|
|  1|       25.0|[0.94021128446795...|       0.0|
|  1|     2890.0|[0.94028789742257...|       0.0|
|  1|      220.0|[0.94028926340218...|       0.0|
|  1|      188.0|[0.94031410659516...|       0.0|
|  1|       68.0|[0.94031416796289...|       0.0|
|  1|       58.0|[0.94031417307687...|       0.0|
|  1|      198.0|[0.94035413548387...|       0.0|
|  1|      208.0|[0.94039204931181...|       0.0|
|  1|     8888.0|[0.94045237642030...|       0.0|
|  1|      519.0|[0.94045664687995...|       0.0|
|  1|      478.0|[0.94045666780037...|       0.0|
|  1|      349.0|[0.94045673362308...|       0.0|
|  1|      348.0|[0.94045673413334...|       0.0|
|  1|      316.0|[0.94045675046144...|       0.0|
|  1|      298.0|[0.94045675964600...|       0.0|
|  1|      298.0|[0.94045675964600...|       0.0|
|  1|      199.0|[0.94045681016104...|       0.0|
|  1|      199.0|[0.94045681016104...|       0.0|
|  1|      198.0|[0.94045681067129...|       0.0|
|  1|      187.1|[0.94045681623305...|       0.0|
|  1|      176.0|[0.94045682189685...|       0.0|
|  1|      168.0|[0.94045682597887...|       0.0|
|  1|      160.0|[0.94045683006090...|       0.0|
|  1|      158.0|[0.94045683108140...|       0.0|
|  1|      158.0|[0.94045683108140...|       0.0|
|  1|      135.0|[0.94045684281721...|       0.0|
|  1|      129.0|[0.94045684587872...|       0.0|
|  1|      127.0|[0.94045684689923...|       0.0|
|  1|      125.0|[0.94045684791973...|       0.0|
|  1|      124.0|[0.94045684842999...|       0.0|
|  1|      118.0|[0.94045685149150...|       0.0|
|  1|      109.0|[0.94045685608377...|       0.0|
|  1|      108.0|[0.94045685659402...|       0.0|
|  1|       99.0|[0.94045686118630...|       0.0|
|  1|       98.0|[0.94045686169655...|       0.0|
|  1|       79.8|[0.94045687098314...|       0.0|
|  1|       79.0|[0.94045687139134...|       0.0|
|  1|       77.0|[0.94045687241185...|       0.0|
|  1|       72.5|[0.94045687470798...|       0.0|
|  1|       69.0|[0.94045687649386...|       0.0|
|  1|       68.0|[0.94045687700412...|       0.0|
|  1|       60.0|[0.94045688108613...|       0.0|
|  1|      43.98|[0.94045688926037...|       0.0|
|  1|       40.0|[0.94045689129118...|       0.0|
|  1|       39.9|[0.94045689134220...|       0.0|
|  1|       39.6|[0.94045689149528...|       0.0|
|  1|       32.0|[0.94045689537319...|       0.0|
|  1|       31.0|[0.94045689588345...|       0.0|
|  1|      25.98|[0.94045689844491...|       0.0|
|  1|       23.0|[0.94045689996546...|       0.0|
|  1|       19.0|[0.94045690200647...|       0.0|
|  1|       16.9|[0.94045690307800...|       0.0|
|  1|       10.0|[0.94045690659874...|       0.0|
|  1|        3.5|[0.94045690991538...|       0.0|
|  1|        3.5|[0.94045690991538...|       0.0|
|  1|        0.4|[0.94045691149716...|       0.0|
|  1|     3960.0|[0.94055740378069...|       0.0|
|  1|     3088.0|[0.94055784801535...|       0.0|
|  1|     1689.0|[0.94055856072019...|       0.0|
|  1|      998.0|[0.94055891273943...|       0.0|
|  1|      888.0|[0.94055896877705...|       0.0|
|  1|      788.0|[0.94055901972029...|       0.0|
|  1|      737.0|[0.94055904570133...|       0.0|
|  1|      629.0|[0.94055910071996...|       0.0|
|  1|      599.0|[0.94055911600291...|       0.0|
|  1|      599.0|[0.94055911600291...|       0.0|
|  1|      599.0|[0.94055911600291...|       0.0|
|  1|      499.0|[0.94055916694603...|       0.0|
|  1|      468.0|[0.94055918273839...|       0.0|
|  1|      459.0|[0.94055918732327...|       0.0|
|  1|      399.0|[0.94055921788912...|       0.0|
|  1|      399.0|[0.94055921788912...|       0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows