Deep Interest Network代码讲解

本文档详细介绍了如何实现Deep Interest Network。提供了代码仓库链接,包含1_convert_pd.py、2_remap_id.py、build_dataset.py和train.py等步骤。数据集来源于Stanford SNAP,包括reviews_Electronics和meta_Electronics。在理解代码前,需要将数据解压并放入raw_data目录。首先运行数据转换脚本,接着进行ID映射,之后构建数据集,最后启动训练过程。
摘要由CSDN通过智能技术生成

代码链接:https://github.com/zhougr1993/DeepInterestNetwork
论文数据下载:http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Electronics_5.json.gz
http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Electronics.json.gz

在代码DeepInterestNetwork文件夹下,新建一个raw_data,将下载好的数据解压,然后放到raw_data中即可

首先运行1_convert_pd.py

import pickle
import pandas as pd
# json格式转化为pandas的dataframe格式,并保存为pickle二进制文件格式。解释一下为什么要保存pickle文件格式,因为pickle文件是二进制形式,读取速度快。
'''
(1)将reviews_Electronics_5.json转换成dataframe,列分别为reviewID ,asin, reviewerName等,
(2)将meta_Electronics.json转成dataframe,并且只保留在reviewes文件中出现过的商品,去重。
(3)转换完的文件保存成pkl格式。
'''
def to_df(file_path):
  with open(file_path, 'r') as fin:
    df = {
   }
    i = 0
    for line in fin:
      df[i] = eval(line)
      i += 1
    df = pd.DataFrame.from_dict(df, orient='index')
    print(df)
    return df

reviews_df = to_df('../raw_data/reviews_Electronics_5.json')
with open('../raw_data/reviews.pkl', 'wb') as f:
  #pickle.dump(obj, file, [,protocol])将对象obj保存到文件file中去。protocol为序列化使用的协议版本
  pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL)

meta_df = to_df('../raw_data/meta_Electronics.json')
meta_df = meta_df[meta_df['asin'].isin(reviews_df['asin'].unique())]
#reviews_df['asin'].unique() 以数组形式(numpy.ndarray)返回列的所有唯一值(特征的所有唯一值)
meta_df = meta_df.reset_index(drop=True)#重置索引
# print(meta_df)
with open('../raw_data/meta.pkl', 'wb') as f:
  pickle.dump(meta_df, f, pickle.HIGHEST_PROTOCOL)

然后,2_remap_id.py文件

import random
import pickle
import numpy as np
#将asin,categories,reviewerID三个字段进行位置编码。位置编码主要通过build_map。
#特别解读一下build_map函数的作用,就是讲id排序,并转换成对应的位置索引
'''
(1)将reviews_df只保留reviewerID, asin, unixReviewTime三列;
(2)将meta_df保留asin, categories列,并且类别列只保留三级类目;(至此,用到的数据只设计5列,(reviewerID, asin, unixReviewTime),(asin, categories));

'''
random.seed(1234)

with open('../raw_data/reviews.pkl', 'rb') as f:
  reviews_df = pickle.load(f)
  reviews_df = reviews_df[['reviewerID', 'asin', 'unixReviewTime']]
with open('../raw_data/meta.pkl', 'rb') as f:
  meta_df = pickle.load(f)
  meta_df = meta_df[['asin', 'categories']]
  # 返回categories的最后一个类
  meta_df['categories'] = meta_df['categories'].map(lambda x: x[-1][-1])

def build_map(df, col_name):
  key = sorted(df[col_name].unique().tolist())
  # zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表(02131321,0)
  m = dict(zip(key, range(len(key))))
  df[col_name] = df[col_name].map(lambda x: m[x])
  return m, key
#(3)用asin,categories,reviewerID分别生产三个map(asin_map, cate_map, revi_map),key为对应的原始信息,
# value为按key排序后的index(从0开始顺序排序),然后将原数据的对应列原始数据转换成key对应的index;
asin_map, asin_key = build_map(meta_df, 'asin')
cate_map, cate_key = build_map(meta_df, 'categories')
revi_map, revi_key = build_map(reviews_df, 'reviewerID')
# print("asin_map:","\n",asin_map)
# print("cate_map:","\n",cate_map)
# print("revi_map:","\n",revi_map)
user_count, item_count, cate_count, example_count =\
    len(revi_map), len(asin_map), len(cate_map), reviews_df.shape[0]
