首先读入的文件有movie和book和news三种,肯定会选择一种进行解析!
下面我们以movies作为样例分析:
本论文中使用的数据集是movie-1m
-
ratings.dat:
分别是用户::电影::评分::电影编号
-
item_index2entity_id_rehashed.txt文件:
内容:
import argparse
import numpy as np
RATING_FILE_NAME = dict({'movie': 'ratings.dat', 'book': 'BX-Book-Ratings.csv', 'news': 'ratings.txt'}) #定义字典 字典保存中保存的都是原始文件
SEP = dict({'movie': '::', 'book': ';', 'news': '\t'}) #定义的分隔符!
THRESHOLD = dict({'movie': 4, 'book': 0, 'news': 0}) #定义电影喜好的阈值吧
def read_item_index_to_entity_id_file(): #看名字:读item的索引转化为实体的id
file = '../data/' + DATASET + '/item_index2entity_id_rehashed.txt'
#../data/movie/item_index2entity_id_rehashed.txt
print('reading item index to entity id file: ' + file + ' ...')
i = 0
for line in open(file, encoding='utf-8').readlines():
item_index = line.strip().split('\t')[0]
satori_id = line.strip().split('\t')[1]
#返回字符列表,并获取第一个和第二个元素, 第一个元素是原item的索引,第二个元素是satori中实体的索引。 satori是微软的大型知识图谱。 具体看后面解析
item_index_old2new[item_index] = i # item 的旧的index转换为新的index
entity_id2index[satori_id] = i # 实体id转换为index
i += 1
def convert_rating():
file = '../data/' + DATASET + '/' + RATING_FILE_NAME[DATASET]
# '../data/movie/ratings.dat
print('reading rating file ...')
item_set = set(item_index_old2new.values()) # 将item新的index转化为集合
user_pos_ratings = dict() # 用户正样本的评分
user_neg_ratings = dict() # 用户负样本的评分
for line in open(file, encoding='utf-8').readlines()[1:]:
array = line.strip().split(SEP[DATASET]) #看上面,我们经过分割后得到四个元素
# remove prefix and suffix quotation marks for BX dataset
if DATASET == 'book':
array = list(map(lambda x: x[1:-1], array))
item_index_old = array[1] # 取的是第二个元素,item的旧index
if item_index_old not in item_index_old2new: # the item is not in the final item set
# 比较的是keys,不是values,item_index_old也是字符,查看评价的items是不是在我们记录的item_index中,如果不在直接终止
continue
item_index = item_index_old2new[item_index_old] #如果在,那么我们就赋值新的item_index
user_index_old = int(array[0]) # 获得user旧的id的index
rating = float(array[2]) #获得用户的电影评分
if rating >= THRESHOLD[DATASET]: #我们选取列表中所有大于阈值的评分
if user_index_old not in user_pos_ratings: #注意这里比较的是keys值
user_pos_ratings[user_index_old] = set()
# 积极评分的设置为set集合
user_pos_ratings[user_index_old].add(item_index) #list列表中添加用户旧的index
#并且添加了item新的index
else:
if user_index_old not in user_neg_ratings: #同样的道理,这里存储列表中小于阈值的评分
user_neg_ratings[user_index_old] = set()
user_neg_ratings[user_index_old].add(item_index)
print('converting rating file ...') #将用户的index转为新的
writer = open('../data/' + DATASET + '/ratings_final.txt', 'w', encoding='utf-8')
user_cnt = 0
user_index_old2new = dict()
for user_index_old, pos_item_set in user_pos_ratings.items():
if user_index_old not in user_index_old2new:
user_index_old2new[user_index_old] = user_cnt #记录user的总数
user_cnt += 1
user_index = user_index_old2new[user_index_old] #
for item in pos_item_set:
writer.write('%d\t%d\t1\n' % (user_index, item))
unwatched_set = item_set - pos_item_set
if user_index_old in user_neg_ratings:
unwatched_set -= user_neg_ratings[user_index_old]
for item in np.random.choice(list(unwatched_set), size=len(pos_item_set), replace=False):
writer.write('%d\t%d\t0\n' % (user_index, item))
writer.close()
print('number of users: %d' % user_cnt)
print('number of items: %d' % len(item_set))
def convert_kg(): #基本都是转变id的事
print('converting kg file ...')
entity_cnt = len(entity_id2index)
relation_cnt = 0
writer = open('../data/' + DATASET + '/kg_final.txt', 'w', encoding='utf-8')
files = []
if DATASET == 'movie':
files.append(open('../data/' + DATASET + '/kg_part1_rehashed.txt', encoding='utf-8'))
files.append(open('../data/' + DATASET + '/kg_part2_rehashed.txt', encoding='utf-8'))
else:
files.append(open('../data/' + DATASET + '/kg_rehashed.txt', encoding='utf-8'))
for file in files:
for line in file:
array = line.strip().split('\t')
head_old = array[0]
relation_old = array[1]
tail_old = array[2]
if head_old not in entity_id2index:
entity_id2index[head_old] = entity_cnt
entity_cnt += 1
head = entity_id2index[head_old]
if tail_old not in entity_id2index:
entity_id2index[tail_old] = entity_cnt
entity_cnt += 1
tail = entity_id2index[tail_old]
if relation_old not in relation_id2index:
relation_id2index[relation_old] = relation_cnt
relation_cnt += 1
relation = relation_id2index[relation_old]
writer.write('%d\t%d\t%d\n' % (head, relation, tail))
writer.close()
print('number of entities (containing items): %d' % entity_cnt)
print('number of relations: %d' % relation_cnt)
if __name__ == '__main__':
np.random.seed(555)
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--dataset', type=str, default='movie', help='which dataset to preprocess')
args = parser.parse_args()
DATASET = args.dataset
entity_id2index = dict()
relation_id2index = dict()
item_index_old2new = dict()
read_item_index_to_entity_id_file()
convert_rating()
convert_kg()
print('done')
补充:
1. line.strip.split(’\t’)
- 描述
Python strip() 方法用于移除字符串头尾指定的字符(默认为空格)或字符序列。
注意:该方法只能删除开头或是结尾的字符,不能删除中间部分的字符。 - 语法
strip()方法语法:
str.strip([chars]);
-
参数
chars – 移除字符串头尾指定的字符序列。 -
返回值
返回移除字符串头尾指定的字符序列生成的新字符串
2. split(’\t’)
已经在上个代码分析中讨论过了,这里只是简单说一下,它会返回字符列表!
- 源代码分析:
def read_item_index_to_entity_id_file():
file = '../data/movie/item_index2entity_id_rehashed.txt'
print('reading item index to entity id file: ' + file + ' ...')
i = 0
for line in open(file, encoding='utf-8').readlines():
i = i + 1
if i < 10:
print(line)
print(len(line))
print(line.strip())
print(len(line.strip()))
print(line.strip().split('\t'))
print(line.strip().split('\t')[0])
read_item_index_to_entity_id_file()
可以看出其一,如果只是输出一行的数据,长度为4,该字符串是"1 \t 0 空格" 多个空格为一个!
所以我们在获取一行数据的时候,要特别注意这些空格符(在首尾)、分隔符(在中间)!
最后split返回的是字符列表!
3. Set()集合
集合是为了啥? 关系运算啊! 并交差集
定义:
set() 函数创建一个无序不重复元素集,可进行关系测试,删除重复数据,还可以计算交集、差集、并集等。
注意是没有顺序,而是是不重复的集合!
返回值:
返回新的集合对象
实例:
>>>x = set('runoob')
>>> y = set('google')
>>> x, y
(set(['b', 'r', 'u', 'o', 'n']), set(['e', 'o', 'g', 'l'])) # 重复的被删除
>>> x & y # 交集
set(['o'])
>>> x | y # 并集
set(['b', 'e', 'g', 'l', 'o', 'n', 'r', 'u'])
>>> x - y # 差集
set(['r', 'b', 'u', 'n'])
>>>
4. “XXX” not in dict
比较的是keys,不是values; 如果字典中没有,那么就返回False,否则返回True。
配合的操作就是如果没有,那么就添加该key值!
源码举例:
if item_index_old not in item_index_old2new: # the item is not in the final item set
# 比较的是keys,不是values,item_index_old也是字符,查看评价的items是不是在我们记录的item_index中,如果不在直接终止
continue
item_index = item_index_old2new[item_index_old] #如果在,那么我们就赋值新的item_index