四 LR实现CTR预估
4.1 Spark逻辑回归(LR)训练点击率预测模型
-
本小节主要根据广告点击样本数据集(raw_sample)、广告基本特征数据集(ad_feature)、用户基本信息数据集(user_profile)构建出了一个完整的样本数据集,并按日期划分为了训练集(前七天)和测试集(最后一天),利用逻辑回归进行训练。
训练模型时,通过对类别特征数据进行处理,一定程度达到提高了模型的效果
'''从HDFS中加载样本数据信息'''
_raw_sample_df1 = spark.read.csv("hdfs://localhost:9000/data/raw_sample.csv", header=True)
# _raw_sample_df1.show() # 展示数据,默认前20条
# 更改表结构,转换为对应的数据类型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, LongType, StringType
# 更改df表结构:更改列类型和列名称
_raw_sample_df2 = _raw_sample_df1.\
withColumn("user", _raw_sample_df1.user.cast(IntegerType())).withColumnRenamed("user", "userId").\
withColumn("time_stamp", _raw_sample_df1.time_stamp.cast(LongType())).withColumnRenamed("time_stamp", "timestamp").\
withColumn("adgroup_id", _raw_sample_df1.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
withColumn("pid", _raw_sample_df1.pid.cast(StringType())).\
withColumn("nonclk", _raw_sample_df1.nonclk.cast(IntegerType())).\
withColumn("clk", _raw_sample_df1.clk.cast(IntegerType()))
_raw_sample_df2.printSchema()
_raw_sample_df2.show()
# 样本数据pid特征处理
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline
stringindexer = StringIndexer(inputCol='pid', outputCol='pid_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pid_feature', outputCol='pid_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_raw_sample_df2)
raw_sample_df = pipeline_fit.transform(_raw_sample_df2)
raw_sample_df.show()
'''pid和特征的对应关系 430548_1007:0 430549_1007:1 '''
显示结果:
root
|-- userId: integer (nullable = true)
|-- timestamp: long (nullable = true)
|-- adgroupId: integer (nullable = true)
|-- pid: string (nullable = true)
|-- nonclk: integer (nullable = true)
|-- clk: integer (nullable = true)
+------+----------+---------+-----------+------+---+
|userId| timestamp|adgroupId| pid|nonclk|clk|
+------+----------+---------+-----------+------+---+
|581738|1494137644| 1|430548_1007| 1| 0|
|449818|1494638778| 3|430548_1007| 1| 0|
|914836|1494650879| 4|430548_1007| 1| 0|
|914836|1494651029| 5|430548_1007| 1| 0|
|399907|1494302958| 8|430548_1007| 1| 0|
|628137|1494524935| 9|430548_1007| 1| 0|
|298139|1494462593| 9|430539_1007| 1| 0|
|775475|1494561036| 9|430548_1007| 1| 0|
|555266|1494307136| 11|430539_1007| 1| 0|
|117840|1494036743| 11|430548_1007| 1| 0|
|739815|1494115387| 11|430539_1007| 1| 0|
|623911|1494625301| 11|430548_1007| 1| 0|
|623911|1494451608| 11|430548_1007| 1| 0|
|421590|1494034144| 11|430548_1007| 1| 0|
|976358|1494156949| 13|430548_1007| 1| 0|
|286630|1494218579| 13|430539_1007| 1| 0|
|286630|1494289247| 13|430539_1007| 1| 0|
|771431|1494153867| 13|430548_1007| 1| 0|
|707120|1494220810| 13|430548_1007| 1| 0|
|530454|1494293746| 13|430548_1007| 1| 0|
+------+----------+---------+-----------+------+---+
only showing top 20 rows
+------+----------+---------+-----------+------+---+-----------+-------------+
|userId| timestamp|adgroupId| pid|nonclk|clk|pid_feature| pid_value|
+------+----------+---------+-----------+------+---+-----------+-------------+
|581738|1494137644| 1|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|449818|1494638778| 3|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|914836|1494650879| 4|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|914836|1494651029| 5|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|399907|1494302958| 8|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|628137|1494524935| 9|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|298139|1494462593| 9|430539_1007| 1| 0| 1.0|(2,[1],[1.0])|
|775475|1494561036| 9|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|555266|1494307136| 11|430539_1007| 1| 0| 1.0|(2,[1],[1.0])|
|117840|1494036743| 11|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|739815|1494115387| 11|430539_1007| 1| 0| 1.0|(2,[1],[1.0])|
|623911|1494625301| 11|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|623911|1494451608| 11|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|421590|1494034144| 11|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|976358|1494156949| 13|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|286630|1494218579| 13|430539_1007| 1| 0| 1.0|(2,[1],[1.0])|
|286630|1494289247| 13|430539_1007| 1| 0| 1.0|(2,[1],[1.0])|
|771431|1494153867| 13|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|707120|1494220810| 13|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
|530454|1494293746| 13|430548_1007| 1| 0| 0.0|(2,[0],[1.0])|
+------+----------+---------+-----------+------+---+-----------+-------------+
only showing top 20 rows
'pid和特征的对应关系\n430548_1007:0\n430549_1007:1\n'
- 从HDFS中加载广告基本信息数据
_ad_feature_df = spark.read.csv("hdfs://localhost:9000/datasets/ad_feature.csv", header=True)
# 更改表结构,转换为对应的数据类型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType
# 替换掉NULL字符串
_ad_feature_df = _ad_feature_df.replace("NULL", "-1")
# 更改df表结构:更改列类型和列名称
ad_feature_df = _ad_feature_df.\
withColumn("adgroup_id", _ad_feature_df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
withColumn("cate_id", _ad_feature_df.cate_id.cast(IntegerType())).withColumnRenamed("cate_id", "cateId").\
withColumn("campaign_id", _ad_feature_df.campaign_id.cast(IntegerType())).withColumnRenamed("campaign_id", "campaignId").\
withColumn("customer", _ad_feature_df.customer.cast(IntegerType())).withColumnRenamed("customer", "customerId").\
withColumn("brand", _ad_feature_df.brand.cast(IntegerType())).withColumnRenamed("brand", "brandId").\
withColumn("price", _ad_feature_df.price.cast(FloatType()))
ad_feature_df.printSchema()
ad_feature_df.show()
显示结果:
root
|-- adgroupId: integer (nullable = true)
|-- cateId: integer (nullable = true)
|-- campaignId: integer (nullable = true)
|-- customerId: integer (nullable = true)
|-- brandId: integer (nullable = true)
|-- price: float (nullable = true)
+---------+------+----------+----------+-------+-----+
|adgroupId|cateId|campaignId|customerId|brandId|price|
+---------+------+----------+----------+-------+-----+
| 63133| 6406| 83237| 1| 95471|170.0|
| 313401| 6406| 83237| 1| 87331|199.0|
| 248909| 392| 83237| 1| 32233| 38.0|
| 208458| 392| 83237| 1| 174374|139.0|
| 110847| 7211| 135256| 2| 145952|32.99|
| 607788| 6261| 387991| 6| 207800|199.0|
| 375706| 4520| 387991| 6| -1| 99.0|
| 11115| 7213| 139747| 9| 186847| 33.0|
| 24484| 7207| 139744| 9| 186847| 19.0|
| 28589| 5953| 395195| 13| -1|428.0|
| 23236| 5953| 395195| 13| -1|368.0|
| 300556| 5953| 395195| 13| -1|639.0|
| 92560| 5953| 395195| 13| -1|368.0|
| 590965| 4284| 28145| 14| 454237|249.0|
| 529913| 4284| 70206| 14| -1|249.0|
| 546930| 4284| 28145| 14| -1|249.0|
| 639794| 6261| 70206| 14| 37004| 89.9|
| 335413| 4284| 28145| 14| -1|249.0|
| 794890| 4284| 70206| 14| 454237|249.0|
| 684020| 6261| 70206| 14| 37004| 99.0|
+---------+------+----------+----------+-------+-----+
only showing top 20 rows
- 从HDFS加载用户基本信息数据
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, FloatType
# 构建表结构schema对象
schema = StructType([
StructField("userId", IntegerType()),
StructField("cms_segid", IntegerType()),
StructField("cms_group_id", IntegerType()),
StructField("final_gender_code", IntegerType()),
StructField("age_level", IntegerType()),
StructField("pvalue_level", IntegerType()),
StructField("shopping_level", IntegerType()),
StructField("occupation", IntegerType()),
StructField("new_user_class_level", IntegerType())
])
# 利用schema从hdfs加载
_user_profile_df1 = spark.read.csv("hdfs://localhost:9000/datasets/user_profile.csv", header=True, schema=schema)
# user_profile_df.printSchema()
# user_profile_df.show()
'''对缺失数据进行特征热编码'''
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline
# 使用热编码转换pvalue_level的一维数据为多维,增加n-1个虚拟变量,n为pvalue_level的取值范围
# 需要先将缺失值全部替换为数值,便于处理,否则会抛出异常
from pyspark.sql.types import StringType
_user_profile_df2 = _user_profile_df1.na.fill(-1)
# _user_profile_df2.show()
# 热编码时,必须先将待处理字段转为字符串类型才可处理
_user_profile_df3 = _user_profile_df2.withColumn("pvalue_level", _user_profile_df2.pvalue_level.cast(StringType()))\
.withColumn("new_user_class_level", _user_profile_df2.new_user_class_level.cast(StringType()))
# _user_profile_df3.printSchema()
# 对pvalue_level进行热编码,求值
# 运行过程是先将pvalue_level转换为一列新的特征数据,然后对该特征数据求出的热编码值,存在了新的一列数据中,类型为一个稀疏矩阵
stringindexer = StringIndexer(inputCol='pvalue_level', outputCol='pl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pl_onehot_feature', outputCol='pl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_user_profile_df3)
_user_profile_df4 = pipeline_fit.transform(_user_profile_df3)
# pl_onehot_value列的值为稀疏矩阵,存储热编码的结果
# _user_profile_df4.printSchema()
# _user_profile_df4.show()
# 使用热编码转换new_user_class_level的一维数据为多维
stringindexer = StringIndexer(inputCol='new_user_class_level', outputCol='nucl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='nucl_onehot_feature', outputCol='nucl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_user_profile_df4)
user_profile_df = pipeline_fit.transform(_user_profile_df4)
user_profile_df.show()
显示结果:
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|pl_onehot_feature|pl_onehot_value|nucl_onehot_feature|nucl_onehot_value|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
| 234| 0| 5| 2| 5| -1| 3| 0| 3| 0.0| (4,[0],[1.0])| 2.0| (5,[2],[1.0])|
| 523| 5| 2| 2| 2| 1| 3| 1| 2| 2.0| (4,[2],[1.0])| 1.0| (5,[1],[1.0])|
| 612| 0| 8| 1| 2| 2| 3| 0| -1| 1.0| (4,[1],[1.0])| 0.0| (5,[0],[1.0])|
| 1670| 0| 4| 2| 4| -1| 1| 0| -1| 0.0| (4,[0],[1.0])| 0.0| (5,[0],[1.0])|
| 2545| 0| 10| 1| 4| -1| 3| 0| -1| 0.0| (4,[0],[1.0])| 0.0| (5,[0],[1.0])|
| 3644| 49| 6| 2| 6| 2| 3| 0| 2| 1.0| (4,[1],[1.0])| 1.0| (5,[1],[1.0])|
| 5777| 44| 5| 2| 5| 2| 3| 0| 2| 1.0| (4,[1],[1.0])| 1.0| (5,[1],[1.0])|
| 6211| 0| 9| 1| 3| -1| 3| 0| 2| 0.0| (4,[0],[1.0])| 1.0| (5,[1],[1.0])|
| 6355| 2| 1| 2| 1| 1| 3| 0| 4| 2.0| (4,[2],[1.0])| 3.0| (5,[3],[1.0])|
| 6823| 43| 5| 2| 5| 2| 3| 0| 1| 1.0| (4,[1],[1.0])| 4.0| (5,[4],[1.0])|
| 6972| 5| 2| 2| 2| 2| 3| 1| 2| 1.0| (4,[1],[1.0])| 1.0| (5,[1],[1.0])|
| 9293| 0| 5| 2| 5| -1| 3| 0| 4| 0.0| (4,[0],[1.0])| 3.0| (5,[3],[1.0])|
| 9510| 55| 8| 1| 2| 2| 2| 0| 2| 1.0| (4,[1],[1.0])| 1.0| (5,[1],[1.0])|
| 10122| 33| 4| 2| 4| 2| 3| 0| 2| 1.0| (4,[1],[1.0])| 1.0| (5,[1],[1.0])|
| 10549| 0| 4| 2| 4| 2| 3| 0| -1| 1.0| (4,[1],[1.0])| 0.0| (5,[0],[1.0])|
| 10812| 0| 4| 2| 4| -1| 2| 0| -1| 0.0| (4,[0],[1.0])| 0.0| (5,[0],[1.0])|
| 10912| 0| 4| 2| 4| 2| 3| 0| -1| 1.0| (4,[1],[1.0])| 0.0| (5,[0],[1.0])|
| 10996| 0| 5| 2| 5| -1| 3| 0| 4| 0.0| (4,[0],[1.0])| 3.0| (5,[3],[1.0])|
| 11256| 8| 2| 2| 2| 1| 3| 0| 3| 2.0| (4,[2],[1.0])| 2.0| (5,[2],[1.0])|
| 11310| 31| 4| 2| 4| 1| 3| 0| 4| 2.0| (4,[2],[1.0])| 3.0| (5,[3],[1.0])|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
only showing top 20 rows
- 热编码中:"pvalue_level"特征对应关系:
+------------+----------------------+
|pvalue_level|pl_onehot_feature |
+------------+----------------------+
| -1| 0.0|
| 3| 3.0|
| 1| 2.0|
| 2| 1.0|
+------------+----------------------+
- “new_user_class_level”的特征对应关系
+--------------------+------------------------+
|new_user_class_level|nucl_onehot_feature |
+--------------------+------------------------+
| -1| 0.0|
| 3| 2.0|
| 1| 4.0|
| 4| 3.0|
| 2| 1.0|
+--------------------+------------------------+
user_profile_df.groupBy("pvalue_level").min("pl_onehot_feature").show()
user_profile_df.groupBy("new_user_class_level").min("nucl_onehot_feature").show()
显示结果:
+------------+----------------------+
|pvalue_level|min(pl_onehot_feature)|
+------------+----------------------+
| -1| 0.0|
| 3| 3.0|
| 1| 2.0|
| 2| 1.0|
+------------+----------------------+
+--------------------+------------------------+
|new_user_class_level|min(nucl_onehot_feature)|
+--------------------+------------------------+
| -1| 0.0|
| 3| 2.0|
| 1| 4.0|
| 4| 3.0|
| 2| 1.0|
+--------------------+------------------------+
- Dataframe数据合并:pyspark.sql.DataFrame.join
# raw_sample_df和ad_feature_df合并条件
condition = [raw_sample_df.adgroupId==ad_feature_df.adgroupId]
_ = raw_sample_df.join(ad_feature_df, condition, 'outer')
# _和user_profile_df合并条件
condition2 = [_.userId==user_profile_df.userId]
datasets = _.join(user_profile_df, condition2, "outer")
# 查看datasets的结构
datasets.printSchema()
# 查看datasets条目数
print(datasets.count())
显示结果:
root
|-- userId: integer (nullable = true)
|-- timestamp: long (nullable = true)
|-- adgroupId: integer (nullable = true)
|-- pid: string (nullable = true)
|-- nonclk: integer (nullable = true)
|-- clk: integer (nullable = true)
|-- pid_feature: double (nullable = true)
|-- pid_value: vector (nullable = true)
|-- adgroupId: integer (nullable = true)
|-- cateId: integer (nullable = true)
|-- campaignId: integer (nullable = true)
|-- customerId: integer (nullable = true)
|-- brandId: integer (nullable = true)
|-- price: float (nullable = true)
|-- userId: integer (nullable = true)
|-- cms_segid: integer (nullable = true)
|-- cms_group_id: integer (nullable = true)
|-- final_gender_code: integer (nullable = true)
|-- age_level: integer (nullable = true)
|-- pvalue_level: string (nullable = true)
|-- shopping_level: integer (nullable = true)
|-- occupation: integer (nullable = true)
|-- new_user_class_level: string (nullable = true)
|-- pl_onehot_feature: double (nullable = true)
|-- pl_onehot_value: vector (nullable = true)
|-- nucl_onehot_feature: double (nullable = true)
|-- nucl_onehot_value: vector (nullable = true)
26557961
- 训练CTRModel_Normal:直接将对应的特征的特征值组合成对应的特征向量进行训练
# 剔除冗余、不需要的字段
useful_cols = [
#
# 时间字段,划分训练集和测试集
"timestamp",
# label目标值字段
"clk",
# 特征值字段
"pid_value", # 资源位的特征向量
"price", # 广告价格
"cms_segid", # 用户微群ID
"cms_group_id", # 用户组ID
"final_gender_code", # 用户性别特征,[1,2]
"age_level", # 年龄等级,1-
"shopping_level",
"occupation",
"pl_onehot_value",
"nucl_onehot_value"
]
# 筛选指定字段数据,构建新的数据集
datasets_1 = datasets.select(*useful_cols)
# 由于前面使用的是outer方式合并的数据,产生了部分空值数据,这里必须先剔除掉
datasets_1 = datasets_1.dropna()
print("剔除空值数据后,还剩:", datasets_1.count())
显示结果:
剔除空值数据后,还剩: 25029435
- 根据特征字段计算出特征向量,并划分出训练数据集和测试数据集
from pyspark.ml.feature import VectorAssembler
# 根据特征字段计算特征向量
datasets_1 = VectorAssembler().setInputCols(useful_cols[2:]).setOutputCol("features").transform(datasets_1)
# 训练数据集: 约7天的数据
train_datasets_1 = datasets_1.filter(datasets_1.timestamp<=(1494691186-24*60*60))
# 测试数据集:约1天的数据量
test_datasets_1 = datasets_1.filter(datasets_1.timestamp>(1494691186-24*60*60))
# 所有的特征的特征向量已经汇总到在features字段中
train_datasets_1.show(5)
test_datasets_1.show(5)
显示结果:
+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
| timestamp|clk| pid_value| price|cms_segid|cms_group_id|final_gender_code|age_level|shopping_level|occupation|pl_onehot_value|nucl_onehot_value| features|
+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
|1494261938| 0|(2,[1],[1.0])| 108.0| 0| 11| 1| 5| 3| 0| (4,[0],[1.0])| (5,[1],[1.0])|(18,[1,2,4,5,6,7,...|
|1494261938| 0|(2,[1],[1.0])|1880.0| 0| 11| 1| 5| 3| 0| (4,[0],[1.0])| (5,[1],[1.0])|(18,[1,2,4,5,6,7,...|
|1494553913| 0|(2,[1],[1.0])|2360.0| 19| 3| 2| 3| 3| 0| (4,[1],[1.0])| (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494553913| 0|(2,[1],[1.0])|2200.0| 19| 3| 2| 3| 3| 0| (4,[1],[1.0])| (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494436784| 0|(2,[1],[1.0])|5649.0| 19| 3| 2| 3| 3| 0| (4,[1],[1.0])| (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
only showing top 5 rows
+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
| timestamp|clk| pid_value|price|cms_segid|cms_group_id|final_gender_code|age_level|shopping_level|occupation|pl_onehot_value|nucl_onehot_value| features|
+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
|1494677292| 0|(2,[1],[1.0])|176.0| 19| 3| 2| 3| 3| 0| (4,[1],[1.0])| (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494677292| 0|(2,[1],[1.0])|698.0| 19| 3| 2| 3| 3| 0| (4,[1],[1.0])| (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494677292| 0|(2,[1],[1.0])|697.0| 19| 3| 2| 3| 3| 0| (4,[1],[1.0])| (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494684007| 0|(2,[1],[1.0])|247.0| 18| 3| 2| 3| 3| 0| (4,[1],[1.0])| (5,[4],[1.0])|(18,[1,2,3,4,5,6,...|
|1494684007| 0|(2,[1],[1.0])|109.0| 18| 3| 2| 3| 3| 0| (4,[1],[1.0])| (5,[4],[1.0])|(18,[1,2,3,4,5,6,...|
+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
only showing top 5 rows
- 创建逻辑回归训练器,并训练模型:LogisticRegression、 LogisticRegressionModel
from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression()
# 设置目标字段、特征值字段并训练
model = lr.setLabelCol("clk").setFeaturesCol("features").fit(train_datasets_1)
# 对模型进行存储
model.save("hdfs://localhost:9000/models/CTRModel_Normal.obj")
# 载入训练好的模型
from pyspark.ml.classification import LogisticRegressionModel
model = LogisticRegressionModel.load("hdfs://localhost:9000/models/CTRModel_Normal.obj")
# 根据测试数据进行预测
result_1 = model.transform(test_datasets_1)
# 按probability升序排列数据,probability表示预测结果的概率
# 如果预测值是0,其概率是0.9248,那么反之可推出1的可能性就是1-0.9248=0.0752,即点击概率约为7.52%
# 因为前面提到广告的点击率一般都比较低,所以预测值通常都是0,因此通常需要反减得出点击的概率
result_1.select("clk", "price", "probability", "prediction").sort("probability").show(100)
显示结果:
+---+-----------+--------------------+----------+
|clk| price| probability|prediction|
+---+-----------+--------------------+----------+
| 0| 1.0E8|[0.86822033939259...| 0.0|
| 0| 1.0E8|[0.88410457194969...| 0.0|
| 0| 1.0E8|[0.89175497837562...| 0.0|
| 1|5.5555556E7|[0.92481456486873...| 0.0|
| 0| 1.5E7|[0.93741450446939...| 0.0|
| 0| 1.5E7|[0.93757135079959...| 0.0|
| 0| 1.5E7|[0.93834723093801...| 0.0|
| 0| 1099.0|[0.93972095713786...| 0.0|
| 0| 338.0|[0.93972134993018...| 0.0|
| 0| 311.0|[0.93972136386626...| 0.0|
| 0| 300.0|[0.93972136954393...| 0.0|
| 0| 278.0|[0.93972138089925...| 0.0|
| 0| 188.0|[0.93972142735283...| 0.0|
| 0| 176.0|[0.93972143354663...| 0.0|
| 0| 168.0|[0.93972143767584...| 0.0|
| 0| 158.0|[0.93972144283734...| 0.0|
| 1| 138.0|[0.93972145316035...| 0.0|
| 0| 125.0|[0.93972145987031...| 0.0|
| 0| 119.0|[0.93972146296721...| 0.0|
| 0| 78.0|[0.93972148412937...| 0.0|
| 0| 59.98|[0.93972149343040...| 0.0|
| 0| 58.0|[0.93972149445238...| 0.0|
| 0| 56.0|[0.93972149548468...| 0.0|
| 0| 38.0|[0.93972150477538...| 0.0|
| 1| 35.0|[0.93972150632383...| 0.0|
| 0| 33.0|[0.93972150735613...| 0.0|
| 0| 30.0|[0.93972150890458...| 0.0|
| 0| 27.6|[0.93972151014334...| 0.0|
| 0| 18.0|[0.93972151509838...| 0.0|
| 0| 30.0|[0.93980311191464...| 0.0|
| 0| 28.0|[0.93980311294563...| 0.0|
| 0| 25.0|[0.93980311449212...| 0.0|
| 0| 688.0|[0.93999362023323...| 0.0|
| 0| 339.0|[0.93999379960808...| 0.0|
| 0| 335.0|[0.93999380166395...| 0.0|
| 0| 220.0|[0.93999386077017...| 0.0|
| 0| 176.0|[0.93999388338470...| 0.0|
| 0| 158.0|[0.93999389263610...| 0.0|
| 0| 158.0|[0.93999389263610...| 0.0|
| 1| 149.0|[0.93999389726180...| 0.0|
| 0| 122.5|[0.93999391088191...| 0.0|
| 0| 99.0|[0.93999392296012...| 0.0|
| 0| 88.0|[0.93999392861375...| 0.0|
| 0| 79.0|[0.93999393323945...| 0.0|
| 0| 75.0|[0.93999393529532...| 0.0|
| 0| 68.0|[0.93999393889308...| 0.0|
| 0| 68.0|[0.93999393889308...| 0.0|
| 0| 59.9|[0.93999394305620...| 0.0|
| 0| 44.98|[0.93999395072458...| 0.0|
| 0| 35.5|[0.93999395559698...| 0.0|
| 0| 33.0|[0.93999395688189...| 0.0|
| 0| 32.8|[0.93999395698469...| 0.0|
| 0| 30.0|[0.93999395842379...| 0.0|
| 0| 28.0|[0.93999395945172...| 0.0|
| 0| 19.9|[0.93999396361485...| 0.0|
| 0| 19.8|[0.93999396366625...| 0.0|
| 0| 19.8|[0.93999396366625...| 0.0|
| 0| 12.0|[0.93999396767518...| 0.0|
| 0| 6.7|[0.93999397039920...| 0.0|
| 0| 568.0|[0.94000369247841...| 0.0|
| 0| 398.0|[0.94000377983931...| 0.0|
| 0| 158.0|[0.94000390317214...| 0.0|
| 0| 5718.0|[0.94001886593718...| 0.0|
| 0| 5718.0|[0.94001886593718...| 0.0|
| 1| 5608.0|[0.94001892245145...| 0.0|
| 0| 4120.0|[0.94001968693052...| 0.0|
| 0| 1027.5|[0.94002127571285...| 0.0|
| 0| 1027.5|[0.94002127571285...| 0.0|
| 0| 989.0|[0.94002129549211...| 0.0|
| 0| 672.0|[0.94002145834965...| 0.0|
| 0| 660.0|[0.94002146451460...| 0.0|
| 0| 598.0|[0.94002149636681...| 0.0|
| 0| 598.0|[0.94002149636681...| 0.0|
| 0| 563.0|[0.94002151434789...| 0.0|
| 0| 509.0|[0.94002154209012...| 0.0|
| 0| 509.0|[0.94002154209012...| 0.0|
| 0| 500.0|[0.94002154671382...| 0.0|
| 0| 498.0|[0.94002154774131...| 0.0|
| 0| 440.0|[0.94002157753851...| 0.0|
| 0| 430.0|[0.94002158267595...| 0.0|
| 0| 388.0|[0.94002160425322...| 0.0|
| 0| 369.0|[0.94002161401436...| 0.0|
| 0| 368.0|[0.94002161452811...| 0.0|
| 0| 368.0|[0.94002161452811...| 0.0|
| 0| 368.0|[0.94002161452811...| 0.0|
| 0| 368.0|[0.94002161452811...| 0.0|
| 0| 366.0|[0.94002161555560...| 0.0|
| 0| 366.0|[0.94002161555560...| 0.0|
| 0| 348.0|[0.94002162480299...| 0.0|
| 0| 299.0|[0.94002164997645...| 0.0|
| 0| 299.0|[0.94002164997645...| 0.0|
| 0| 299.0|[0.94002164997645...| 0.0|
| 0| 298.0|[0.94002165049020...| 0.0|
| 0| 297.0|[0.94002165100394...| 0.0|
| 0| 278.0|[0.94002166076508...| 0.0|
| 1| 275.0|[0.94002166230631...| 0.0|
| 0| 275.0|[0.94002166230631...| 0.0|
| 0| 273.0|[0.94002166333380...| 0.0|
| 0| 258.0|[0.94002167103995...| 0.0|
| 0| 256.0|[0.94002167206744...| 0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows
- 查看样本中点击的被实际点击的条目的预测情况
result_1.filter(result_1.clk==1).select("clk", "price", "probability", "prediction").sort("probability").show(100)
显示结果:
+---+-----------+--------------------+----------+
|clk| price| probability|prediction|
+---+-----------+--------------------+----------+
| 1|5.5555556E7|[0.92481456486873...| 0.0|
| 1| 138.0|[0.93972145316035...| 0.0|
| 1| 35.0|[0.93972150632383...| 0.0|
| 1| 149.0|[0.93999389726180...| 0.0|
| 1| 5608.0|[0.94001892245145...| 0.0|
| 1| 275.0|[0.94002166230631...| 0.0|
| 1| 35.0|[0.94002178560473...| 0.0|
| 1| 49.0|[0.94004219516957...| 0.0|
| 1| 915.0|[0.94021082858784...| 0.0|
| 1| 598.0|[0.94021099096349...| 0.0|
| 1| 568.0|[0.94021100633025...| 0.0|
| 1| 398.0|[0.94021109340848...| 0.0|
| 1| 368.0|[0.94021110877521...| 0.0|
| 1| 299.0|[0.94021114411869...| 0.0|
| 1| 278.0|[0.94021115487539...| 0.0|
| 1| 259.0|[0.94021116460765...| 0.0|
| 1| 258.0|[0.94021116511987...| 0.0|
| 1| 258.0|[0.94021116511987...| 0.0|
| 1| 258.0|[0.94021116511987...| 0.0|
| 1| 195.0|[0.94021119738998...| 0.0|
| 1| 188.0|[0.94021120097554...| 0.0|
| 1| 178.0|[0.94021120609778...| 0.0|
| 1| 159.0|[0.94021121583003...| 0.0|
| 1| 149.0|[0.94021122095226...| 0.0|
| 1| 138.0|[0.94021122658672...| 0.0|
| 1| 58.0|[0.94021126756458...| 0.0|
| 1| 49.0|[0.94021127217459...| 0.0|
| 1| 35.0|[0.94021127934572...| 0.0|
| 1| 25.0|[0.94021128446795...| 0.0|
| 1| 2890.0|[0.94028789742257...| 0.0|
| 1| 220.0|[0.94028926340218...| 0.0|
| 1| 188.0|[0.94031410659516...| 0.0|
| 1| 68.0|[0.94031416796289...| 0.0|
| 1| 58.0|[0.94031417307687...| 0.0|
| 1| 198.0|[0.94035413548387...| 0.0|
| 1| 208.0|[0.94039204931181...| 0.0|
| 1| 8888.0|[0.94045237642030...| 0.0|
| 1| 519.0|[0.94045664687995...| 0.0|
| 1| 478.0|[0.94045666780037...| 0.0|
| 1| 349.0|[0.94045673362308...| 0.0|
| 1| 348.0|[0.94045673413334...| 0.0|
| 1| 316.0|[0.94045675046144...| 0.0|
| 1| 298.0|[0.94045675964600...| 0.0|
| 1| 298.0|[0.94045675964600...| 0.0|
| 1| 199.0|[0.94045681016104...| 0.0|
| 1| 199.0|[0.94045681016104...| 0.0|
| 1| 198.0|[0.94045681067129...| 0.0|
| 1| 187.1|[0.94045681623305...| 0.0|
| 1| 176.0|[0.94045682189685...| 0.0|
| 1| 168.0|[0.94045682597887...| 0.0|
| 1| 160.0|[0.94045683006090...| 0.0|
| 1| 158.0|[0.94045683108140...| 0.0|
| 1| 158.0|[0.94045683108140...| 0.0|
| 1| 135.0|[0.94045684281721...| 0.0|
| 1| 129.0|[0.94045684587872...| 0.0|
| 1| 127.0|[0.94045684689923...| 0.0|
| 1| 125.0|[0.94045684791973...| 0.0|
| 1| 124.0|[0.94045684842999...| 0.0|
| 1| 118.0|[0.94045685149150...| 0.0|
| 1| 109.0|[0.94045685608377...| 0.0|
| 1| 108.0|[0.94045685659402...| 0.0|
| 1| 99.0|[0.94045686118630...| 0.0|
| 1| 98.0|[0.94045686169655...| 0.0|
| 1| 79.8|[0.94045687098314...| 0.0|
| 1| 79.0|[0.94045687139134...| 0.0|
| 1| 77.0|[0.94045687241185...| 0.0|
| 1| 72.5|[0.94045687470798...| 0.0|
| 1| 69.0|[0.94045687649386...| 0.0|
| 1| 68.0|[0.94045687700412...| 0.0|
| 1| 60.0|[0.94045688108613...| 0.0|
| 1| 43.98|[0.94045688926037...| 0.0|
| 1| 40.0|[0.94045689129118...| 0.0|
| 1| 39.9|[0.94045689134220...| 0.0|
| 1| 39.6|[0.94045689149528...| 0.0|
| 1| 32.0|[0.94045689537319...| 0.0|
| 1| 31.0|[0.94045689588345...| 0.0|
| 1| 25.98|[0.94045689844491...| 0.0|
| 1| 23.0|[0.94045689996546...| 0.0|
| 1| 19.0|[0.94045690200647...| 0.0|
| 1| 16.9|[0.94045690307800...| 0.0|
| 1| 10.0|[0.94045690659874...| 0.0|
| 1| 3.5|[0.94045690991538...| 0.0|
| 1| 3.5|[0.94045690991538...| 0.0|
| 1| 0.4|[0.94045691149716...| 0.0|
| 1| 3960.0|[0.94055740378069...| 0.0|
| 1| 3088.0|[0.94055784801535...| 0.0|
| 1| 1689.0|[0.94055856072019...| 0.0|
| 1| 998.0|[0.94055891273943...| 0.0|
| 1| 888.0|[0.94055896877705...| 0.0|
| 1| 788.0|[0.94055901972029...| 0.0|
| 1| 737.0|[0.94055904570133...| 0.0|
| 1| 629.0|[0.94055910071996...| 0.0|
| 1| 599.0|[0.94055911600291...| 0.0|
| 1| 599.0|[0.94055911600291...| 0.0|
| 1| 599.0|[0.94055911600291...| 0.0|
| 1| 499.0|[0.94055916694603...| 0.0|
| 1| 468.0|[0.94055918273839...| 0.0|
| 1| 459.0|[0.94055918732327...| 0.0|
| 1| 399.0|[0.94055921788912...| 0.0|
| 1| 399.0|[0.94055921788912...| 0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows