eval.py
# coding:utf-8
import lightgbm as lgb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from woe.eval import eval_segment_metrics
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_auc_score
import sys
# THERSHOLD = 0.49
# THERSHOLD = 0.693920
THRESHOLD = float(sys.argv[1])
segment_cnt = 50
def assign_grp_idx(df):
total_sample_cnt = len(df)
grp_idx = 1
start_idx = 0
proba_descend_idx = np.argsort(y_pred_prob)[::-1]
segment_sample_cnt = int(len(y_pred_prob) / segment_cnt)
while start_idx < total_sample_cnt:
segment_idx_list = proba_descend_idx[start_idx: start_idx + segment_sample_cnt]
segment_sample_cnt = len(segment_idx_list)
df.loc[segment_idx_list, 'grp'] = grp_idx
grp_idx += 1
start_idx += segment_sample_cnt
def model_stat(train, test, df_opt, dump_report=False):
res = pd.DataFrame()
res['样本量'] = [len(train), len(test), len(df_opt), len(df_opt) / len(test), test['new_user'].sum(),
df_opt['new_user'].sum()]
# res['样本量'] = [len(train), len(test), len(df_opt), len(df_opt) / len(test), len(test['uid'].unique()),
# len(df_opt['uid'].unique())]
res['正样本量'] = [train['d7_open'].sum(), test['d7_open'].sum(), df_opt['d7_open'].sum(),
df_opt['d7_open'].sum() / test['d7_open'].sum(), None, None]
res.index = ['训练数据', '测试数据', '测试门槛召回数据', '占比', '原始打开用户', '优化打开用户']
if dump_report:
res.to_excel('../../reports/model_stat.xlsx', index=True)
return res
def plot_pr_curve(y_test, y_score):
precision, recall, _ = precision_recall_curve(y_test, y_score)
plt.step(recall, precision, color='b', alpha=0.2, where='post')
plt.fill_between(recall, precision, step='post', alpha=0.2, color='b')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 0.0175])
plt.xlim([0.0, 1.0])
plt.savefig('../../reports/figures/PR-curve.png')
if __name__ == '__main__':
print('THRESHOLD = %.4f' % THRESHOLD)
train = pd.read_csv('../../data/processed/lx_new_users_algo2_nf_train.tsv', sep='\t')
test = pd.read_csv('../../data/processed/lx_new_users_algo2_nf_test.tsv', sep='\t')
feature_names = ['up_cnt', 'if_333', 'open_1d_fuids_cnt', 'interact_hb_14d_fuids_rate',
'interact_hb_30d_fuids_ra