数据读取
训练和验证集的划分
划分训练和验证集的原因是为了在线下验证模型参数的好坏,为了完全模拟测试集,我们这里就在训练集中抽取部分用户的所有信息来作为验证集。提前做训练验证集划分的好处就是可以分解制作排序特征时的压力,一次性做整个数据集的排序特征可能时间会比较长。
1
# all_click_df指的是训练集
2
# sample_user_nums 采样作为验证集的用户数量
3
def trn_val_split(all_click_df, sample_user_nums):
4
all_click = all_click_df
5
all_user_ids = all_click.user_id.unique()
6
7
# replace=True表示可以重复抽样,反之不可以
8
sample_user_ids = np.random.choice(all_user_ids, size=sample_user_nums, replace=False)
9
10
click_val = all_click[all_click['user_id'].isin(sample_user_ids)]
11
click_trn = all_click[~all_click['user_id'].isin(sample_user_ids)]
12
13
# 将验证集中的最后一次点击给抽取出来作为答案
14
click_val = click_val.sort_values(['user_id', 'click_timestamp'])
15
val_ans = click_val.groupby('user_id').tail(1)
16
17
click_val = click_val.groupby('user_id').apply(lambda x: x[:-1]).reset_index(drop=True)
18
19
# 去除val_ans中某些用户只有一个点击数据的情况,如果该用户只有一个点击数据,又被分到ans中,
20
# 那么训练集中就没有这个用户的点击数据,出现用户冷启动问题,给自己模型验证带来麻烦
21
val_ans = val_ans[val_ans.user_id.isin(click_val.user_id.unique())] # 保证答案中出现的用户再验证集中还有
22
click_val = click_val[click_val.user_id.isin(val_ans.user_id.unique())]
23
24
return click_trn, click_val, val_ans
获取历史点击和最后一次点击
1
# 获取当前数据的历史点击和最后一次点击
2
def get_hist_and_last_click(all_click):
3
all_click = all_click.sort_values(by=['user_id', 'click_timestamp'])
4
click_last_df = all_click.groupby('user_id').tail(1)
5
6
# 如果用户只有一个点击,hist为空了,会导致训练的时候这个用户不可见,此时默认泄露一下
7
def hist_func(user_df):
8
if len(user_df) == 1:
9
return user_df
10
else:
11
return user_df[:-1]
12
13
click_hist_df = all_click.groupby('user_id').apply(hist_func).reset_index(drop=True)
14
15
return click_hist_df, click_last_df
读取训练、验证及测试集¶
1
def get_trn_val_tst_data(data_path, offline=True):
2
if offline:
3
click_trn_data = pd.read_csv(data_path+'train_click_log.csv') # 训练集用户点击日志
4
click_trn_data = reduce_mem(click_trn_data)
5
click_trn, click_val, val_ans = trn_val_split(click_trn_data, sample_user_nums)
6
else:
7
click_trn = pd.read_csv(data_path+'train_click_log.csv')
8
click_trn = reduce_mem(click_trn)
9
click_val = None
10
val_ans = None
11
12
click_tst = pd.read_csv(data_path+'testA_click_log.csv')
return click_trn, click_val, click_tst, val_ans
读取召回列表
1
# 返回多路召回列表或者单路召回
2
def get_recall_list(save_path, single_recall_model=None, multi_recall=False):
3
if multi_recall:
4
return pickle.load(open(save_path + 'final_recall_items_dict.pkl', 'rb'))
5
6
if single_recall_model == 'i2i_itemcf':
7
return pickle.load(open(save_path + 'itemcf_recall_dict.pkl', 'rb'))
8
elif single_recall_model == 'i2i_emb_itemcf':
9
return pickle.load(open(save_path + 'itemcf_emb_dict.pkl', 'rb'))
10
elif single_recall_model == 'user_cf':
11
return pickle.load(open(save_path + 'youtubednn_usercf_dict.pkl', 'rb'))
12
elif single_recall_model == 'youtubednn':
13
return pickle.load(open(save_path + 'youtube_u2i_dict.pkl', rb')