Xgboost模型评估

    # 4.1 模型评估

    """ 检查 XGBoost 模型是否正确,如不正确则抛出异常 """
    logger.info(' --- 检查 XGBoost 模型是否正确')

    if not xgboost_model:
        logger.error('XGBoost模型不存在,请检查!')
        raise ValueError("XGBoost模型不存在,请检查!")
    if not isinstance(xgboost_model, xgb.sklearn.XGBModel):
        logger.error('XGBoost模型不正确,请重新训练!')
        raise TypeError("XGBoost模型不正确,请重新训练!")
    predictions = xgboost_model.predict(x_test)
    mae = mean_absolute_error(y_test, predictions)
    rmse = np.sqrt(mean_squared_error(y_test, predictions))
    r_square_score = r2_score(y_test, predictions)
    evaluation_criteria = {
            'mae': mae,
            'rmse': rmse,
            'r2': r_square_score
        }
    logger.info('模型评估完成,请查看成员变量 evaluation_criteria')

    # 4.2 获取模型评估结果
    # evaluation_criteria = xgb_reg_helper.evaluation_criteria
    logger.info('模型评估效果: {}', evaluation_criteria)
    # 4.3 获取特征重要性
    #feature_importance = xgb_reg_helper.get_feature_importance()
    """ 返回特征重要性 """
    logger.info(' --- 返回特征重要性')
    #self._check_whether_xgb_model_correct()
    xgboost_model.feature_importances_
    logger.info('模型特征重要性: {}', xgboost_model.feature_importances_)
    # 4.4 绘制特征重要性图像
    # xgb_reg_helper.plot_feature_importance()

    xgb.plot_importance(xgboost_model).set_yticklabels(features_column_names)

    plt.show()


    # 5. 模型预测,预测各个输入回归值
    # (可提供 ndarray 形式入参;如不提供,则预测测试集)
    # predict_result = xgb_reg_helper.predict()
    # logger.info('预测测试集: {}', predict_result)

    """
            执行预测任务。对于分类器,预测结果为类别;对于回归器,预测结果为回归值
                可提供一维 np.ndarray 形式待预测数据,如不提供,则默认使用测试集数据进行预测

            Args:
                x_predict: 一维 np.ndarray 形式待预测数据
            """
    logger.info(' --- 执行预测任务。对于分类器,预测结果为类别;对于回归器,预测结果为回归值')

    # if x_predict:
    #     if not isinstance(x_predict, np.ndarray) or x_predict.ndim != 2:
    #         logger.error('请提供二维 np.ndarray 作为输出')
    #         raise TypeError("请提供二维 np.ndarray 作为输出")
    #     return self.xgboost_model.predict(x_predict)
    predict_result = xgboost_model.predict(x_test)
    logger.info('预测测试集: {}', predict_result)

  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值