from pyspark.sql.functions import collect_list
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
from pyspark.sql.window import Window
import pyspark.sql.functions as F
window = Window.partitionBy("uid").orderBy(df["times"].desc())
df = df.withColumn('topn', F.row_number().over(window))
df = df.where(df['topn'] <= 200)
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession, Window
from pyspark.sql.types import IntegerType, StringType
from pyspark.sql.functions import udf, col
import pyspark.sql.functions as f
new_data = new_data.select('u_id', 'query', 'cate_query_freq', 'cate', 'prefer_probility', 'older', 'i_counts',
f.row_number().over(Window.partitionBy('u_id', 'cate') \
.orderBy(new_data['cate_query_freq'].desc())).alias('index'))
print(new_data.show())
new_data = new_data.where(new_data['index'] <= TOP_K / 10 * new_data['older'])