场景描述:一句话是一个文本序列,通常可以直接使用word2vec编码;类似的,用户浏览商城时,依先后顺序点击的物品id,也构成物品序列,可以对各个用户的物品序列用word2vec训练。模型训练完,对一个新来的物品id,可以用模型预测,预测的向量就是对物品的embedding向量。
%matplotlib inline
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
plt.rc('font', family='SimHei', size=13)
import os,gc,re,warnings,sys
warnings.filterwarnings("ignore")
path = './data/'
##### train
train_user_df = pd.read_csv(path+'underexpose_train/underexpose_user_feat.csv', names=['user_id','user_age_level','user_gender','user_city_level'])
train_item_df = pd.read_csv(path+'underexpose_train/underexpose_item_feat.csv')
train_click_0_df = pd.read_csv(path+'underexpose_train/underexpose_train_click-0.csv',names=['user_id','item_id','time'])
print(train_user_df.shape, train_item_df.shape, train_click_0_df.shape)
# (6789, 4) (108915, 257) (241784, 3)
接下来看看商品嵌入表示。这里使用word2vec进行构造,当然还可以尝试图嵌入等方式来提取嵌入表示。
tmp = train_click_0_df.sort_values('time') #按时间由小到大排序
# 提取用户点击序列(用户点击的物品id列表),并构成文本
doc = tmp.groupby(['user_id'])['item_id'].agg({list}).reset_index()['list'].values.tolist()
#[[],[],[]]每个子列表 对应一个用户的 所有点击物品列表
# 导入 Word2Vec
from gensim.models import Word2Vec
# 转为字符串型才能进行训练
for i in range(len(doc)):
doc[i] = [str(x) for x in doc[i]]
model = Word2Vec(doc, size=128, window=5, min_count=3, sg=0, hs=1, seed=2020)
#模型训练样本是同一个人点击的物品清单,就是极有可能的一起被查看的物品清单
# 训练结果提取
values = set(tmp['item_id'].values)
w2v=[]
for v in values:
try:
a = [int(v)]
a.extend(model[str(v)])
w2v.append(a) #w2v存每个物品的[item_id,模型预测的相近id清单]
except:
pass
out_df = pd.DataFrame(w2v) #预测每个item_id的相近item_id
out_df.columns = ['item_id'] + ['item_vec'+str(i) for i in range(128)]
out_df.head()
# 用户合并item id向量
tmpone = train_click_0_df[train_click_0_df['user_id']==5701]
tmpone = tmpone.merge(out_df, on='item_id', how='left')
nonull_tmp = tmpone[~tmpone['item_vec0'].isnull()]
# 可视化展示前后商品向量相似性
sim_list = []
for i in range(0, nonull_tmp.shape[0]-1):
emb1 = nonull_tmp.values[i][-128:]
emb2 = nonull_tmp.values[i+1][-128:]
sim_list.append(np.dot(emb1,emb2)/(np.linalg.norm(emb1)*(np.linalg.norm(emb2))))
sim_list.append(0)
plt.figure()
plt.figure(figsize=(10, 6))
fig = sns.lineplot(x=[i for i in range(len(sim_list))], y=sim_list)
for item in fig.get_xticklabels():
item.set_rotation(90)
plt.tight_layout()
plt.title('前后商品向量的相似性')
plt.show()