python版本, lightgbm使用示例

1、安装lightgbm包,

pip install lightgbm -i https://pypi.tuna.tsinghua.edu.cn/simple --default-timeout=100

2、lightgbm原理:

https://www.cnblogs.com/jiangxinyang/p/9337094.html

 

3、lightgbm使用示例:

def train(x_train, y_train, q_train, model_save_path):
    '''
    模型的训练和保存
    :param x_train:
    :param y_train:
    :param q_train:
    :param model_save_path:
    :return:
    '''

    train_data = lgb.Dataset(x_train, label=y_train, group=q_train)
    params = {
        'task': 'train',  # 执行的任务类型
        'boosting_type': 'gbrt',  # 基学习器
        'objective': 'lambdarank',  # 排序任务(目标函数)
        'metric': 'ndcg',  # 度量的指标(评估函数)
        'max_position': 10,  # @NDCG 位置优化
        'metric_freq': 1,  # 每隔多少次输出一次度量结果
        'train_metric': True,  # 训练时就输出度量结果
        'ndcg_at': [10],
        'max_bin': 255,  # 一个整数,表示最大的桶的数量。默认值为 255。lightgbm 会根据它来自动压缩内存。如max_bin=255 时,则lightgbm 将使用uint8 来表示特征的每一个值。
        'num_iterations': 200,  # 迭代次数,即生成的树的棵数
        'learning_rate': 0.01,  # 学习率
        'num_leaves': 31,  # 叶子数
        # 'max_depth':6,
        'tree_learner': 'serial',  # 用于并行学习,‘serial’: 单台机器的tree learner
        'min_data_in_leaf': 30,  # 一个叶子节点上包含的最少样本数量
        'verbose': 2  # 显示训练时的信息
    }
    gbm = lgb.train(params, train_data, valid_sets=[train_data])
    gbm.save_model(model_save_path)


def predict(x_test, comments, model_input_path):
    '''
     预测得分并排序
    :param x_test:
    :param comments:
    :param model_input_path:
    :return:
    '''

    gbm = lgb.Booster(model_file=model_input_path)  # 加载model

    ypred = gbm.predict(x_test)

    predicted_sorted_indexes = np.argsort(ypred)[::-1]  # 返回从大到小的索引

    t_results = comments[predicted_sorted_indexes]  # 返回对应的comments,从大到小的排序

    return t_results

def test_data_ndcg(model_path, test_path):
    '''
    评估测试数据的ndcg
    :param model_path:
    :param test_path:
    :return:
    '''

    with open(test_path, 'r', encoding='utf-8') as testfile:
        test_X, test_y, test_qids, comments = read_dataset(testfile)

    gbm = lgb.Booster(model_file=model_path)
    test_predict = gbm.predict(test_X)

    average_ndcg, _ = validate(test_qids, test_y, test_predict, 60)
    # 所有qid的平均ndcg
    print("all qid average ndcg: ", average_ndcg)
    print("job done!")

 

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

samoyan

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值