tenrec mmoe实操

项目地址:https://github.com/yuangh-x/2022-NIPS-Tenrec

该项目包含多种任务数据集,以及详细清晰的代码,十分适合用于推荐项目练习使用,但是在运行该项目过程中,遇见了一些麻烦,现记录下来。

1.首先按照requirements.txt安装项目所需包。

2.下载数据集,项目目录下建立data文件夹,将相应数据放到该目录下,建立checkpoint目录,用来保存模型。

2.运转我的命令而不是项目中指示的命令

python main.py --task_name=mtl --seed=100 --model_name=mmoe --dataset_path='data/ctr_data_1M.csv' --train_batch_size=4096 --val_batch_size=4096 --test_batch_size=4096 --epochs=20 --lr=0.0001 --embedding_size=32 --mtl_task_num=2 --device='cpu'

这个命令的缺点就是项目使用CPU训练而不是GPU了,博主也暂时没有跑GPU版本,GPU可能会有不少问题。

注意:

运行中可以在代码中添加一些中间日志,比如在main.py相应位置添加print代码

    elif args.task_name == 'mtl':
        print("============data generate begin=============")
        train_dataloader, val_dataloader, test_dataloader, user_feature_dict, item_feature_dict = get_data(args)
        print("============data generate end===============")
        if args.mtl_task_num == 2:
            num_task = 2
        else:
            num_task = 1
        if args.model_name == 'esmm':
            model = ESMM(user_feature_dict, item_feature_dict, emb_dim=args.embedding_size, num_task=num_task)
        else:
            model = MMOE(user_feature_dict, item_feature_dict, emb_dim=args.embedding_size, device=args.device, num_task=num_task)
        mtlTrain(model, train_dataloader, val_dataloader, test_dataloader, args, train=False)

一个epoch需要时间较多,可以在trainer.py程序里添加一些中间日志


            print('train batch num:', len(train_loader))
            for idx, (x, y1, y2) in enumerate(train_loader):
                x, y1, y2 = x.to(device), y1.to(device), y2.to(device)
                predict = model(x)
                y_train_click_true += list(y1.squeeze().cpu().numpy())
                y_train_like_true += list(y2.squeeze().cpu().numpy())
                y_train_click_predict += list(predict[0].squeeze().cpu().detach().numpy())
                y_train_like_predict += list(predict[1].squeeze().cpu().detach().numpy())
                loss_1 = loss_function(predict[0], y1.unsqueeze(1).float())
                loss_2 = loss_function(predict[1], y2.unsqueeze(1).float())
                loss = loss_1 + loss_2
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += float(loss)
                count += 1
                if idx % 100 == 0:
                    print(f'trained {idx} batches')

再检查到程序运行无问题后,可以将中间日志删掉,以防影响查看训练指标。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值