load_data.py
import os
import csv
import numpy as np
def load_csv(fileName, fileWithHeader=True):
with open(fileName, 'r') as f:
reader = csv.reader(f)
if fileWithHeader:
header = next(reader)
else:
header = []
data = [r for r in reader]
return header, data
num_items = 40
num_users = 200
predict_step = 3
least_rating_num = 5
current_path = os.path.dirname(os.path.abspath(__file__))
data_path = os.path.join(current_path, 'ml-latest-small')
headers, ratings = load_csv(os.path.join(data_path, 'new_Steam.csv'))
'''2022/5/16'''
item_frequent_dict = {}
for e in ratings:
item_frequent_dict[e[1]] = item_frequent_dict.get(e[1], 0) + 1
item_frequent_dict = sorted(item_frequent_dict.items(), key=lambda x: x[1], reverse=True)
item_id_list = [int(e[0]) for e in item_frequent_dict[:num_items]]
'''item_id_list
[356, 318, 296, 593, 2571, 260, 480, 110, 589, 527, 2959,
1, 1196, 50, 2858, 47, 780, 150, 1198, 4993, 1210, 858,
457, 592, 2028, 5952, 7153, 588, 608, 2762, 380, 32, 364,
1270, 377, 3578, 4306, 1580, 590, 648]
'''
user_id_list = sorted(set([e[0] for e in ratings]), key=lambda x: int(x))[:num_users]
ratings_dict = {e: [] for e in user_id_list}
counter = 0
for record in ratings:
if record[0] not in user_id_list or int(record[1]) not in item_id_list:
continue
counter += 1
''' print((int(record[1]))) record[1]指的是movieId
print(item_id_list.index((int(record[1]))))
index() 函数用于从列表中找出某个值第一个匹配项的索引位置
'''
ratings_dict[record[0]].append([item_id_list.index(int(record[1])), float(record[2])])
'''去掉时间戳''''''2022/5/16'''
"""
{'1': [[11, 4.0, 964982703], [15, 5.0, 964983815],...,'2' : [[1, 3.0,
'用户id':[[该movie在item_id_list中的索引位置, 评分, 评分时间],[],[],....]
"""
train_data = {}
test_data = {}
for user_id in ratings_dict:
if len(ratings_dict[user_id]) < least_rating_num:
continue
sorted_rate = sorted(ratings_dict[user_id], key=lambda x: x[-1], reverse=False)
train_data[user_id] = sorted_rate[:-3]
test_data[user_id] = sorted_rate[-3:]
'''
print(set([e for e in train_data]))
{'4', '5', '10', '8', '1', '6', '7'}
'''
user_id_list = sorted(set([e for e in train_data]), key=lambda x: int(x))
'''
print(user_id_list)
['1', '4', '5', '6', '7', '8', '10']
'''
print('Number of items', len(item_id_list))
print('Number of users', len(user_id_list))
print('Number of training ratings', np.sum([len(train_data[e]) for e in train_data]))
print('Number of testing ratings', np.sum([len(test_data[e]) for e in test_data]))