spark.sql(query).show() 显示SQL表格

query = """
SELECT 
ROW_NUMBER() OVER (ORDER BY time) AS row,
train_id, 
station, 
time, 
LEAD(time,1) OVER (ORDER BY time) AS time_next 
FROM schedule
"""
spark.sql(query).show()

aggregation聚合

获取单组聚合值,一二两行输出一样的结果:

# Give the identical result in each command
spark.sql('SELECT train_id, MIN(time) AS start FROM schedule GROUP BY train_id').show()
df.groupBy('train_id').agg({'time':'min'}).withColumnRenamed('min(time)', 'start').show()

+--------+-----+
|train_id|start|
+--------+-----+
| 217|6:06a|
| 324|7:59a|
+--------+-----+

获取多组聚合值:

spark.sql('SELECT train_id, MIN(time), MAX(time) FROM schedule GROUP BY train_id').show()
result = df.groupBy('train_id').agg({'time':'min', 'time':'max'})

第一行用的sql可以输出想要的max和min,第二行采用的dot notation只能得到max聚合值,没有min:
+--------+---------+---------+
|train_id|min(time)|max(time)|
+--------+---------+---------+
| 217| 6:06a| 6:59a|
| 324| 7:59a| 9:05a|
+--------+---------+---------+

+--------+---------+
|train_id|min(time)|
+--------+---------+
| 217| 6:06a|
| 324| 7:59a|
+--------+---------+

对于多组聚合,如果用dot notation的话应该这么写,但是就显得很冗长:

from pyspark.sql.functions import min, max, col
expr = [min(col("time")).alias('start'), max(col("time")).alias('end')]
dot_df = df.groupBy("train_id").agg(*expr)
dot_df.show()

+--------+-----+-----+
|train_id|start| end|
+--------+-----+-----+
| 217|6:06a|6:59a|
| 324|7:59a|9:05a|
+--------+-----+-----+

dot notation方法怎么在spark里实现窗口函数

SQL方法:

df = spark.sql("""
SELECT *, 
LEAD(time,1) OVER(PARTITION BY train_id ORDER BY time) AS time_next 
FROM schedule
""")

dot notation方法

from pyspark.sql import Window
from pyspark.sql.functions import lead
dot_df = df.withColumn('time_next', lead('time', 1)
        .over(Window.partitionBy('train_id')
        .orderBy('time')))

其中df.withColumn 表示在原df基础上增加一列,列名为'time_next'

lead、UNIX_TIMESTAMP求时间差

query = """
SELECT *, 
(UNIX_TIMESTAMP(lead(time, 1) over (partition BY train_id order BY time),'H:m') 
 - UNIX_TIMESTAMP(time, 'H:m'))/60 AS diff_min 
FROM schedule 
"""
sql_df = spark.sql(query)
sql_df.show()

UNIX_TIMESTAMP返回的单位是秒
注意: lead lag要求窗口内的数据应该是有序的,因此over括号里一定要有order by,否则会报错
over()表示 lag()与lead()操作的数据都在over()的范围内,他里面可以使用partition by 语句(用于分组) order by 语句(用于排序)。partition by a order by b表示以a字段进行分组,再以b字段进行排序,对数据进行查询