全局变量None

代码
_user_input = None
_item_input_pos = None
_batch_size = 512
_index = None

_sess = None
_dataset = None
_K = None
_feed_dict = None
_output = None
#---------- data preparation -------
# data sampling and shuffling

# input: dataset(Mat, List, Rating, Negatives), batch_choice
# output: [_user_input_list, _item_input_pos_list]
def sampling(dataset):
    l_user_input,l_item_input_pos = [], [] #这里的数据类型是list
    for (u, i) in dataset.trainMatrix.keys():
        # positive instance
        l_user_input.append(u)
        l_item_input_pos.append(i)
    return l_user_input, l_item_input_pos

def shuffle(samples, batch_size, dataset):
    
    global _user_input
    global _item_input_pos
    global _batch_size
    global _index
    global _model
    global _dataset
    _user_input, _item_input_pos = samples
    _batch_size = batch_size
    _index = np.arange(len(_user_input))
    
    _dataset = dataset
    np.random.shuffle(_index)
    num_batch = len(_user_input) // _batch_size
    pool = Pool(cpu_count())
    res = pool.map(_get_train_batch, range(num_batch))
    pool.close()
    pool.join()
    user_list = [r[0] for r in res]
    item_pos_list = [r[1] for r in res]
    user_dns_list = [r[2] for r in res]
    item_dns_list = [r[3] for r in res]
    return user_list, item_pos_list, user_dns_list, item_dns_list

def _get_train_batch(i):
    user_batch, item_batch = [], []
    user_neg_batch, item_neg_batch = [], []
    begin =i*_batch_size
    for idx in range(begin, begin +_batch_size):
        user_batch.append(_user_input[_index[idx]])
        item_batch.append(_item_input_pos[_index[idx]])
        for dns in range(_model.dns):
            user = _user_input[_index[idx]]
            user_neg_batch.append(user)
            # negtive k
            gtItem = _dataset.testRatings[user][1]
            j = np.random.randint(_dataset.num_items)
            while j in _dataset.trainList[_user_input[_index[idx]]]:
                j = np.random.randint(_dataset.num_items)
            item_neg_batch.append(j)
    return np.array(user_batch)[:,None], np.array(item_batch)[:,None], \
           np.array(user_neg_batch)[:,None], np.array(item_neg_batch)[:,None]
def train():
    samples=sampling(dataset)
    batches=shuffle(samples,args.batch_size,dataset)
    return batches
    
if __name__=='__main__':
    args = parse_args()
    dataset = Dataset(args.path + args.dataset)
    train()
    print(_user_input[90])
    print(_item_input_pos[90])

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值