今天分享一个在用户规模很大时的模型训练更新代码,还有很多需要优化的地方
'''=================================================
@Function -> 用TensorFlow2实现协同过滤矩阵的分解
@Author :郭艳丹
@Date :2023-01-11
=================================================='''
import numpy as np
import tensorflow as tf
from keras.callbacks import ModelCheckpoint
def get_all_user_ids():
#TODO change your real data
return ["user1","user2","user3","user4","user5","user6"]
def get_all_item_ids():
# TODO change your real data
return ["item1","item2","item3","item4","item11","item22","item33","item44","item111","item222","item333","item444"]
def add_negtive_samples(item_ids:set,user_click_item_ids:set,total_len=20):
# 此处采集和正样本等数的负样本
candidate_set = list(item_ids - user_click_item_ids) # 热度采样
neg_list = np.random.choice(candidate_set, size=total_len-len(user_click_item_ids), replace=True) # 对于每个正样本,选择n个负样本
return
import pandas as pd
def get_pretrained_data():
#TODO 后续调整为tf.data
user_click_matrix = [("user1","item1"),
("user2","item3"),
("user1","item3"),
("user2","item2"),
("user4","item3"),
("user4","item33"),
("user3","item3"),
("user5","item3"),
("user5","item11"),
("user1","item3"),
("user1","item2"),
]
original_click_data = pd.DataFrame(user_click_matrix,columns=["user","item"])
original_click_data["target"] = 1
item_ids = get_all_item_ids()
user_ids = get_all_user_ids()
negtive_samples = np.random.choice(item_ids, size=original_click_data.shape[0],replace=True)
#负样本构造
negtive_click_datas = pd.DataFrame({"user":original_click_data["user"],"item": negtive_samples})
negtive_click_datas["target"] = 0
result = pd.concat([original_click_data,negtive_click_datas],ignore_index=True,sort=False)
return item_ids,user_ids, {"user":result["user"],"item":result["item"]},result["target"]
def als_by_batch_train():
item_ids,user_ids,train_data,target = get_pretrained_data()
#输入为(None,)表示输入的为一个一维的向量
user_input = tf.keras.layers.Input(shape=(None,),name="user",dtype=tf.string)
item_input = tf.keras.layers.Input(shape=(None,),name="item",dtype=tf.string)
user_string_lookup = tf.keras.layers.StringLookup(vocabulary=user_ids)(user_input)
item_string_lookup = tf.keras.layers.StringLookup(vocabulary=item_ids)(item_input)
user_embedding = tf.keras.layers.Embedding(len(user_ids)+5,64)(user_string_lookup)
item_embedding = tf.keras.layers.Embedding(len(item_ids)+5,64)(item_string_lookup)
cons_sim_result = tf.keras.layers.Dot(-1,normalize=True)([user_embedding,item_embedding])
model = tf.keras.Model(inputs=[user_input,item_input], outputs=cons_sim_result)
model.compile(optimizer="adam", loss=tf.keras.losses.MSE)
callbacks = [
ModelCheckpoint(filepath='base_path/als_models/'+'{epoch: 02d}.h5')
]
model.fit(train_data, target, batch_size=4, epochs=2, verbose=1,
validation_split=0.01,callbacks=callbacks)
return tf.keras.Model(inputs=[user_input], outputs=user_embedding),tf.keras.Model(inputs=[item_input], outputs=item_embedding)
def get_embedding(model,inputs):
inputs = tf.constant(inputs)
return model(inputs)
if __name__ == '__main__':
user_embedding_model,item_embedding_model = als_by_batch_train()
user_embedding_result = get_embedding(user_embedding_model,["user1","user2","user100"])
item_embeddings_result = get_embedding(item_embedding_model,["item2","item1","itemk"])
print(user_embedding_result)
print(item_embeddings_result)
其中,tf.keras.layers.Dot是对应的矩阵中的每一行对应相乘并求和,而且可以通过设置normalize可以在进行对应行相乘之前,对其进行行正则化,从而实现了余弦相似性的计算,相关代码已经开源到:sparkle_code_guy/rec_sys - 码云 - 开源中国 (gitee.com)