##部分截图(草稿,后续会删除)
args的各个参数值
# %% 3. Setup model
def get_task(envname, yfeat):
def setup_world_model
二、代码阅读
2.1.4设置世界模型(第三步)
task, task_logit_dim, is_ranking = get_task(args.env, args.yfeat)
获得了
ensemble_models = setup_world_model(args, x_columns, y_columns, ab_columns, task, task_logit_dim, is_ranking, MODEL_SAVE_PATH)
获得的 ensemble_models 是一堆路径和模型
对 user_model 的解释
以user_model的第一行为例,
她所代表的含义是
初始是通过传递参数以创建一个EnsembleModel
对象获得的
接着获取损失函数
这里选择的是负点对损失函数,其构成如下:
解释如下
y_weighted, loss_ab = process_logit(y_deepfm_pos, score, alpha_u=alpha_u, beta_i=beta_i, args=args)
的处理过程如下
此处对应了论文中的公式(10)
损失函数的输入:
y: 对应函数参数y,表示真实的目标值。
y_deepfm_pos: 对应函数参数y_deepfm_pos,表示DeepFM模型预测的正样本预测值。
y_deepfm_neg: 对应函数参数y_deepfm_neg,表示DeepFM模型预测的负样本预测值。
score: 对应函数参数score,表示用于计算损失函数的得分或权重。
alpha_u: 对应函数参数alpha_u,表示用于计算损失函数的用户相关参数。
beta_i: 对应函数参数beta_i,表示用于计算损失函数的物品相关参数。
args: 对应函数参数args,表示传递给损失函数的配置参数。
损失函数的输出:
总损失loss。
接着对 ensemble_models 进行编译
具体来说 compile() 方法会对每个模型(共5个)进行编译,编译的步骤如下:
以第一个模型的编译结果为例
ensemble_models的user_model的第一个模型中,loss_func的值为
functools.partial(<function loss_pointwise_negative at 0x0000019FE49D1E10>, args=Namespace(env='KuaiEnv-v0', resume=False, optimizer='adam', seed=2022, bpr_weight=0.5, neg_K=3, n_models=5, is_softmax=False, num_trajectory=200, force_length=10, top_rate=0.8, deterministic=True, draw_bar=False, is_all_item_ranking=False, loss='pointneg', rankingK=(20, 10, 5), max_turn=30, l2_reg_dnn=0.1, lambda_ab=10, epsilon=0, is_ucb=False, dnn_activation='relu', feature_dim=8, entity_dim=8, user_model_name='DeepFM', dnn=(128, 128), dnn_var=(), batch_size=4096, epoch=5, cuda=0, tau=0, is_ab=False, message='pointneg', all_item_ranking=False, leave_threshold=0, num_leave_compute=10, is_userinfo=False, is_binarize=False, need_transform=True, entropy_window=[1, 2], yfeat='watch_ratio_normed', use_userEmbedding=False))
也就是说,第一个模型的loss_func
是一个经过部分应用的损失函数,其中的参数通过args
对象进行绑定。ensemble_models的每一个模型均有一个loss_func
其中,loss_func的keywords中的‘args’具体为
表示的含义为
全部编译好后返回 ensemble_models
继续获取真实环境env,环境任务类env_task_class(未返回)和关键字参数字典kwargs_um
env:
kwargs_um
接着根据环境名称获取训练中的物品主导信息
从上至下,所示类别占据主导地位的强度逐渐下降,也就是说,类别28的主导性最强。也可以理解为,用户观看类别28的视频对最有利于获得更好的结果(比如对应的观看时长更长等)。
最后,再次编译 ensemble_models (利用.compile_RL_test()方法)为其设置评估函数
ensemble_models.compile_RL_test( functools.partial(test_static_model_in_RL_env, env=env, dataset_val=dataset_val, is_softmax=args.is_softmax, epsilon=args.epsilon, is_ucb=args.is_ucb, need_transform=args.need_transform, num_trajectory=args.num_trajectory, item_feat_domination=item_feat_domination, force_length=args.force_length, top_rate=args.top_rate))
循环编译5个模型
对于第一个模型
test_static_model_in_RL_env所在位置如下:
处理好后,第三步完成
user_models的第一行如下:
2.1.5学习并评估模型(第五步)
history_list = ensemble_models.fit_data(dataset_train, dataset_val, batch_size=args.batch_size, epochs=args.epoch, shuffle=True, callbacks=[LoggerCallback_Update(logger_path)])
对.fit_data方法的描述
此处因突然开始训练并报错:BufferError: memoryview has 1 exported buffer,工作暂时中断。