print('user_count: %d\titem_count: %d\tcate_count: %d\texample_count: %d' %
      (user_count, item_count, cate_count, example_count))
#(4)将meta_df按asin对应的index进行排序
meta_df = meta_df.sort_values('asin')
meta_df = meta_df.reset_index(drop=True)
# print(meta_df)
#(5)将reiviews_df中的asin转换成asin_map中asin对应的value值
reviews_df['asin'] = reviews_df['asin'].map(lambda x: asin_map[x])
#并且按照reviewerID和时间排序。
reviews_df = reviews_df.sort_values(['reviewerID', 'unixReviewTime'])
reviews_df = reviews_df.reset_index(drop=True)
reviews_df = reviews_df[['reviewerID', 'asin', 'unixReviewTime']]
# print(reviews_df)
#(6)生成cate_list, 就是把meta_df的'categories'列取出来。
cate_list = [meta_df['categories'][i] for i in range(len(asin_map))]
cate_list = np.array(cate_list, dtype=np.int32)
# print(cate_list)

with open('../raw_data/remap.pkl', 'wb') as f:
  #pickle.dump将对象obj保存到文件file中去
  pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL) # uid, iid
  pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL) # cid of iid line
  pickle.dump((user_count, item_count, cate_count, example_count),
              f, pickle.HIGHEST_PROTOCOL)
  pickle.dump((asin_key, cate_key, revi_key), f, pickle.HIGHEST_PROTOCOL)

第三,运行build_dataset.py

import random
import pickle

random.seed(1234)

with open('../raw_data/remap.pkl', 'rb') as f:
  reviews_df = pickle.load(f)
  cate_list = pickle.load(f)
  user_count, item_count, cate_count, example_count = pickle.load(f)
# pos_list(每个点击者点击的商品 ID 组成的 list)例如: [8] [9,6,4,5] [3] [8]
train_set = []
test_set = []
#(1)将reviews_df按reviewerID进行聚合
for reviewerID, hist in reviews_df.groupby('reviewerID'):
  # print(reviewerID)
  # print(hist)
  # (2)将hist的asin列作为每个reviewerID(也就是用户)的正样本列表(pos_list),注意这里的asin存的已经不是原始的item_id了,
  # 而是通过asin_map转换过来的index。负样本列表(neg_list)为在item_count范围内产生不在pos_list中的随机数列表。
  pos_list = hist['asin'].tolist()
  def gen_neg():
    # 取每个用户点击列表的第一个商品
    neg = pos_list[0]
    while neg in pos_list:
      # 随机初始化,即给点击者随机初始化一个商品,item_count-1 为商品数
      neg = random.randint(0, item_count-1)
    return neg
  neg_list = [gen_neg() for i in range(len(pos_list))]

  # 如果用户点击的商品数大于 1,则循环
  for i in range(1, len(pos_list)):
    hist = pos_list[:i]
    '''
        下面的 if 语句控制正负样本的个数和格式),例如某用户点击过 abcd 四个商品,
        则最终生成的样本为:(其中 X 为随机初始化的某商品 ID) 
        ((user_id,a,(b,1)) (user_id,a,(X,0)) (user_id,(a,b),(c,1)) 
        user_id,(a,b),(X,0)) (user_id,(a,b,c),(d,1)) (user_id,(a,b,c),(X,0))
        '''
    if i != len(pos_list) - 1:
      train_set.append((reviewerID, hist, pos_list[i], 1))
      train_set.append((reviewerID, hist, neg_list[i], 0))
      # 验证集格式(user_id,a,(b,X))
    else:
      label = (pos_list[i], neg_list[i])
      test_set.append((reviewerID, hist, label))
#最终的数据集里点击商品数小于 1 的数据删除掉了
random.shuffle(train_set)
random.shuffle(test_set)

assert len(test_set) == user_count
# assert(len(test_set) + len(train_set) // 2 == reviews_df.shape[0])

with open('dataset.pkl', 'wb') as f:
  pickle.dump(train_set, f, pickle.HIGHEST_PROTOCOL)
  pickle.dump(test_set, f, pickle.HIGHEST_PROTOCOL)
  pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL)
  pickle.dump((user_count, item_count, cate_count), f, pickle.HIGHEST_PROTOCOL)

第四,运行din文件夹下的train.py

import os
import time
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值