代码
_user_input = None
_item_input_pos = None
_batch_size = 512
_index = None
_sess = None
_dataset = None
_K = None
_feed_dict = None
_output = None
def sampling(dataset):
l_user_input,l_item_input_pos = [], []
for (u, i) in dataset.trainMatrix.keys():
l_user_input.append(u)
l_item_input_pos.append(i)
return l_user_input, l_item_input_pos
def shuffle(samples, batch_size, dataset):
global _user_input
global _item_input_pos
global _batch_size
global _index
global _model
global _dataset
_user_input, _item_input_pos = samples
_batch_size = batch_size
_index = np.arange(len(_user_input))
_dataset = dataset
np.random.shuffle(_index)
num_batch = len(_user_input) // _batch_size
pool = Pool(cpu_count())
res = pool.map(_get_train_batch, range(num_batch))
pool.close()
pool.join()
user_list = [r[0] for r in res]
item_pos_list = [r[1] for r in res]
user_dns_list = [r[2] for r in res]
item_dns_list = [r[3] for r in res]
return user_list, item_pos_list, user_dns_list, item_dns_list
def _get_train_batch(i):
user_batch, item_batch = [], []
user_neg_batch, item_neg_batch = [], []
begin =i*_batch_size
for idx in range(begin, begin +_batch_size):
user_batch.append(_user_input[_index[idx]])
item_batch.append(_item_input_pos[_index[idx]])
for dns in range(_model.dns):
user = _user_input[_index[idx]]
user_neg_batch.append(user)
gtItem = _dataset.testRatings[user][1]
j = np.random.randint(_dataset.num_items)
while j in _dataset.trainList[_user_input[_index[idx]]]:
j = np.random.randint(_dataset.num_items)
item_neg_batch.append(j)
return np.array(user_batch)[:,None], np.array(item_batch)[:,None], \
np.array(user_neg_batch)[:,None], np.array(item_neg_batch)[:,None]
def train():
samples=sampling(dataset)
batches=shuffle(samples,args.batch_size,dataset)
return batches
if __name__=='__main__':
args = parse_args()
dataset = Dataset(args.path + args.dataset)
train()
print(_user_input[90])
print(_item_input_pos[90])