MeLU源码解读

github地址:https://github.com/hoyeoplee/MeLU

综述

  1. 以用户为单位封装support_set & query_set方便训练
  2. 引入MAML的思想,将参数更新分为分任务更新和元更新。代码实现上应用OrderedDict()辅助分任务更新
  3. 提出一种筛选新设备的商品推荐策略,item_score = popularity_value * discriminative_value

data_generation.py

def item_converting ()
	将item信息转成index格式:可看成one-hot
def user_converting ()
	将user信息转成index格式:可看成one-hot
def generate()
	分四种场景
	将每种场景下的每个 user 生成 support_set 和 query_set
	每行格式 : movie_i - user_fixed

dataset.py

class movielens_1m(object):
	def load(self):
		profile_data = pd.read_csv("users.dat")
		item_data = pd.read_csv("ratings.dat")
		score_data = pd.read_csv("movies_extrainfos.dat")
		

embeddings.py

class item(torch.nn.Module):
	将item各个属性(rate、genre、director、actors)emb成32维
	对genre、director、actors还要avg_pooling
	
class user(torch.nn.Module):
	将user各个属性(gender、age、occupation、area)emb成32维
	不需要avg_pooling,每种属性只有一个值

MeLU.py

class user_preference_estimator:
	构建有两个hidden_layer的MLP,输出二分类概率值
	
class MeLU(torch.nn.Module):
	self.model = user_preference_estimator(config)
	self.local_update_target_weight_name = ['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'linear_out.weight', 'linear_out.bias'] #分任务更新Loop:不更新emb层(torch.nn.Embedding)
	def store_parameters(self):
        self.keep_weight = deepcopy(self.model.state_dict()) # 存储分任务更新前的参数
        self.weight_name = list(self.keep_weight.keys()) #         分任务更新时排除emb参数
        self.weight_len = len(self.keep_weight) # for 逐参数写SGD
        self.fast_weights = OrderedDict() # 顺序存储多个SGD
    def forward(self, support_set_x, support_set_y, query_set_x, num_local_update):
    	grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True) # 关键create_graph=True
    	# 巧用fast_weights完成单个分任务上num_local_update轮更新
    def global_update(self, support_set_xs, support_set_ys, query_set_xs, query_set_ys, num_local_update):
    	batch_sz = len(support_set_xs) # 有多少个分任务
    	
    	for i in range(batch_sz):
            query_set_y_pred = self.forward(support_set_xs[i], support_set_ys[i], query_set_xs[i], num_local_update) # 每个分任务完成分任务更新
            loss_q = F.mse_loss(query_set_y_pred, query_set_ys[i].view(-1, 1))
            losses_q.append(loss_q)
        losses_q = torch.stack(losses_q).mean(0)
        self.meta_optim.zero_grad()
        losses_q.backward() 
        self.meta_optim.step() # 元更新:具体是梯度求和的意思
        self.store_parameters() # 对参数重新初始化
    # 完成evidence_candidate
    # 统计分任务(每个user)averaging Forbenius norm
    def get_weight_avg_norm(self, support_set_x, support_set_y, num_local_update)

model_training.py

def training()
	# 采样一个batch用户,执行 melu.global_update
	# 在执行num_epoch轮后,torch.save

main.py

# 数据封装与解析,调用 generate(master_path),pickle.load()
# 调用model_training.py中函数training(),完成训练
# 调用函数selection(),返回用户的候选商品

evidence_candidate.py

# 附加内容
def selection()
# 从不同用户的support_xs中计算每个商品的区分度和流行度,商品得分 = 区分度 * 流行度
# 相当于是不对用户个性化做的,只是从全量商品中统计哪些商品最好,可用于完全没有行为用户的冷启动。
  • 4
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值