12.6 实现推荐模型
在项目中的 "src\models" 目录中的源代码主要用于实现不同的推荐模型和算法,用于实现推荐系统的各个组件。
12.6.1 实现商品推荐和排序
编写文件predict_model.py实现推荐系统的类 Recommendation,它提供了一系列方法来生成推荐候选项并进行排序。这个类的目的是根据用户的历史行为和商品信息,生成个性化的商品推荐列表。它使用了预先计算好的数据文件和训练好的模型来提高推荐的效果。文件predict_model.py的具体实现代码如下所示。
class Recommendation:
def __init__(self):
# 上一年度热门商品列表
self.popular_items_last_year_list = []
self.resnet_obj = None
self.ranking_obj = None
#self.customer_mapping = None
#self.article_mapping_obj = None
self.pairs_items_obj = None
self.items_paried_together_obj = None
self.transaction_data_obj = None
def popular_items_last_year(self, ntop=15):
"""
获取上一年度热门商品
参数:
- ntop: 返回的热门商品数量
返回:
- popular_items_last_year_list: 上一年度热门商品列表
"""
if self.pairs_items_obj is None:
POPULAR_ITEMS_LAST_YEAR = os.path.join(
hlpread.read_yaml_key(CONFIGURATION_PATH, 'model', 'output_folder'),
hlpread.read_yaml_key(CONFIGURATION_PATH, 'candidate-popular-items-last-year',
'popular-items-last-year-folder'),
hlpread.read_yaml_key(CONFIGURATION_PATH, 'candidate-popular-items-last-year',
'popular-items-last-year-output')
)
self.pairs_items_obj = hlpread.read_object(POPULAR_ITEMS_LAST_YEAR)
if len(self.popular_items_last_year_list) != ntop:
self.popular_items_last_year_list = list(self.pairs_items_obj.keys())[:ntop]
return self.popular_items_last_year_list
def items_paried_together(self, customer_id, popular_items_last_year_list, ntop=20):
"""
获取商品组合推荐
参数:
- customer_id: 用户ID
- popular_items_last_year_list: 上一年度热门商品列表
- ntop: 返回的商品组合推荐数量
返回:
- items_paried_together: 商品组合推荐列表
"""
customer_id = int(customer_id[-16:], 16)
# 候选1:上一年度热门商品的组合推荐
# 获取与上一年度购买的热门商品相关的前ntop个商品组合推荐
# items_paried_to_gether是一个键值对,每个键表示一个商品,值是用户购买该商品时一起购买的商品列表
if self.items_paried_together_obj is None:
ITEMS_PAIRED_TOGETHER = os.path.join(
hlpread.read_yaml_key(CONFIGURATION_PATH, 'model', 'output_folder'),
hlpread.read_yaml_key(CONFIGURATION_PATH, 'candidate-item-purchase-together',
'item-purchase-together-folder'),
hlpread.read_yaml_key(CONFIGURATION_PATH, 'candidate-item-purchase-together',
'item-purchase-together-output')
)
self.items_paried_together_obj = hlpread.read_object(ITEMS_PAIRED_TOGETHER)
items_paried_together = [self.items_paried_together_obj[x][:ntop] for x in popular_items_last_year_list]
items_paried_together = list(np.concatenate(items_paried_together).flat)
# 候选2:如果用户在过去一个月内进行了任何购买,推荐与这些商品一起购买的其他用户购买的商品
# t_dat = date.today()
if self.transaction_data_obj is None:
TRANSACTION_DATA = os.path.join(
hlpread.read_yaml_key(CONFIGURATION_PATH, 'data_source', 'data_folders'),
hlpread.read_yaml_key(CONFIGURATION_PATH, 'data_source', 'processed_data_folder'),
hlpread.read_yaml_key(CONFIGURATION_PATH, 'data_source', 'train_data')
)
self.transaction_data_obj = hlpread.read_from_parquet(TRANSACTION_DATA)
self.last_tran_date_record = self.transaction_data_obj.t_dat.max() # 这将在实际环境中替换为当前日期
article_user_purchase_last_4week_tran = self.transaction_data_obj[
(self.transaction_data_obj.t_dat >= (self.last_tran_date_record - timedelta(weeks=4))) &
(self.transaction_data_obj.customer_id == customer_id)
].article_id.unique()
# 如果用户在过去四周内没有交易记录,则仍会获取上一周销售的前n个热门商品作为候选项
item_paried_together_for_items_user_purchase_last_4week = []
if len(article_user_purchase_last_4week_tran) > 0:
item_paried_together_for_items_user_purchase_last_4week = [
self.items_paried_together_obj[x][:ntop] for x in article_user_purchase_last_4week_tran
]
item_paried_together_for_items_user_purchase_last_4week = list(
np.concatenate(item_paried_together_for_items_user_purchase_last_4week).flat)
# 候选3:上一周销售的热门商品
vc = self.transaction_data_obj[
self.transaction_data_obj.t_dat >= (self.last_tran_date_record - timedelta(weeks=1))].article_id.value_counts()
vc = vc.reset_index()
vc.rename(columns={'index': 'article_id', 'article_id': 'cnt'}, inplace=True)
vc = vc[~vc.article_id.isin(items_paried_together)]
vc = vc[~vc.article_id.isin(item_paried_together_for_items_user_purchase_last_4week)]
vc.sort_values(by='cnt', ascending=False, inplace=True)
top_items_sold_last_1week = list(vc[:ntop].article_id)
# 最终的候选项列表
items_paried_together = popular_items_last_year_list + items_paried_together + \
item_paried_together_for_items_user_purchase_last_4week + top_items_sold_last_1week
items_paried_together = list(set(items_paried_together)) # 列表中不包含重复项
return items_paried_together
def candinate_generation(self, customer_id):
"""
生成候选项
参数:
- customer_id: 用户ID
返回:
- candinate_lists: 候选项列表
"""
candinate_lists = []
# 上一年度热门商品同时出售的商品
candinate_lists = self.popular_items_last_year()
candinate_lists = self.items_paried_together(customer_id, candinate_lists)
return candinate_lists
def find_relevent_items_of_candinate(self, customer_id, candinate_items):
"""
查找候选项中的相关商品
参数:
- customer_id: 用户ID
- candinate_items: 候选项列表
返回:
- df_relevent_items: 相关商品的DataFrame,包括customer_id、article_id和y_hat
"""
if self.resnet_obj is None:
self.resnet_obj = resnet_based_prevence(False)
# 查找候选商品的相关得分/y_hat
df_relevent_items = self.resnet_obj.predict(customer_id, candinate_items, -1)
# customer_id, article_id, y_hat
return df_relevent_items
def rank_relevent_items(self, recommended_items):
"""
对相关商品进行排序
参数:
- recommended_items: 相关商品列表
返回:
- ranked_recommended_list: 排序后的推荐商品列表
"""
if self.ranking_obj is None:
saved_model = os.path.join(read_yaml_key(CONFIGURATION_PATH, 'model', 'output_folder'),
read_yaml_key(CONFIGURATION_PATH, 'lightgbm-param', 'ranking-model-output-folder'),
# read_yaml_key(CONFIGURATION_PATH,'lightgbm-param','ranking-model-5feature-folder'),
read_yaml_key(CONFIGURATION_PATH, 'lightgbm-param', 'saved_model')
)
saved_pipeline = os.path.join(read_yaml_key(CONFIGURATION_PATH, 'model', 'output_folder'),
read_yaml_key(CONFIGURATION_PATH, 'lightgbm-param',
'ranking-model-output-folder'),
read_yaml_key(CONFIGURATION_PATH, 'lightgbm-param', 'saved_engg_pipeline')
)
self.ranking_obj = ranking_model(saved_model, saved_pipeline, CONFIGURATION_PATH)
ranked_recommended_list = self.ranking_obj.predict(recommended_items)
return ranked_recommended_list
def predict(self, customer_id):
"""
预测推荐商品
参数:
- customer_id: 用户ID
返回:
- recommended_items: 推荐的商品列表
"""
recommended_items = []
# 第1步: 将customer_id转换为userid,即将customer_id转换为哈希等效值
# 在特征流水线中处理
# user_id = customer_id.apply(lambda x: int(x[-16:],16) ).astype('int64')
# if self.customer_mapping == None:
# self.customer_mapping = encoder_customer_userid()
# user_id = self.customer_mapping.transform(user_id)
# 第2步: 获取推荐的候选项
log.write_log(f'Get candinate for customer: {customer_id}...', log.logging.DEBUG)
candinate_lists = self.candinate_generation(customer_id)
# 第3步: 从候选项中找到相关商品
log.write_log(f'Find relevent items from shortlisted candinate for customer: {customer_id}...', log.logging.DEBUG)
relevent_items = self.find_relevent_items_of_candinate(customer_id, candinate_lists)
# 第4步: 对最终的推荐商品进行排序
log.write_log(f'Rank the relevent items from shortlisted candinate for customer: {customer_id}...', log.logging.DEBUG)
recommended_items = self.rank_relevent_items(relevent_items)
# 第5步: 将item_id转换为article_id
# 排序模型将返回带有article_ids和item_ids的数据帧
# if self.article_mapping_obj == None:
# self.article_mapping_obj = encode_article_itemid()
# recommended_items = self.article_mapping_obj.inverse_transform(recommended_items)
# 第6步: 返回排名前15的商品
log.write_log(f'Return Top 10 relevent items to recommend for the customer: {customer_id}...', log.logging.DEBUG)
return recommended_items[:15]
- __init__(self): 初始化函数,初始化一些属性。
- popular_items_last_year(self, ntop=15): 获取去年热门商品列表,此函数会读取预先计算好的数据文件,并返回指定数量的热门商品列表。
- items_paried_together(self, customer_id, popular_items_last_year_list, ntop=20): 获取与指定用户和热门商品相关联的商品列表。此函数会读取预先计算好的数据文件,并返回与用户购买的热门商品相关联的商品列表。
- candinate_generation(self, customer_id): 生成推荐候选项列表。此函数会调用上述两个方法来获取热门商品列表和相关商品列表,并将它们合并为候选项列表。
- find_relevent_items_of_candinate(self, customer_id, candinate_items): 根据候选项列表找到相关的商品。此函数会使用一个名为 resnet_based_prevence 的模型进行预测,并返回包含相关商品的数据帧。
- rank_relevent_items(self, recommended_items): 对相关商品进行排序。此函数会使用一个名为 ranking_model 的模型对相关商品进行排序,并返回排序后的商品列表。
- predict(self, customer_id): 进行推荐的主要方法。此函数会调用上述方法来生成候选项列表、获取相关商品和对商品进行排序,并返回前15个排名最高的商品。
12.6.2 排序模型
在上面的商品推荐和排序文件predict_model.py中,调用了文件ranking_model.py中的功能模块。编写文件ranking_model.py实现排序模型,该模型用于根据商品特征和用户特征对推荐候选项进行排序。在此文件中创建了一个用于排名模型的特征处理和数据获取的类ranking_model,封装了特征处理流程和数据获取操作,方便在实际应用中使用。文件ranking_model.py的具体实现流程如下:
(1)编写初始化函数__init__(self, model_file, pipeline_file),用于创建ranking_model类的实例。在初始化时,可以传入已保存的模型文件和特征工程流水线文件的路径,并加载已训练的模型和特征工程流水线。对应的实现代码如下所示:
def __init__(self, saved_model, saved_pipeline, config_path = CONFIGURATION_PATH):
self.config_path = config_path
self.saved_model_filepath = saved_model
self.saved_pipeline_filepath = saved_pipeline
self.fited_pipeline = False
self.model_trained = False
self.fs = None
if os.path.exists(self.saved_model_filepath) == True:
self.model_trained = True
self.ranker_bst = lgb.Booster(model_file = self.saved_model_filepath)
if os.path.exists(self.saved_pipeline_filepath) == True:
self.fited_pipeline = True
self.feature_engg = read_object(self.saved_pipeline_filepath)
(2)创建初始化仓库路径的方init_repo_path(self, is_training),该方法用于设置数据仓库的路径,方便后续获取训练和在线数据的操作。对应的实现代码如下所示:
def init_repo_path(self, is_training):
if self.fs == None:
repo_path = os.path.join(
read_yaml_key(self.config_path, 'feature_store', 'feature_store_folder')
)
self.fs = FeatureStore(repo_path = repo_path)
if is_training == False:
fv = ['user_avg_median_purchase_price_fv',
'user_avg_median_purchase_price_last_8week_fv',
#'customer_elapsed_day_fv',
'item_previous_days_sales_count_fv',
'item_avg_sales_price_fv'
]
self.fs.materialize_incremental(end_date = datetime.utcnow() - timedelta(days = 1),
feature_views = fv
)
(3)编写方法get_train_user_features(self)获取训练用户的特征,该方法从数据仓库中获取训练数据中的用户特征,返回一个包含用户特征的列表。对应的实现代码如下所示:
def get_training_user_features(self, X):
if self.fs == None:
return X
X = self.fs.get_historical_features(
entity_df = X,
features = self.user_features_list
).to_df()
#X.columns
X.rename(columns = {'user_prev_median_purchase_price': 'user_last8week_median_purchase_price'},
inplace = True)
X = self.fs.get_historical_features(
entity_df = X,
features = self.user_features_list2
).to_df()
X.rename(columns = {'user_prev_median_purchase_price': 'user_overall_median_purchase_price'},
inplace = True)
return X
(4)编写方法get_train_item_features(self)获取训练商品的特征,该方法从数据仓库中获取训练数据中的商品特征,返回一个包含商品特征的列表。对应的实现代码如下所示:
def get_training_item_features(self, X):
X = self.fs.get_historical_features(
entity_df = X,
features = self.item_features_list
).to_df()
return X
(5)编写方法get_online_user_features(self)获取在线用户特征的方法,该方法从数据仓库中获取在线数据中的用户特征,返回一个包含用户特征的列表。对应的实现代码如下所示:
def get_online_user_features(self, X):
X_User_Elapsed_Feat = self.fs.get_online_features(
entity_rows = X[['customer_id']].to_dict(orient = 'records'),
features = self.user_features_list
).to_df()
X_User_Elapsed_Feat.rename(columns = {'user_median_purchase_price': 'user_last8week_median_purchase_price'},
inplace = True)
X_User_Median_Purchase_Feat = self.fs.get_online_features(
entity_rows = X[['customer_id']].to_dict(orient = 'records'),
features = self.user_features_list2
).to_df()
X_User_Median_Purchase_Feat.rename(columns = {'user_median_purchase_price': 'user_overall_median_purchase_price'},
inplace = True)
X_return = X_User_Median_Purchase_Feat.merge(X_User_Elapsed_Feat, on = ['customer_id'], how = 'inner')
X = X.merge(X_return, on = ['customer_id'], how = 'inner')
return X
(6)编写方法get_online_item_features(self)获取在线商品特征的方法,该方法从数据仓库中获取在线数据中的商品特征,返回一个包含商品特征的列表。对应的实现代码如下所示:
def get_online_item_features(self, X):
dt_items = X[['article_id']].drop_duplicates().reset_index(drop = True)
X_Item_Feat = self.fs.get_online_features(
entity_rows = dt_items.to_dict(orient = 'records'),
features = self.item_features_list,
).to_df()
X['prev_day_sales_cnt'] = X['sale_count']
X['item_prev_median_sales_price'] = X['item_median_sales_price']
X = X.merge(X_Item_Feat, on = ['article_id'], how = 'inner')
return X
(7)编写方法get_user_previous_purchase_details(self, X)获取用户之前的购买详情的方法,该方法根据给定的用户ID,在数据仓库中查询该用户的购买历史记录,并返回购买详情。对应的实现代码如下所示:
def get_user_previous_purchase_details(self, X):
USER_AVG_MEDIAN_PURCHASE_PRICE_LAST_8WEEK = os.path.join(
read_yaml_key(self.config_path,'data_source','data_folders'),
read_yaml_key(self.config_path,'data_source','feature_folder'),
read_yaml_key(self.config_path,'customer_features','customer_folder'),
read_yaml_key(self.config_path,'customer_features','user_avg_median_purchase_price_last_8week'),
)
user_purchase_detail = read_from_parquet(USER_AVG_MEDIAN_PURCHASE_PRICE_LAST_8WEEK)
X = X.merge(user_purchase_detail[[
'customer_id',
'user_median_purchase_price',
]],
on = ['customer_id'],
how = 'inner')
X.rename(columns = {'user_median_purchase_price': 'user_last8week_median_purchase_price'},
inplace = True)
del user_purchase_detail
gc.collect()
USER_AVG_MEDIAN_PURCHASE_PRICE = os.path.join(
read_yaml_key(self.config_path,'data_source','data_folders'),
read_yaml_key(self.config_path,'data_source','feature_folder'),
read_yaml_key(self.config_path,'customer_features','customer_folder'),
read_yaml_key(self.config_path,'customer_features','user_avg_median_purchase_price'),
)
user_purchase_detail = read_from_parquet(USER_AVG_MEDIAN_PURCHASE_PRICE)
X = X.merge(user_purchase_detail[[
'customer_id',
'user_median_purchase_price',
]],
on = ['customer_id'],
how = 'inner')
X.rename(columns = {'user_median_purchase_price': 'user_overall_median_purchase_price'},
inplace = True)
del user_purchase_detail
gc.collect()
return X
(8)编写方法get_items_previous_sales_details(self, X)获取商品之前的销售详情的方法,该方法根据给定的商品ID,在数据仓库中查询该商品的销售历史记录,并返回销售详情。对应的实现代码如下所示:
def get_items_previous_sales_details(self, X):
ALL_ITEM_SALES_COUNT = os.path.join(
read_yaml_key(self.config_path,'data_source','data_folders'),
read_yaml_key(self.config_path,'data_source','feature_folder'),
read_yaml_key(self.config_path,'article_feature','article_folder'),
#read_yaml_key(self.config_path,'article_feature','item_sales_count'),
read_yaml_key(self.config_path,'article_feature','item_prev_days_sales_count'),
)
item_sale_count = read_from_parquet(ALL_ITEM_SALES_COUNT)
X = X.merge(item_sale_count[[
'article_id',
#'prev_day_sales_cnt',
'prev_1w_sales_cnt',
'prev_year_sales_cnt',
'sale_count'
]],
on = ['article_id'],
how = 'inner')
del [item_sale_count]
gc.collect()
ITEM_AVG_SALES_PRICE = os.path.join(
read_yaml_key(self.config_path,'data_source','data_folders'),
read_yaml_key(self.config_path,'data_source','feature_folder'),
read_yaml_key(self.config_path,'article_feature','article_folder'),
read_yaml_key(self.config_path,'article_feature','item_avg_median_sales_price'),
)
item_median_sales_counts = read_from_parquet(ITEM_AVG_SALES_PRICE)
X = X.merge(item_median_sales_counts[[
'article_id',
'item_median_sales_price',
#'item_prev_mean_sales_price',
]],
on = ['article_id'],
how = 'inner')
X.rename(columns = { "sale_count": "prev_day_sales_cnt", "item_median_sales_price": "item_prev_median_sales_price"} , inplace = True)
del item_median_sales_counts
gc.collect()
return X
(9)编写方法feature_transform(self, X, is_training)实现特征转换,该方法接收一个特征数据的列表作为输入,使用预先定义的特征处理流水线对输入数据进行转换。转换包括序列编码和目标编码等操作,转换后的特征数据将作为模型的输入。对应的实现代码如下所示:
def feature_transform(self, X, is_training):
if self.fited_pipeline == False:
ordinal_encoder_columns = ['product_group_name']
ordinal_encoder_label_columns = ['label']
target_encode_columns = ['product_desc']
target_encode_label = ['label']
seed = 1001
q_list = [0.1, 0.25, 0.5, 0.75, 0.9]
quartile_features = {
'user_overall_median_purchase_price': {'groupby_col': 'product_desc', 'quartile_list': q_list},
'user_last8week_median_purchase_price': {'groupby_col': 'product_desc', 'quartile_list': q_list},
'item_prev_median_sales_price': {'groupby_col': 'product_desc', 'quartile_list': q_list},
'prev_day_sales_cnt': {'groupby_col': 'product_desc', 'quartile_list': q_list},
'prev_1w_sales_cnt': {'groupby_col': 'product_desc', 'quartile_list': q_list},
#'prev_2w_sales_cnt': {'groupby_col': 'product_desc', 'quartile_list': q_list},
#'prev_3w_sales_cnt': {'groupby_col': 'product_desc', 'quartile_list': q_list},
#'prev_4w_sales_cnt': {'groupby_col': 'product_desc', 'quartile_list': q_list},
}
bin_features = {
'color': [0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
#'avg_elapse_days_per_tran': [0.25, 0.5, 0.75, 0.9],
#'days_pass_since_last_purchase': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8 ,0.9],
}
self.feature_engg = Pipeline( steps = [
('transform_article_mapping', transform_article_mapping(config_path = self.config_path)),
('transform_customer_mapping', transform_customer_mapping(hash_conversion = False, config_path = self.config_path)),
('transform_color_rgb', transform_color_rgb()),
('merge_catagorical_feature', merge_catagorical_feature()),
('catagory_ordinal_encode', catagory_ordinal_encoder(ordinal_encoder_label_columns,
ordinal_encoder_columns)),
('catagory_leave_one_encoder', catagory_leave_one_out_encoder(target_encode_label,
target_encode_columns,
seed)),
('bin_feature_based_on_other_features', bin_feature_based_on_feature(quartile_features)),
('bin_feature', bin_feature(bin_features))
]
,verbose = True
)
X = self.feature_engg.fit_transform(X)
save_object(self.saved_pipeline_filepath , self.feature_engg)
else:
X = self.feature_engg.transform(X)
return X
(10)编写方法get_features(self, X, is_training) 获取特征,该方法接收特征数据 X 和一个布尔值 is_training,用于指示是否为训练模式。在训练模式下,调用 get_items_features 方法和 get_user_features 方法获取商品特征和用户特征,并通过 feature_transform 方法对特征数据进行转换。最后返回转换后的特征数据。对应的实现代码如下所示:
def get_features(self, X, is_training):
X = self.get_items_features(X, is_training = is_training)
X = self.get_user_features(X, is_training = is_training)
X = self.feature_transform(X, is_training = is_training)
return X
(11)编写方法train_model(self, X)实现模型训练,该方法接收特征数据 X,调用 get_features 方法获取特征数据,并进行模型训练。首先,从特征数据中选择训练所需的特征列,并按照用户ID和标签进行排序。然后,计算每个用户的样本数,并将其作为 group 参数传递给 LightGBM 模型。接下来,将标签列转换为整数类型,并从特征数据中删除用户ID、商品ID和标签列,得到训练数据集 ddf_x_train。最后,使用 LightGBM 模型训练训练数据集,并将训练好的模型保存。对应的实现代码如下所示:
def train_model(self, X):
try:
X = self.get_features(X, True)
x_train = X[['label',
'item_id', 'user_id',
'product_group_name_oce',
'product_desc_tce',
'user_overall_median_purchase_price_bin', #'item_median_sales_price_for_product_type_bin',
'user_last8week_median_purchase_price_bin',
'item_prev_median_sales_price_bin',
'prev_day_sales_cnt_bin',
'prev_1w_sales_cnt_bin',
'color_bin',
'graphical_appearance_no',
#'days_pass_since_last_purchase_bin',
]]
x_train = x_train.sort_values(by = ['user_id', 'label'], ascending = [True, False] , na_position = 'first')
qids_train = x_train.groupby("user_id")["item_id"].count().reset_index()
qids_train.columns = ['user_id','cnt']
#qids_train = qids_train.sort_values('user_id', ascending = True).cnt.to_pandas().to_numpy() #Code when use cudf dataframe
qids_train = qids_train.sort_values('user_id', ascending = True).cnt.to_numpy() #Code when use pandas dataframe
# Relevance label for train
y_train = x_train['label'].astype(int)
# Keeping only the features on which we would train our model
ddf_x_train = x_train.drop(["user_id", "item_id", "label"], axis = 1) #, inplace = True
#ddf_x_train = lgb.Dataset(data = ddf_x_train.to_pandas(), label = y_train.to_pandas(), group = qids_train, free_raw_data = False) #Code when use cudf dataframe
ddf_x_train = lgb.Dataset(data = ddf_x_train, label = y_train, group = qids_train, free_raw_data = False) #Code when use pandas dataframe
param = read_yaml_key(self.config_path,'lightgbm-param','param')
self.ranker_bst = lgb.train(params = param,
num_boost_round = param['n_estimators'],
train_set = ddf_x_train,
keep_training_booster = True
)
self.ranker_bst.save_model(self.saved_model_filepath)
except Exception as e:
raise RecommendationException(e, sys) from e
(12)编写方法predict(self, X)实现模型预测,该方法接收特征数据 X,调用 get_features 方法获取特征数据,并使用训练好的 LightGBM 模型进行预测。首先,从特征数据中选择预测所需的特征列。然后,通过 self.ranker_bst.predict 方法使用训练好的模型对特征数据进行预测,得到预测结果。接着,将预测结果与articleID列合并,并按照预测结果的降序对数据进行排序。最后,返回包含articleID和预测结果的数据。对应的实现代码如下所示:
def predict(self, X):
X = self.get_features(X, False)
col = ['product_group_name_oce',
'product_desc_tce',
'user_overall_median_purchase_price_bin', #'item_median_sales_price_for_product_type_bin',
'user_last8week_median_purchase_price_bin',
'item_prev_median_sales_price_bin',
'prev_day_sales_cnt_bin',
'prev_1w_sales_cnt_bin',
'color_bin',
'graphical_appearance_no',
#'days_pass_since_last_purchase_bin',
]
X['rank'] = self.ranker_bst.predict(data = X[col])
X = X[['article_id', 'rank']]
X.sort_values('rank', ascending = False, inplace = True)
return X
上述方法的主要功能是获取和处理排名模型所需的特征数据,为模型训练和预测提供支持。另外,文件predict_model.py还用到了如下所示的自定义文件功能:
- src.models.pipeline.catagory_leave_one_out_encoder:用于类别 Leave-One-Out 编码的自定义管道。
- src.models.pipeline.bin_feature_based_on_feature:基于特征的二值化特征的自定义管道。
- src.models.pipeline.transform_customer_mapping:用于转换客户映射的自定义管道。
- src.models.pipeline.merge_catagorical_feature:用于合并类别特征的自定义管道。
- src.models.pipeline.transform_article_mapping:用于转换文章映射的自定义管道。
- src.models.pipeline.catagory_ordinal_encoder:用于类别序数编码的自定义管道。
- src.models.pipeline.transform_color_rgb:用于转换颜色 RGB 的自定义管道。