【代码解析(6)】Secure Federated Matrix Factorization

load_data.py

import os
import csv
import numpy as np


def load_csv(fileName, fileWithHeader=True):
    with open(fileName, 'r') as f:
        reader = csv.reader(f)
        if fileWithHeader:
            header = next(reader)
        else:
            header = []
        data = [r for r in reader]
    return header, data


num_items = 40
num_users = 200

predict_step = 3
least_rating_num = 5  # 最少5个评分记录

current_path = os.path.dirname(os.path.abspath(__file__))

data_path = os.path.join(current_path, 'ml-latest-small')

headers, ratings = load_csv(os.path.join(data_path, 'new_Steam.csv'))
'''2022/5/16'''
# headers, ratings = load_csv(os.path.join(data_path, 'ratings_data.csv'))

# 让我来看看headers和ratings是什么
# print(headers)  # ['userId', 'movieId', 'rating', 'timestamp']
# print(ratings)  # ['474', '457', '5.0', '974667331']按照每一行输出所有评分记录

item_frequent_dict = {}
for e in ratings:
    item_frequent_dict[e[1]] = item_frequent_dict.get(e[1], 0) + 1
# print(item_frequent_dict)
# dict类型:键值对 movie_id: times 记录每部电影被多少次评分,按照电影id排序
item_frequent_dict = sorted(item_frequent_dict.items(), key=lambda x: x[1], reverse=True)
# print(item_frequent_dict)
# key=lambda x: x[1], reverse=True 表示按照次数从大到小排列

#  num_items = 40取前四十个
item_id_list = [int(e[0]) for e in item_frequent_dict[:num_items]]
# print(item_id_list)
# item_id_list列表类型存储movieId
'''item_id_list
[356, 318, 296, 593, 2571, 260, 480, 110, 589, 527, 2959, 
1, 1196, 50, 2858, 47, 780, 150, 1198, 4993, 1210, 858, 
457, 592, 2028, 5952, 7153, 588, 608, 2762, 380, 32, 364, 
1270, 377, 3578, 4306, 1580, 590, 648]
'''
# num_users=10
# print('这能输出什么?')
# print(set([e[0] for e in ratings]))输出集合类型,每个元素是用户的id
# sorted(set([e[0] for e in ratings]), key=lambda x: int(x))升序排列按照user_id呗?
user_id_list = sorted(set([e[0] for e in ratings]), key=lambda x: int(x))[:num_users]
# print(user_id_list)输出[1,2,3...10]

ratings_dict = {e: [] for e in user_id_list}
# 上面一句在构建dict,{'1': [], '2': [],...}
counter = 0
for record in ratings:
    # 一行一行搜索,有一个存在就执行下面的语句,左边not in 右边就不执行了
    # 左边not in则为true continue
    # 左边in则为false看右边如果是not in则continue,是in 则执行下面
    if record[0] not in user_id_list or int(record[1]) not in item_id_list:
        continue
    # 上面一句record[0]、[1]肯定会在user_id_list和item_id_list的吧
    counter += 1
    ''' print((int(record[1]))) record[1]指的是movieId
        print(item_id_list.index((int(record[1]))))
        index() 函数用于从列表中找出某个值第一个匹配项的索引位置
    '''
    # 每一行既包含user_id_list和item_id_list则加过来
    ratings_dict[record[0]].append([item_id_list.index(int(record[1])), float(record[2])])
    '''去掉时间戳''''''2022/5/16'''
    # ratings_dict[record[0]].append([item_id_list.index(int(record[1])), float(record[2])])

# print(counter) 137一共137行符合条件
# print(ratings_dict)
"""
{'1': [[11, 4.0, 964982703], [15, 5.0, 964983815],...,'2' : [[1, 3.0,
'用户id':[[该movie在item_id_list中的索引位置, 评分, 评分时间],[],[],....] 
"""
train_data = {}
test_data = {}

# print(ratings_dict)10个用户
for user_id in ratings_dict:
    # least_rating_num=5
    # 筛选之后还有130条,每个用户至少要有5个评分记录
    # 用户2(2),3(1),9(4)记录少于5条
    if len(ratings_dict[user_id]) < least_rating_num:
        continue

    # 按照评分时间升序排列
    # print(ratings_dict)
    sorted_rate = sorted(ratings_dict[user_id], key=lambda x: x[-1], reverse=False)
    # print('------------------------')
    # print(sorted_rate)
    # 这里的ratings_dict和sorted_rate数据一样
    # [20, 5.0, 964980499], [35, 5.0, 964980668],...]
    # print('***********')

    # 每个用户的记录除了最后三个记录,其他的作为训练数据
    train_data[user_id] = sorted_rate[:-3]
    # 每个用户的最后三个记录作为测试数据
    test_data[user_id] = sorted_rate[-3:]

    # 一共七个用户,每个用户最后三个作为训练数据,正好一共21条记录
# print(train_data)
# print(test_data)
'''
print(set([e for e in train_data]))
{'4', '5', '10', '8', '1', '6', '7'}
'''
user_id_list = sorted(set([e for e in train_data]), key=lambda x: int(x))

'''
print(user_id_list)
['1', '4', '5', '6', '7', '8', '10']
'''
# print(item_id_list)
# item_id_list存储movie的id
print('Number of items', len(item_id_list))
# user_id_list按照训练数据的标准取的user_id
print('Number of users', len(user_id_list))


print('Number of training ratings', np.sum([len(train_data[e]) for e in train_data]))
print('Number of testing ratings', np.sum([len(test_data[e]) for e in test_data]))

# print(ratings_dict)
# print('.....')
# print(train_data)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值