TowardsDataScience 2023 博客中文翻译(二百五十八)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

使用卷积网络预测结核分枝杆菌的药物耐药性 — 论文评审

原文:towardsdatascience.com/predicting-drug-resistance-in-mycobacterium-tuberculosis-using-a-convolutional-network-paper-b5905e3e3977?source=collection_archive---------20-----------------------#2023-03-20

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Uri Almog

·

关注 发表在 Towards Data Science ·6 分钟阅读·2023 年 3 月 20 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:CDCUnsplash

神经网络可以提高对病原体药物耐药性的预测能力

在这篇文章中,我将回顾一篇关于医学研究与建模及机器学习之间接口的最新论文。论文Green, A.G., Yoon, C.H., Chen, M.L. 等. 卷积神经网络突显与抗微生物耐药性相关的结核分枝杆菌突变. Nat Commun 13, 3817 (2022). https://doi.org/10.1038/s41467-022-31236-0描述了两种培训神经网络模型以预测给定 M. tuberculosis (MTB)菌株对 13 种抗生素的耐药性的 Approaches。这种建模技术的优点是能够生成一个显著性图,突显出对预测影响最大的特征,从而解决了一些关于模型可解释性的担忧。

问题陈述

结核病(TB)是导致全球感染性病原体死亡的主要原因。其病原体 M. tuberculosis(或 MTB)正逐渐对抗生素产生耐药性——这一过程对公共卫生构成威胁。虽然对每个患者的 MTB 分离株进行一系列抗生素的耐药性实证测试可能是最准确的方法,但可能需要几周才能完成,并且无法及时治疗。分离株的分子诊断只需数小时或数天,但仅关注基因组序列中的特定位点。因此,学习表型(药物耐药性)与病原体基因型(被诊断位点的结构)之间依赖关系的机器学习模型可能提供所需的解决方案。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由Julia Koblitz拍摄,发布于Unsplash

单药和多药模型

作者描述了两种建模方法:第一种名为 SD-CNN(单药 CNN),训练 13 个不同的 CNN,每个 CNN 预测对不同药物的耐药性。第二种名为 MD-CNN(多药 CNN),同时预测对 13 种药物的耐药性。这种建模技术背后的见解是关于多任务学习的开创性研究(Caruana, R. 多任务学习. Mach. Learn. 28, 41–75 (1997)),该研究表明,与直觉相反,训练 CNN 同时执行不同任务,确实可以提高其在每个单独任务上的表现,前提是这些任务是相关的。这个结果的解释是,一个任务生成的特征对其他任务的表现是有利的(例如,通过路标检测的辅助任务训练自动驾驶汽车转向模型)。多任务学习在遗传学研究中的优势由Dobrescu, A., Giuffrida, M. V. & Tsaftaris, S. A. 以更少做更多:一种植物表型的多任务深度学习方法. Front. Plant Sci. 11*, 141 (2020)*证明。

模型输入

用于训练的数据是 10,201 个 M. tuberculosis 病原体分离株,这些分离株在 13 种抗生素上进行了耐药性测试。MD-CNN 的输入是一个 5x18x10,291 的数组,其中 5 是 4 种核苷酸的独热编码(腺嘌呤、胸腺嘧啶、鸟嘌呤、胞嘧啶和一个缺口字符),18 是 locus 索引(作者使用 18 个与药物耐药性相关的 loci),10,291 是最长 locus 的长度。locus(复数—loci)是染色体上的一个特定固定位置,其中存在特定基因或基因序列。locus 由其起始索引和结束索引定义,从约定的起点计算核苷酸。不同的 loci 具有不同的长度。

每个 13 个 SD-CNN 模型的输入包含 18 个 loci 中的一个子集,这些 loci 对该药物的抗性有已知影响。

模型输出

MD-CNN 模型的输出是一个 13 元素的向量(按抗结核药物索引),每个元素包含该菌株对该药物的抗性信心的 sigmoid 结果。SD-CNN 模型返回一个对应于该药物抗性信心的单一 sigmoid 值。

模型架构

该模型是一个 CNN,由 2 个 1-D 卷积和最大池化块组成,之后是 3 个全连接层。描述见图 1。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1 — MD-CNN 架构。Conv 1a 和 1b 的卷积核尺寸为 5x12 和 1x12。Conv 2a 和 2b 的卷积核尺寸为 1x3。最大池化层的形状为 1x3。所有步幅均为 1x1。所有激活函数均为 ReLU,输出层除外,输出层使用 sigmoid。每层的输出维度在图形表示下方给出。SD-CNN 模型与此图的不同之处在于它们的 locus 维度不是 18,输出维度是 1。图像由作者提供。

结果

SD-CNN 和 MD-CNN 模型在相互之间及与两个先前模型:Reg+L2 和 SOTA 模型 WDNN(Chen, M. L. 等. 超越多药耐药性:利用稀有变异通过机器和统计学习模型进行结核分枝杆菌耐药性预测。EBioMedicine 43, 356–369 (2019)*)的基准测试。基准测试使用了对训练集的 5 折交叉验证。

测试显示,MD-CNN 的表现与 WDNN(当前的 SOTA 模型,使用布尔编码已知突变的基因组序列作为输入。它被设计为多层感知器的组合,即不使用卷积)相当。MD-CNN 的平均 AUC 在一线药物上为 0.948(WDNN 为 0.960),在二线药物上为 0.912(WDNN 为 0.924)。SD-CNN 的准确率略低,两组药物的 AUC 均为 0.888。MD-CNN 和 SD-CNN 展示了对新数据的泛化能力,在一个单独收集的 12,848 个样本的测试集上取得了大致相同的 AUC)。— 有关模型的图形比较,请参见 原始论文

作者指出,MD-CNN 模型的敏感性高于 SD-CNN 模型(即药物耐药性的漏检率较低),而 SD-CNN 模型的特异性更高(即错误将样本分类为耐药药物的比例较低)。换句话说——MD-CNN 不那么保守,倾向于将更多案例分类为‘耐药’。

分析 SD-CNN 的性能时,作者检查了假阴性案例。在检查数据时,他们观察到具有相同模型输入的样本在某些情况下对同一种药物具有耐药性,而在其他情况下则对该药物敏感(即它们的实际分类不同)。这使得作者假设,SD-CNN 模型未包含的位点中的突变可能是耐药性的原因。

可解释性和显著性映射

作者使用DeepLIFTAvanti Shrikumar, Peyton Greenside, and Anshul Kundaje. 2017. 通过传播激活差异来学习重要特征。发表于第 34 届国际机器学习大会 — 第 70 卷(ICML’17)。JMLR.org, 3145–3153。),这是一种计算输入特征对输出贡献的方法,来解释模型的预测。通过在计算机中变更基因型输入(模拟输入)并将结果与‘参考结果’进行比较,作者发现了以前未知的变异对 MTB 药物耐药性的影响。

对模型架构的几点思考

作为一名机器学习工程师和研究员,大部分关注点在计算机视觉领域,我从阅读这篇论文和相关背景材料中学到了很多。显然,神经网络在医学和生物学领域的建模技术方面具有巨大潜力。在将此模型中使用的技术与我自己的经验进行比较时,我想到了一些如果我在研究的第二阶段工作时会感兴趣尝试的东西:

  1. Gap encoding — 四种核苷酸以 1-hot 编码表示,并额外增加一个表示间隙的元素。我很好奇如果将间隙表示改为[0, 0, 0, 0],结果是否会有所改善。

  2. 特征深度——此处呈现的架构在整个模型中使用单一特征。我对计算机视觉的直觉让我对特征多样化的可能性感到好奇。正如在计算机视觉中,训练过程可能会收敛到图像中的单一位置具有多种特征,如‘圆度’,‘金属性’,‘平滑度’,我猜在基因组序列中也可能是这样。

  3. 填充类型——作者在其卷积层中使用‘valid’填充,而不是计算机视觉中常用的‘same’填充。这会随着序列在层之间传递逐渐缩短序列。‘Same’填充保持序列的空间大小,允许序列边缘附近的结构即使在模型的后期阶段仍然保持一些效果。它还允许诸如将来自不同阶段的层的输出连接起来等操作。

  4. 注意机制——(Vaswani 等人,Attention Is All You Need, 2017, NIPS)——注意块在发现序列中远程标记之间的微妙关系(例如,NLP 中的句子不同部分)时非常有用,当一个标记的值可能对另一个标记的值的解释产生重大影响时,它们尤其相关。看看添加注意块是否能改善结果,如果能——则使用它来追溯基因组中的区域间隐藏关系会很有趣。

预测高急诊室使用率

原文:towardsdatascience.com/predicting-high-emergency-room-visit-rates-5fff6a8950f4

使用 Python 分析健康的社会决定因素(SDOH)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Meagan Burkhart

·发布于数据科学前沿 ·10 分钟阅读·2023 年 5 月 23 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由国家癌症研究所拍摄,来自 Unsplash

本项目的目标是利用来自AHRQ的按县划分的社会决定因素(SDoH)数据,查找特定变量与县急诊就诊率之间的关系。最终,我希望开发一个与高急诊率相关的顶级特征的预测模型。我决定查看 2019 年和 2020 年的数据(2018 年数据不可用)。该数据集经过AHRQ的明确许可使用。

这个逐步教程介绍了我加载、清理、分析和建模数据的过程。

加载数据

第一步是加载两个数据文件并检查其形状。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者的代码

由于两个数据框的列数不同,我将导入数据字典并提取相同的列。

我根据列名(内连接)合并了数据字典,以获取最终的公共列列表。一旦获得列,我选择了每个数据框中这些列的子集,并将它们与轴=0 进行连接,以垂直添加它们。我的 df_final 包含了 2019 年和 2020 年数据中的公共列。

dictionary2019=pd.read_csv('Data/datadictionary2019.csv', encoding= "ISO-8859–1")
dictionary2020=pd.read_csv('Data/datadictionary.csv', encoding= "ISO-8859–1")
commoncolumns=dictionary2020.merge(dictionary2019, how='inner', left_on='name', right_on='name')['name'].values.tolist()

dfa=df2019[commoncolumns]
dfb=df2020[commoncolumns]
df_final=pd.concat([dfa, dfb], axis=0)
df_final

关于数据

该数据有 674 列,因此我们需要尝试缩小需要查看的特征。让我们从我感兴趣的变量——急诊访问率开始。

数据集包括每 1,000 名男性医疗保险(双重和非双重)受益人的急诊科访问次数。然而,数据提供了单独的男性和女性比率,因此我将创建一个加权平均的整体急诊率。

为此,我将比率分别乘以男性的百分比和女性的百分比,然后将这些值相加。

男性 ED 比率:597.1129158207091

女性 ED 比率:639.9023742580443

import numpy as np
#create an average overall ED rate by weighting the male and female rates by their percentage of the population and adding
df_final['Malerate']=(df_final['ACS_PCT_MALE']*df_final['MMD_ED_VISITS_M_RATE'])/100
df_final['Femalerate']=(df_final['ACS_PCT_FEMALE']*df_final['MMD_ED_VISITS_F_RATE'])/100
df_final['EDrate']=df_final['Malerate']+df_final['Femalerate']
#print the mean ED rate to use as our baseline for Good and Bad outcomes

数据清理

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

查看数据框后,我们可以看到一些县的数据缺失。为了解决这个问题,我删除了某些列,并用均值插补了其他列。

对于这个项目,我们将删除 ED 比率计算中缺失值的行,因为这意味着他们没有相关的急诊访问数据。

然后,我们将找到数据集的 80 百分位数,作为‘高 ED 比率’的截断点。这将成为我未来预测的结果变量。整体 ED 比率将用于相关性和探索性数据分析。

#create cutoff value for high EDrate
cutoff=np.percentile(df_final['EDrate'], 80)
#if ED rate is greater than the 50th percentile then flag as high or else 0 for low
df_final['HighED']=np.where(df_final['EDrate']>cutoff, 1, 0)

为了处理一些缺失数据,我开始时删除了任何缺失值超过 10%的列。然后,我得到了剩余列中有缺失值的列表,如下所示。

# drop columns with >10% missing
df_final.dropna(thresh=0.90*len(df_final),axis=1, inplace=True)

#list columns remaining with missing values
df_final.isnull().sum().to_frame(name='counts').query('counts > 0').sort_values(by='counts', ascending=False)

我们先用均值对所有浮点型列进行简单插补。首先,我们需要将训练集和测试集分开:

from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
columns=df_final.loc[:, df_final.dtypes == float].columns.values
X_train, X_test, y_train, y_test = train_test_split( df_final.index, df_final['HighED'], stratify=df_final['HighED'], test_size=0.25,  random_state=42)
df_train=df_final.iloc[X_train]
df_test=df_final.iloc[X_test]
imp = SimpleImputer( strategy='mean')
df_impute_train=df_train[columns].copy(deep=True)
df_impute_test=df_test[columns].copy(deep=True)
df_impute_train[:]=imp.fit_transform(df_train[columns])
df_impute_test[:]=imp.fit_transform(df_test[columns])

df_impute_train

接下来,我们将选择所有浮点型列,以便我们可以运行每个特征与目标特征之间的相关性分析。根据相关性,我们将设定阈值,并保留那些与 ED 比率显著正相关或负相关的列。

#print positive correlations over .2 (small-moderate effect size)
def positivecorrelations(threshold=.2):
    columns=df_impute_train.columns.values
    positivecolumns=[]
    for col in columns:        
        if df_impute_train['EDrate'].corr(df_impute_train[col])>threshold:
            positivecolumns.append(col)
    return positivecolumns

poscols=(positivecorrelations())
#print negative correlations less than -.2 (small-moderate effect size)
def negativecorrelations(threshold=-.2):
    columns=df_impute_train.columns.values
    negativecolumns=[]
    for col in columns:        
        if df_impute_train['EDrate'].corr(df_impute_train[col])<threshold:
            negativecolumns.append(col)
    return negativecolumns

negcols=(negativecorrelations())

我制作了一个最终的列列表,如下所示:

#make a final list of significant columns
sigcols=poscols+negcols
print(sigcols)
len(sigcols)

我们最终得到了 140 列。在打印列的列表后,我意识到我仍然需要进一步清理 -

我们要确保不包括任何包含 ED 的变量,因此我们会通过以下代码过滤掉所有这些列,以及我们计算出的女性比率和男性比率。

stringVal = "ED"
sigcols.remove('Femalerate')
sigcols.remove('Malerate')
finalcols=[x for x in sigcols if stringVal not in x]

len(finalcols)

这使我们剩下了 140 列。我们将数据集缩小到 112 列。但现在我查看列的列表时,看到我们还应排除任何包含 _IP(住院)、_PA(急性后期)和 _EM(E&M)的列。我们也不关心每月的最低和最高温度,因此我会删除这些列。

stringVal = "_IP"

finalcols=[x for x in finalcols if stringVal not in x]
stringVal2="TEMP_"

finalcols=[x for x in finalcols if stringVal2 not in x]
stringVal3="_PA"
stringVal4="_EM"

finalcols=[x for x in finalcols if stringVal3 not in x]
finalcols=[x for x in finalcols if stringVal4 not in x]
len(finalcols)
#result is 77

基于对输出的另一轮仔细检查,我发现有些特征测量的是非常相似的事物(即整体估计百分比与年龄 X-Y)。如果整体值存在,就删除那些指定年龄范围的列。此外,PQI 在所有不同的人群子集中的重复,因此我们将采用加权平均来找出整体比率,就像我们之前用 ED 比率一样。

finalcols=[x for x in finalcols if x not in ('ACS_PCT_PRIVATE_SELF_BELOW64', 'ACS_PCT_PRIVATE_SELF','SAIPE_PCT_POV_0_17', 'ACS_PCT_PRIVATE_ANY_BELOW64', 'SAIPE_PCT_POV_5_17', 'NEPHTN_HEATIND_90', 'NEPHTN_HEATIND_95', 'NEPHTN_HEATIND_100')]

df_impute_train.loc[:, 'MalePQI']=(df_impute_train['ACS_PCT_MALE']*df_impute_train['MMD_OVERALL_PQI_M_RATE'])/100
df_impute_train.loc[:, 'FemalePQI']=(df_impute_train['ACS_PCT_FEMALE']*df_impute_train['MMD_OVERALL_PQI_F_RATE'])/100
df_impute_train.loc[:, 'PQI']=df_impute_train['MalePQI']+df_impute_train['FemalePQI']
df_impute_train['PQI'].describe()

df_impute_test.loc[:, 'MalePQI']=(df_impute_test['ACS_PCT_MALE']*df_impute_test['MMD_OVERALL_PQI_M_RATE'])/100
df_impute_test.loc[:, 'FemalePQI']=(df_impute_test['ACS_PCT_FEMALE']*df_impute_test['MMD_OVERALL_PQI_F_RATE'])/100
df_impute_test.loc[:, 'PQI']=df_impute_test['MalePQI']+df_impute_test['FemalePQI']
df_impute_test['PQI'].describe()

rate="_RATE"
finalcols=[x for x in finalcols if rate not in x]
race="ACS_PCT_BLACK_"
finalcols=[x for x in finalcols if race not in x]

dictionary2020[['name', 'label']][dictionary2020['name'].isin(finalcols)]

我还从数据字典中提取了所有列的完整标签,并审查了它们,以确保所有列在我对医疗分析的背景知识下都是实用的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

最终结果得到 78 个可以使用的列。接下来,我想快速可视化这些变量与 EDrate 之间的关系,于是写了一个小循环来创建散点图:

import matplotlib.pyplot as plt
def create_scatterplots(var='EDrate'):
    for col in finalcols:
        plt.scatter(df_impute_train[col], df_impute_train[var])
        plt.title(("{} vs {}".format(col, var)))
        plt.show()
create_scatterplots()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

本分析的下一部分涉及开发几个预测模型。在这个项目中,我使用了逻辑回归、支持向量机、随机森林分类器和 XGBoost 分类器。我们首先测试的是逻辑回归和 SVM,因此需要对数据进行缩放。

我选择了 MinMaxScaler 来尝试减少异常值的影响。

from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()

X_train=df_impute_train[finalcols].copy(deep=True)
X_test=df_impute_test[finalcols].copy(deep=True)
y_train=df_train['HighED']
y_test=df_test['HighED']
scaler.fit(X_train)
X_train_scaled=scaler.transform(X_train)
X_test_scaled=scaler.transform(X_test)

逻辑回归

如果你对分类问题不熟悉,可以查看逻辑回归介绍这篇文章,由Ayush Pant撰写。对于这次逻辑回归,我决定设置 class_weight=’balanced’,因为这是一个不平衡的分类问题:

from sklearn.linear_model import LogisticRegression
model = LogisticRegression(solver='liblinear', random_state=0, class_weight='balanced').fit(X_train_scaled, y_train)

model.score(X_test_scaled, y_test)

下方的混淆矩阵显示了 TP、TN、FP 和 FN。我们还打印了分类报告。

from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt

cm = confusion_matrix(y_test, model.predict(X_test_scaled))

fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(cm, cmap='summer', alpha=0.3)
ax.grid(False)
ax.xaxis.set(ticks=(0, 1), ticklabels=('Predicted LowED', 'Predicted HighED'))
ax.yaxis.set(ticks=(0, 1), ticklabels=('Actual LowED', 'Actual HighED'))
ax.set_ylim(1.5, -0.5)
for i in range(2):
    for j in range(2):
        ax.text(j, i, cm[i, j], ha='center', va='center', color='black')
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

支持向量机

我尝试的下一个模型是支持向量机。如果这是一个你不熟悉的话题,我推荐查看Ajay Yadav的文章支持向量机(SVM)

在我的项目中,我对类别权重进行了调整,以观察哪些设置能在模型结果方面取得平衡。

from sklearn.svm import SVC
clf=SVC(class_weight='balanced')
clf.fit(X_train_scaled, y_train)
clf.score(X_test_scaled, y_test)
cm = confusion_matrix(y_test, clf.predict(X_test_scaled))

fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(cm, cmap='summer', alpha=0.3)
ax.grid(False)
ax.xaxis.set(ticks=(0, 1), ticklabels=('Predicted LowED', 'Predicted HighED'))
ax.yaxis.set(ticks=(0, 1), ticklabels=('Actual LowED', 'Actual HighED'))
ax.set_ylim(1.5, -0.5)
for i in range(2):
    for j in range(2):
        ax.text(j, i, cm[i, j], ha='center', va='center', color='black')
plt.show()
print(classification_report(y_test, clf.predict(X_test_scaled)))

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

接下来要测试的模型是随机森林分类器和 XGBoost。由于基于树的模型不需要缩放,我将使用原始的 X 数据来进行这两项测试。

随机森林分类器

在你开始运行随机森林分类器之前,可能需要阅读一篇背景文章,例如Tony Yiu的文章理解随机森林。一旦你对概念有了良好的理解,可以尝试运行下面的简单示例代码:

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
clf = RandomForestClassifier( max_features=None, random_state=0, class_weight='balanced')
clf.fit(X_train, y_train)
clf.score(X_test, y_test)
cm = confusion_matrix(y_test, clf.predict(X_test))

fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(cm, cmap='summer', alpha=0.3)
ax.grid(False)
ax.xaxis.set(ticks=(0, 1), ticklabels=('Predicted LowED', 'Predicted HighED'))
ax.yaxis.set(ticks=(0, 1), ticklabels=('Actual LowED', 'Actual HighED'))
ax.set_ylim(1.5, -0.5)
for i in range(2):
    for j in range(2):
        ax.text(j, i, cm[i, j], ha='center', va='center', color='black')
plt.show()
print(classification_report(y_test, clf.predict(X_test)))

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

到目前为止,我们的随机森林分类报告如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

凭借 94%的准确率,这个模型表现得相当不错,但我对假阴性率感到担忧。让我们看看 XGBoost 是否能做得更好。

XGBoost

对 XGBoost 不熟悉?查看George Seif的文章——XGBoost 初学者指南,在运行下面的代码之前熟悉一下提升树。

 import xgboost as xgb

# Init classifier
xgb_cl = xgb.XGBClassifier(random_state=0)

# Fit
xgb_cl.fit(X_train, y_train)

# Predict
preds = xgb_cl.predict(X_test)
cm = confusion_matrix(y_test, xgb_cl.predict(X_test))

fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(cm)
ax.grid(False)
ax.xaxis.set(ticks=(0, 1), ticklabels=('Predicted LowED', 'Predicted HighED'))
ax.yaxis.set(ticks=(0, 1), ticklabels=('Actual LowED', 'Actual HighED'))
ax.set_ylim(1.5, -0.5)
for i in range(2):
    for j in range(2):
        ax.text(j, i, cm[i, j], ha='center', va='center', color='red')
plt.show()
print(classification_report(y_test, xgb_cl.predict(X_test)))

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

看起来随机森林分类器是明显的赢家,具有更高的准确性、精确度、召回率和 f1。

让我们来看看一些特征重要性值:

 features = df_impute_test.columns.values
importances = clf.feature_importances_
indices = np.argsort(importances)
plt.figure(figsize=(8, 20))
plt.title('Feature Importances')
plt.barh(range(len(indices)), importances[indices], color='b', align='center')
plt.yticks(range(len(indices)), [features[i] for i in indices])
plt.xlabel('Relative Importance')
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

 import shap
explainer = shap.TreeExplainer(clf)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_test, plot_type="bar")

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

主要发现

  1. 最重要的特征是黑人女性的比例

— 这表明少数群体在急性医疗方面面临更大的困难,任何歧视或健康不平等问题应立即解决。

  1. 残疾退伍军人的比例也是一个重要特征

— 这表明残疾和退伍军人身份导致了急诊科的高使用率,需要解决潜在的风险因素

  1. SHAP 值显示 PQI 和过去 12 个月领取食物券/SNAP 的家庭比例有显著影响

— 这间接地将社会经济地位与健康不良结果联系起来,表明如果医疗提供者能够在这一领域提供帮助,患者可能不会那么频繁地去急诊科。

推荐和未来方向

基于这些发现,医疗保健领域的数据专业人员可能考虑开发预测模型,以隔离与健康不良结果相关的社会决定因素特征。

希望目标医疗程序可以解决诸如由于种族、残疾或社会经济地位而产生的歧视等社会决定健康需求。通过为高风险个体或高风险区域提供更密集的护理和支持,医疗保健分析师可以致力于在非急性(非急诊)环境中更好地治疗这些患者。

这可以通过社区健康设施、家庭护理机构以及与健康保险公司合作的其他伙伴来实现。

结论

在这篇文章中,我介绍了一个与社会决定健康相关的公共数据集。通过分析 2019 年和 2020 年的数据,我得出了几个预测模型。该模型的目标是根据社会决定健康因素预测一个县是否会有高急诊科使用率。我的分析中最好的模型是随机森林分类器。我们回顾了驱动模型的核心特征,并解释了对医疗保健分析师的影响。

— — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — —

查看 我的 GitHub 上的整个笔记本

如果你是 Medium 的新用户,喜欢这样的故事,可以在这里注册。

互联网引用: 社会决定健康数据库。内容最后审查于 2022 年 11 月。医疗保健研究与质量局,罗克维尔,MD。www.ahrq.gov/sdoh/data-analytics/sdoh-data.html

使用 GPT-3 预测人道主义数据集的元数据

原文:towardsdatascience.com/predicting-metadata-for-humanitarian-datasets-using-gpt-3-b104be17716d?source=collection_archive---------4-----------------------#2023-01-18

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Matthew Harris

·

关注 发布于 Towards Data Science ·19 分钟阅读·2023 年 1 月 18 日

快速响应人道主义灾难,更好的是,能够预测这些灾难,可以挽救生命[1]。数据是关键,不仅仅是拥有大量数据,而是清洁的数据并且理解良好[2],以便创建对实际情况的清晰视图。在许多情况下,这些关键数据被存储在数百个小型电子表格中,因此在进行数据整合时可能非常耗时,并且在发生人道主义事件时,随着新数据的不断涌入,维护这些数据也很困难。自动化数据发现过程可能会加快响应速度并改善受影响人员的结果。

简化发现的一个方法是确保表格数据有描述每列的元数据。这可以帮助将数据集链接在一起,例如知道一个地雷位置表中的列指定了经度和纬度,这与另一个表中定位野战医院的列类似。列名并不总能明显显示它们可能包含的数据,这些数据可能以多种语言呈现,并遵循不同的标准。在理想的情况下,这种元数据是与数据一起提供的,但正如我们下面将看到的那样,这通常不是情况。手动处理这项工作可能非常庞大。

在这篇文章中,我将探讨我们如何通过使用OpenAI 的 GPT-3 大型语言模型来预测人道主义数据集的元数据属性,从而帮助自动化这个过程,并改进以往工作的表现。

人道主义数据交换(HDX)

人道主义数据交换(HDX)是一个极好的平台,旨在通过以标准化的方式将人道主义数据集合在一起,解决这些问题。截至我写这篇文章时,全球共有 20,403 个数据集,涵盖了广泛的领域和文件类型。这些数据集中的 CSV 和 Excel 文件产生了大约 148,000 个不同的表格,数据量非常庞大!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

人道主义数据交换(HDX)平台上的文件类型。有关数据如何汇总的信息,请参见这个笔记本

人道主义交换语言(HXL)

HDX 平台的一个优点是它鼓励数据拥有者使用人道主义交换语言(HXL)格式来标记他们的数据。这些元数据使得将数据结合起来并以有意义的方式使用变得更容易,从而在时间紧迫时加快处理速度。

HXL 标签有两种形式,一种是设置在数据集级别的标签,另一种是应用于表格数据中列的字段级别标签。后者看起来像这样:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

第二行有 HXL 标签的表格示例[#HXL Standards examples]

注意列标题下方的第二行,那些是 HXL 标签。它们由前缀为‘#’的标签(例如‘#adm1’)和某些情况下的属性(例如‘+name’)组成。

挑战在于这些字段级标签并不总是设置在 HDX 数据集上,这使得使用那里的数据变得更加困难。查看肯尼亚的 CSV 和 Excel 数据,大多数表格似乎缺少列 HXL 标签。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

分析人道主义数据交换 (HDX)平台上肯尼亚的数据文件,查看哪些文件有 HXL 列标签。有关如何整理数据的细节,请参见这个笔记本

如果我们能填补那些空白并为尚未拥有标签的列填充 HXL 标签,那不是很好吗?

微软已经在使用fastText 嵌入预测 HXL 标签方面做了一些非常出色的工作,请参见这个笔记本和相应的论文 [3]。作者在预测标签方面达到了 95% 的准确率,预测属性方面达到了 92%,表现非常出色。

不过,我想知道我们是否可以使用另一种技术,现在有了一些新技术……

GPT-3

正如我在上一篇文章中提到的,去年对生成性 AI 似乎真的非常关注。这个故事的明星之一是 Open AI 的GPT-3 大型语言模型(LLM),它有一些非常惊人的能力。重要的是,它可以经过微调来学习语言的特殊应用模式,如计算机代码

所以我想到 HXL 标签只是另一种语言‘特殊情况’,可能可以通过一些 HXL 标签示例对 GPT-3 进行微调,然后查看它是否能对新数据进行预测。

从 HDX 获取一些训练数据

首先,值得澄清一下 HDX 数据集、资源和表格的层次结构。‘数据集’可以包含一组‘资源’,这些资源是文件。数据集有自己的页面,比如这个,提供了很多关于历史、上传者和数据集级别标签的有用信息。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

HDX 平台上的一个 HDX 数据集示例

上面的示例有两个 CSV 文件资源,如果选择更多 > 在 HDX 上预览,可以显示 HXL 标签。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一个 HDX 平台上的示例资源

这是一个超级酷的平台!

我们将下载像上面这样的资源进行分析。HDX 提供了一个 用于与其 API 交互的 Python 库,可以通过……进行安装。

pip install hdx-python-api

然后你需要设置连接。由于我们仅下载开放数据集,因此不需要设置 API 密钥……

from hdx.utilities.easy_logging import setup_logging
from hdx.api.configuration import Configuration
from hdx.data.dataset import Dataset

setup_logging()
Configuration.create(hdx_site="prod", user_agent="my_agent_name", hdx_read_only=True)

经过一些实验,我写了一个小的包装器来下载每个数据集的资源(文件)。它支持 CSV、TSV、XLS 和 XLSX 文件类型,这些类型应该包括足够的表格用于我们的模型微调。它还保存数据集和资源的 HDX JSON 元数据以及每个文件。

def is_supported_filetype(format):
    """
    Checks if the file format is currently supported for extracting meta data.

    Parameters
    ----------
    format : str
        The file format to check.

    Returns
    -------
    bool
        True if the file format is supported, False otherwise.
    """
    matches = ["CSV", "XLSX", "XLS", "TSV"]
    if any(x in format for x in matches):
        return True
    else:
        return False

def download_data(datasets, output_folder):
    """
    Downloads data from HDX. Will save dataset and resource meta data for each file

    Parameters
    ----------
    datasets : pandas.DataFrame
        A dataframe containing the datasets to download.
    output_folder : str
        The folder to download the data to.
    """
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)

    for index, row in datasets.iterrows():
        dataset = Dataset.read_from_hdx(row["id"])
        resources = dataset.get_resources()
        for resource in resources:
            dir = f"./{output_folder}/{row['name']}_{row['id']}"
            print(
                f"Downloading {row['name']} - {resource['name']} - {resource['format']}"
            )
            resource["dataset_name"] = row["name"]
            if not os.path.exists(dir):
                dump_hdx_meta_file(dataset, dir, "dataset.json")
            try:
                dir = f'{dir}/{get_safe_name(resource["name"])}_{get_safe_name(resource["id"])}'
                if not os.path.exists(dir):
                    dump_hdx_meta_file(resource, dir, "resource.json")
                    if is_supported_filetype(resource["format"]):
                        url, path = resource.download(dir)
                    else:
                        print(
                            f"*** Skipping file as it is not a supported filetype *** {resource['name']}"
                        )
                else:
                    print(f"Skipping {dir} as it already exists")
            except Exception as e:
                traceback.print_exc()
                sys.exit()

    print("Done")

上述内容有点啰嗦,因为我想能够重新启动下载,并让过程继续从中断的地方开始。此外,API 似乎偶尔会出现错误,可能是由于我的互联网连接,所以里面有一些 Try/Except。通常我不喜欢 Try/Except,但目标是创建一个训练数据集,所以我不介意缺少一些资源,只要我有一个代表性的样本来训练 GPT-3。

使用搜索 HDX API,我们搜索‘HXL’以寻找可能具有 HXL 标签的数据集,然后下载这些文件……

datasets_hxl = pd.DataFrame(Dataset.search_in_hdx("HXL"))
download_data(datasets_hxl, output_folder)

这可能需要一段时间(几个小时),所以不妨喝杯好茶!

根据我能发现的,列 HXL 标签在 HDX 资源元数据中未列出,因此要提取这些标签,我们必须分析下载的文件。经过一些实验,我写了一些辅助函数……

def check_hdx_header(first_row):
    """
    This function checks if the first row of a csv file likely an HDX header.
    """
    matches = ["#meta", "#country", "#data", "#loc", "#geo"]
    if any(x in first_row for x in matches):
        return True
    else:
        return False

def set_meta_data_fields(data, file, dataset, resource, sheet, type):
    """
    This function create a data frame with meta data about the data, as well as a snippet of its
    first nrows.

    Parameters:
        data: a dataframe
        file: the name of the data file
        dataset: the dataset JSON object from HDX
        resource: the resource JSON object from HDX
        sheet: the sheet name if the data was a tab in a sheet
        type: the type of file, CSV, XLSX, etc.
    Returns:
        dict: a dictionary with metadata about the dataframe
    """

    nrows = 10

    # Data preview to only include values
    data = data.dropna(axis=1, how="all")

    cols = str(list(data.columns))
    if data.shape[0] > 0:
        first_row = str(list(data.iloc[0]))
        has_hxl_header = check_hdx_header(first_row)
        num_rows = int(data.shape[0])
        num_cols = int(data.shape[1])
        first_nrows = data.head(nrows)
    else:
        first_row = "No data"
        has_hxl_header = "No data"
        num_rows = 0
        num_cols = 0
        first_nrows = None

    dict = {}

    dict["resource_id"] = resource["id"]
    dict["resource_name"] = resource["name"]
    dict["resource_format"] = resource["format"]
    dict["dataset_id"] = dataset["id"]
    dict["dataset_name"] = dataset["name"]
    dict["dataset_org_title"] = dataset["organization"]["title"]
    dict["dataset_last_modified"] = dataset["last_modified"]
    dict["dataset_tags"] = dataset["tags"]
    dict["dataset_groups"] = dataset["groups"]
    dict["dataset_total_res_downloads"] = dataset["total_res_downloads"]
    dict["dataset_pageviews_last_14_days"] = dataset["pageviews_last_14_days"]
    dict["file"] = file
    dict["type"] = type
    dict["dataset"] = dataset
    dict["sheet"] = sheet
    dict["resource"] = resource
    dict["num_rows"] = num_rows
    dict["num_cols"] = num_cols
    dict["columns"] = cols
    dict["first_row"] = first_row
    dict["has_hxl_header"] = has_hxl_header
    dict["first_nrows"] = first_nrows
    return dict

def extract_data_details(f, dataset, resource, nrows, data_details):
    """

    Reads saved CVS and XLSX HDX files and extracts headers, HDX tags and sample data.
    For XLSX files, it extracts data from all sheets.

    Parameters
    ----------
    f : str
        The file name
    dataset : str
        The dataset name
    resource : str
        The resource name
    nrows : int
        The number of rows to read
    data_details : list
        The list of data details

    Returns
    -------
    data_details : list
        The list of data details

    """
    if f.endswith(".xlsx") or f.endswith(".xls"):
        print(f"Loading xslx file {f} ...")
        try:
            sheet_to_df_map = pd.read_excel(f, sheet_name=None)
        except Exception:
            print("An exception occurred trying to read the file {f}")
            return data_details
        for sheet in sheet_to_df_map:
            data = sheet_to_df_map[sheet]
            data_details.append(
                set_meta_data_fields(data, f, dataset, resource, sheet, "xlsx")
            )
    elif f.endswith(".csv"):
        print(f"Loading csv file {f}")
        # Detect encoding
        with open(f, "rb") as rawdata:
            r = chardet.detect(rawdata.read(100000))
        try:
            data = pd.read_csv(f, encoding=r["encoding"], encoding_errors="ignore")
        except Exception:
            print("An exception occurred trying to read the file {f}")
            return data_details
        data_details.append(set_meta_data_fields(data, f, dataset, resource, "", "csv"))
    else:
        type = f.split(".")[-1]
        print(f"Type {type} for {f}")
        data = pd.DataFrame()
        data_details.append(set_meta_data_fields(data, f, dataset, resource, "", type))

    return data_details

# Loop through downloaded folders
def extract_all_data_details(startpath, data_details):
    """
    Extracts all data details for downloaded HDX files in a given directory.

    Parameters
    ----------
    startpath : str
        The path to the directory containing all datasets.
    data_details : list
        Results

    Returns
    -------
    data_details : pandas.DataFrame
        Results, to which new meta data was appended. 
        See function set_meta_data_fields for columns
    """
    for d in os.listdir(startpath):
          d = f"{startpath}/{d}"
          with open(f"{d}/dataset.json") as f:
              dataset = json.load(f)
          for r in os.listdir(d):
              if "dataset.json" not in r:
                  with open(f"{d}/{r}/resource.json") as f:
                      resource = json.load(f)
                  for f in os.listdir(f"{d}/{r}"):
                      file = str(f"{d}/{r}/{f}")
                      if ".json" not in file:
                          data_details = extract_data_details(
                              file, dataset, resource, 5, data_details
                          )
    data_details = pd.DataFrame(data_details)
    return data_details

现在我们可以在之前下载的数据文件上运行……

hxl_resources_data_details = extract_all_data_details(f"./data/hxl_datasets/", [])
print(hxl_resources_data_details.shape)

(25695, 22)

这个数据框包含 25,695 行,用于在 HDX 上搜索‘HXL’时扫描 CSV 和 Excel 文件找到的每个表格数据集,包含数据预览、列名称和在某些情况下的 HXL 标签。

训练/测试拆分

通常,我会简单地使用 Scikit learn 的 train_test_split 来处理要用于模型的数据。然而,在这样做时,我注意到同一个数据集中的重复资源(文件)可能会出现在训练集和测试集中。例如,一个组织可能会提供多个机场的文件,每个文件的格式和 HXL 标签完全相同。如果我们生成一个提示数据框然后分割,这些机场将同时出现在训练集和测试集中,这无法很好地反映我们的需求,即预测全新的数据集的 HXL 标签。

为了解决这个问题,我采取了以下措施:

  1. 将 HDX ‘数据集’拆分为训练集/测试集(请记住,一个数据集可能包含多个资源文件)

  2. 使用每个,我创建了资源的数据框,每行一个数据文件。

  3. 然后,使用这些训练/测试资源数据框,我创建了训练/测试数据框,每列一个。这些是 GPT-3 微调所需的提示。

创建 GPT-3 微调提示

为了微调 GPT-3,我们需要提供一个 JSONL 格式的提示和响应训练文件。我决定使用(i)列名;(ii)该列的一个数据样本。补全将是 HXL 标签和属性。

这里是一个例子……

{"prompt": " 'scheduled_service' | \"['1', '1', '0', '0', '0', '0', '0', '0']\"", "completion": " #status+scheduled"}

GPT-3 的格式非常特殊,花了一段时间才调整好!像在补全开始处添加空格这样的设置是 OpenAI 推荐的。

 def get_hxl_tags_list(resources):
    """
    Build a list of the HXL tags found in a dataframe of HDX resources.

    Parameters
    ----------
    resources : pandas dataframe
        A dataframe of HDX resources

    Returns
    -------
    hxl_tags : list
        A list of HXL tags.
    """
    hxl_tags = []
    for row, d in resources.iterrows():
        if d["has_hxl_header"] == True:
            fr = d["first_row"].replace(" ", "")
            for c in fr.split(","):
                fr = re.sub("\[|\]|\"|\'","", c)
                hdxs = fr.split("+")
                for h in hdxs:
                    if h not in hxl_tags and len(h) > 0:
                        hxl_tags.append(h.lower())
    hxl_tags = list(set(hxl_tags))
    hxl_tags.remove('nan')
    return hxl_tags

def get_prompt(col_name, data):
    """
    Builds the prompt for GPT-3 for predicting HXL tags and attributes

    Parameters
    ----------
    col_name : str
        Column name
    data : list
        A list of sample data for the column

    Returns
    -------
    prompt : string
        A prompt for GPT-3.
    """
    ld = len(data) - 1
    col_data = json.dumps(str(list(data.iloc[1:ld])))
    prompt = f" {col_name} | {col_data}".lower()
    return prompt

def create_training_set(resources):
    """
    Builds a jsonl training data file for GPT-3 where each row is a prompt for a column HXL tag.

    It will only output prompts where the sample data for the column didn't contain nans.

    Parameters
    ----------
    resources : pandas dataframe
        A dataframe of HDX resources

    Returns
    -------
    train_data : list
        A list of prompts and completions for the HXL tag autocomplete feature.
    """
    train_data = []
    for row, d in resources.iterrows():
        if d["has_hxl_header"] == True:
            cols = d["columns"][1:-1].split(",")
            hdxs = d["first_row"][1:-1].split(",")
            data = d["first_nrows"]
            has_hxl_header = d["has_hxl_header"]
            if len(cols) == len(hdxs) and len(cols) > 1:
                ld = len(data) - 1
                for i in range(0, len(cols)):
                    if i < len(hdxs):
                        hdx = re.sub("'|\"", "", hdxs[i])
                    # Only include is has HXL tags and good sample data in column
                    if has_hxl_header == True and hdx != np.nan:
                        prompt = get_prompt(cols[i], data.iloc[:,i])
                        if 'nan' not in hdx and 'nan, nan' not in prompt:
                            p = {
                                "prompt": prompt,
                                "completion": f" {hdx}",
                            }
                            train_data.append(p)
    return train_data

你会注意到我在上面排除了数据中存在 NaNs 的提示。我认为我们应该从良好的数据样本开始,但这是需要在未来重新审视的内容。

我们现在可以生成一个训练数据集,并将其保存为 GPT-3 的文件……

# Create training set
X_train = create_training_set(X_train_resources)
print(f"Training records: {len(X_train)}")

train_file = "fine_tune_openai_train.jsonl"

with open(train_file, "w") as f:
    for p in X_train:
        json.dump(p, f)
        f.write("\n")

print("Done")

这就是训练数据的样子……

{"prompt": "  'Country ISO3' | \"['COD', 'COD', 'COD', 'COD', 'COD', 'COD', 'COD', 'COD']\"", "completion": "  #country+code"}
{"prompt": "  'Year' | \"['2010', '2005', '2000', '1995', '1990', '1985', '1980', '1975']\"", "completion": "  #date+year"}
{"prompt": "  'Indicator Name' | \"['Barro-Lee: Percentage of female population age 15-19 with no education', 'Barro-Lee: Percentage of female population age 15-19 with no education', 'Barro-Lee: Percentage of female population age 15-19 with no education', 'Barro-Lee: Percentage of female population age 15-19 with no education', 'Barro-Lee: Percentage of female population age 15-19 with no education', 'Barro-Lee: Percentage of female population age 15-19 with no education', 'Barro-Lee: Percentage of female population age 15-19 with no education', 'Barro-Lee: Percentage of female population age 15-19 with no education']\"", "completion": "  #indicator+name"}
{"prompt": "  'Indicator Code' | \"['BAR.NOED.1519.FE.ZS', 'BAR.NOED.1519.FE.ZS', 'BAR.NOED.1519.FE.ZS', 'BAR.NOED.1519.FE.ZS', 'BAR.NOED.1519.FE.ZS', 'BAR.NOED.1519.FE.ZS', 'BAR.NOED.1519.FE.ZS', 'BAR.NOED.1519.FE.ZS']\"", "completion": "  #indicator+code"}
{"prompt": "  'Value' | \"['48.1', '51.79', '52.1', '43.62', '35.44', '38.02', '43.47', '49.08']\"", "completion": "  #indicator+value+num"}
{"prompt": " 'Country ISO3' | \"['COD', 'COD', 'COD', 'COD', 'COD', 'COD', 'COD', 'COD']\"", "completion": " #country+code"}
{"prompt": "  'Year' | \"['2015', '2014', '2013', '2012', '2011', '2010', '2009', '2008']\"", "completion": "  #date+year"}

这个训练数据集中有 139,503 行,每列一行,来自我们从 HDX 下载的表格数据,专门用于那些列中有 HXL 标签的情况。

生成 OpenAI API 密钥

在我们能做任何事情之前,你需要先注册一个 OpenAI 账户。完成后,你应该有$18 的免费积分。如果使用少量数据,这应该足够,但在这次分析和几次模型训练中,我累计了 $50 的账单,因此你可能需要将信用卡绑定到你的账户上。

一旦你有了账户,你可以生成 API 密钥。我选择将其保存到本地文件并在代码中引用,但OpenAI Python 库也支持使用环境变量。

微调 GPT-3

好了,现在是令人兴奋的部分!有了我们精美的训练数据,我们可以按如下方式微调 GPT-3……

import openai
from openai import cli

# Open AI API key should be put into this file
openai.api_key_path = "./api_key.txt"

print("Uploading training file ...")
training_id = cli.FineTune._get_or_upload(train_file, True)
# validation_id = cli.FineTune._get_or_upload(validation_file_name, True)

print("Fine-tuning model ...")
create_args = {
    "training_file": training_id,
    # "validation_file": test_file,
    "model": "ada",
}
# https://beta.openai.com/docs/api-reference/fine-tunes/create
resp = openai.FineTune.create(**create_args)
job_id = resp["id"]
status = resp["status"]

print(f"Fine-tunning model with jobID: {job_id}.")

在上面,我们将微调模型提交给 OpenAI,然后可以查看状态……

result = openai.FineTune.retrieve(id=job_id)
print(result['status'])

我选择保持简单,但你也可以将其提交给 OpenAI,并通过此处显示的流来监控状态。

一旦状态显示为‘成功’,你现在可以获得一个模型 ID 用于预测(补全)……

result = openai.FineTune.retrieve(id=job_id)
model = result["fine_tuned_model"]

用我们微调过的 GPT-3 模型预测 HXL 标签

我们现在有了一个模型,来看看它能做什么!

要调用 GPT-3,你可以使用Open AI Python 库的 ‘create’ 方法。查看文档了解你可以调整的参数是值得的。

def create_prediction_dataset_from_resources(resources):
    """
    Generate a list of model column-level prompts from a list of resources (tables).

    It will only output prompts where the sample data for the column didn't contain nans.

    Parameters
    ----------
    resources : list
        A list of dictionaries containing the resource name, columns, first_row, and first_nrows.

    Returns
    -------
    prediction_data : list
        A list of dictionaries containing GPT-3 prompts (one per column in resource table)
    """

    prediction_data = []
    for index, d in resources.iterrows():
        cols = d["columns"][1:-1].split(",")
        hdxs = d["first_row"][1:-1].split(",")
        data = d["first_nrows"]
        has_hxl_header = d["has_hxl_header"]
        if len(cols) == len(hdxs) and len(cols) > 1:
            ld = len(data) - 1
            # Loop through columns 
            for i in range(0, len(cols)):
                if i < len(hdxs) and i < data.shape[1]:
                    prompt = get_prompt(cols[i], data.iloc[:,i])
                    # Skip any prompts with at least two nan values in sample data
                    if 'nan, nan' not in prompt:
                        r = {
                                "prompt": prompt
                            }
                        # If we were called with HXL tags (ie for test set), populate 'expected'
                        if has_hxl_header == True:
                            hdx = re.sub("'|\"| ", "", hdxs[i])
                            # Row has HXL tags, but this particular column doesn't have tags
                            if hdx == 'nan':
                                continue
                            else:
                                r["expected"]= hdx
                        prediction_data.append(r)
    return prediction_data

def make_gpt3_prediction(prompt, model, temperature=0.99, max_tokens=13):
    """
    Wrapper to call GPT-3 to make a prediction (completion) on a single prompt.

    Parameters
    ----------
    prompt : str
        Prompt to use for prediction
    model : str
        GPT-3 model to use
    temperature : float
        Temperature to use for sampling
    max_tokens : int
        Maximum number of tokens to use for sampling

    Returns
    -------
    result : dict
        Dictionary with prompt, predicted, and 
        log probabilities of each completed token
    """
    result = {}
    result["prompt"] = prompt
    model_result = openai.Completion.create(
        engine=model,
        prompt=prompt,
        temperature=temperature,
        max_tokens=max_tokens,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
        stop=["\n"],
        logprobs=1
    )
    result["predicted"] = model_result["choices"][0]["text"].replace(" ","")
    result["logprobs"]  = model_result['choices'][0]['logprobs']['top_logprobs']
    return result

def make_gpt3_predictions(
    sample_size, prediction_data, model, temperature=0.99, max_tokens=13, logprob_cutoff=-0.01
):

    """
    Wrapper to call GPT-3 to make predictions on test file for sample_size samples.

    Parameters
    ----------
    sample_size : int
        Number of predictions to make from test file
    prediction_data : list
        List of dictionaries with prompts
    model : str
        GPT-3 model to use
    postprocess : bool
        Whether to postprocess the predictions
    temperature : float
        Temperature to use for sampling
    max_tokens : int
        Maximum number of tokens to use for sampling
    prob_cutoff : float
        Logprob cutoff for filtering out low probability tokens

    Returns
    -------
    results : list
        List of dictionaries with prompt, predicted, predicted_post_processed
    """
    results = []
    prediction_data = sample(prediction_data, sample_size)
    for i in range(0, sample_size):
        prompt = prediction_data[i]["prompt"]
        res = make_gpt3_prediction(
            prompt, model, temperature, max_tokens
        )

        # Filter out low logprob predictions
        pred = ""
        seen_tokens = []
        for w in res["logprobs"]:
            token = list(w.keys())[0]
            prob = w[token]
            if prob > logprob_cutoff and token not in seen_tokens:
                pred += token
                if '+' not in token:
                    seen_tokens.append(token)
            else:
                break
        pred = re.sub(r" |\+$|\+v_$", "", pred)

        r = {
                "prompt": prompt,
                "predicted": res["predicted"],
                "predicted_log_prob_cutoff": pred,
                #"logprobs": res["logprobs"]
            }
        # For test sets we have expected values, add back for performance reporting
        if "expected" in prediction_data[i]:
            r['expected'] = prediction_data[i]['expected'].replace(' ', '')
        results.append(r)
    return results

我们使用以下方式调用,限制为 500 个提示……

# Generate the prompts we want GPT-3 to complete
print("Building model input ...")
prediction_data = create_prediction_dataset_from_resources(X_test_resources)

# How many predictions to try from the test set
sample_size = 500

# Make the predictions
print("Making GPT-3 predictions (completions) ...")
results = make_gpt3_predictions(
   sample_size, prediction_data, model, temperature=0.99, max_tokens=20, logprob_cutoff=-0.001
) 

这产生了以下结果……

def output_prediction_metrics(results, prediction_field="predicted_post_processed"):
    """
    Prints out model performance report if provided results in the format:

    [
        {
            'prompt': ' \'ISO3\' | "[\'RWA\', \'RWA\', \'RWA\', \'RWA\', \'RWA\', \'RWA\', \'RWA\', \'RWA\']"', 
            'predicted': ' #country+code+iso3+v_iso3+', 
            'predicted_post_processed': '#country+code', 
            'expected': '#country+code'
        }, 
        ... etc ...
    ]

    Parameters
    ----------
    results : list
        See above for format
    prediction_field : str
        Field name of element with prediction. Handy for comparing raw and post-processed predictions.
    """
    y_test = []
    y_pred = []
    y_justtag_test = []
    y_justtag_pred = []
    for r in results:
        if "expected" not in r:
            print("Provided results do not contain expected values.")
            sys.exit()
        y_pred.append(r[prediction_field])
        y_test.append(r["expected"])
        expected_tag = r["expected"].split("+")[0]
        predicted_tag = r[prediction_field].split("+")[0]
        y_justtag_test.append(expected_tag)
        y_justtag_pred.append(predicted_tag)

    print(f"GPT-3 results for {prediction_field}, {len(results)} predictions ...")
    print("\nJust HXL tags ...\n")
    print(f"Accuracy: {round(accuracy_score(y_justtag_test, y_justtag_pred),2)}")
    print(
        f"Precision: {round(precision_score(y_justtag_test, y_justtag_pred, average='weighted', zero_division=0),2)}"
    )
    print(
        f"Recall: {round(recall_score(y_justtag_test, y_justtag_pred, average='weighted', zero_division=0),2)}"
    )
    print(
        f"F1: {round(f1_score(y_justtag_test, y_justtag_pred, average='weighted', zero_division=0),2)}"
    )

    print(f"\nTags and attributes with {prediction_field} ...\n")
    print(f"Accuracy: {round(accuracy_score(y_test, y_pred),2)}")
    print(
        f"Precision: {round(precision_score(y_test, y_pred, average='weighted', zero_division=0),2)}"
    )
    print(
        f"Recall: {round(recall_score(y_test, y_pred, average='weighted', zero_division=0),2)}"
    )
    print(
        f"F1: {round(f1_score(y_test, y_pred, average='weighted', zero_division=0),2)}"
    )

    return 
output_prediction_metrics(results, prediction_field="predicted")

GPT-3 results for predicted, 500 predictions ...

Just HXL tags ...

Accuracy: 0.99
Precision: 0.99
Recall: 0.99
F1: 0.99

Tags and attributes with predicted ...

Accuracy: 0.0
Precision: 0.0
Recall: 0.0
F1: 0.0

嗯!?那,嗯……糟糕。仅预测 HXL 标签效果非常好,但预测标签和属性的效果就差很多。

让我们看看一些失败的预测……

 {
    "prompt": " 'gho (code)' | \"['mort_100', 'mort_100', 'mort_100', 'mort_100', 'mort_100', 'mort_100', 'mort_100', 'mort_100']\"",
    "predicted": "#indicator+code+v_hor_funder_",
    "expected": "#indicator+code"
}
{
    "prompt": "  'region (code)' | \"['afr', 'afr', 'afr', 'afr', 'afr', 'afr', 'afr', 'afr']\"",
    "predicted": "#region+code+v_reliefweb+f",
    "expected": "#region+code"
}
{
    "prompt": "  'dataid' | \"['310633', '310634', '310635', '310636', '310629', '310631', '310630', '511344']\"",
    "predicted": "#meta+id+fts_internal_view_all",
    "expected": "#meta+id"
}
{
    "prompt": "  'gho (url)' | \"['https://www.who.int/data/gho/indicator-metadata-registry/imr-details/5580', 'https://www.who.int/data/gho/indicator-metadata-registry/imr-details/5580']\"",
    "predicted": "#indicator+url+name+has_more_",
    "expected": "#indicator+url"
}
{
    "prompt": "  'year (display)' | \"['2014', '2014', '2014', '2014', '2014', '2014', '2014', '2014']\"",
    "predicted": "#date+year+name+tariff+for+",
    "expected": "#date+year"
}
{
    "prompt": "  'byvariablelabel' | \"[nan]\"",
    "predicted": "#indicator+label+code+placeholder+Hubble",
    "expected": "#indicator+label"
}
{
    "prompt": " 'gho (code)' | \"['ntd_bejelstatus', 'ntd_pintastatus', 'ntd_yawsend', 'ntd_leishcend', 'ntd_leishvend', 'ntd_leishcnum_im', 'ntd_leishcnum_im', 'ntd_leishcnum_im']\"",
    "predicted": "#indicator+code+v_ind+olk_ind",
    "expected": "#indicator+code"
}
{
    "prompt": "  'enddate' | \"['2002-12-31', '2003-12-31', '2004-12-31', '2005-12-31', '2006-12-31', '2007-12-31', '2008-12-31', '2009-12-31']\"",
    "predicted": "#date+enddate+enddate+usd+",
    "expected": "#date+end"
}
{
    "prompt": "  'endyear' | \"['2013', '2013', '2013', '2013', '2013', '2013', '2013', '2013']\"",
    "predicted": "#date+year+endyear+end_of_",
    "expected": "#date+year+end"
}
{
    "prompt": "  'country (code)' | \"['dnk', 'dnk', 'dnk', 'dnk', 'dnk', 'dnk', 'dnk', 'dnk']\"",
    "predicted": "#country+code+v_iso2+v_",
    "expected": "#country+code"
}

有趣的是。似乎模型几乎完美地完成并捕捉了正确的标签和属性,然后在末尾添加了一些额外的属性。例如……

"predicted": "#country+code+v_iso2+v_",
"expected": "#country+code"

让我们看看期望的标签和属性在预测的前半部分出现的频率……

passes = 0
fails = 0
for r in results:
    if r["predicted"].startswith(r["expected"]):
        passes += 1
    else:
        fails += 1
        #print(json.dumps(r, indent=4, sort_keys=False))

print(f" Out of {passes + fails} predictions, the expected tags and attributes where in the predicted tags and attributes {round(100*passes/(passes+fails),1)}% of the time.")

Out of 500 predictions, the expected tags and attributes where in the predicted tags and attributes 99.0% of the time.

在 500 次预测中,期望的标签和属性**99%**的时间都出现在预测的标签和属性中。换句话说,期望的值通常是大多数预测的首部分。

所以 GPT-3 在预测标签和属性方面具有很高的准确性,但在末尾添加了额外的属性。

那么,如何排除那些额外的标记呢?

嗯,结果证明 GPT-3 返回了每个标记的对数概率。如上所述,我们还计算了一个预测,假设我们在对数概率高于某个截止值时停止完成标记……

# Filter out low logprob predictions
pred = ""
seen_tokens = []
for w in res["logprobs"]:
    token = list(w.keys())[0]
    prob = w[token]
    if prob > logprob_cutoff and token not in seen_tokens:
        pred += token
        if '+' not in token:
            seen_tokens.append(token)
    else:
        break
pred = re.sub(r" |\+$|\+v_$", "", pred)

让我们看看在假设截止值为 -0.001 的情况下表现如何……

output_prediction_metrics(results, prediction_field="predicted_log_prob_cutoff")

Just HXL tags ...

Accuracy: 0.99
Precision: 1.0
Recall: 0.99
F1: 0.99

Tags and attributes with predicted_log_prob_cutoff ...

Accuracy: 0.94
Precision: 0.99
Recall: 0.94
F1: 0.95

这还不错,标签和属性的准确率为 0.94。既然我们知道正确的标签和属性在预测中出现的概率为 99%,我们应该通过调整对数概率截止值和进行一些后处理来做得更好。

结论与未来工作

以上是对 GPT-3 在预测元数据,特别是人道主义数据集上的 HXL 标签方面应用的快速分析。它在这一任务上的表现非常好,并且在类似的元数据预测任务中有很大的潜力。

当然,还需要更多的工作来完善方法,例如:

  1. 尝试其他模型(我上面使用的是 ‘ada’)以查看是否能改善性能(尽管这会增加成本)。

  2. 模型超参数调整。对数概率截止值可能非常重要。

  3. 可能需要更多的提示工程,比如在表格中包含列列表,以提供更好的上下文,以及在两行标题表格上的覆盖列。

  4. 更多的预处理。这篇文章的处理不多,盲目地使用从 CSV 文件中提取的表格,因此数据可能有些混乱。

也就是说,我觉得使用 GPT-3 来预测数据集上的元数据有很大的潜力。

敬请关注更多更新!

参考文献

[1] Mark Lowcock, 人道主义事务副秘书长和紧急救援协调员,Anticipation saves lives: How data and innovative financing can help improve the world’s response to humanitarian crises (2019)

[2] Sarah Telford, Opinion: Humanitarian world is full of data myths. Here are the most popular (2018)

[3] Vinitra Swamy 等人,人道主义数据的机器学习:使用 HXL 标准进行标签预测(2019 年)

用于此分析的笔记本可以在这里找到。

预测 NBA 薪资的机器学习方法

原文:towardsdatascience.com/predicting-nba-salaries-with-machine-learning-ed68b6f75566?source=collection_archive---------7-----------------------#2023-08-24

使用 Python 构建机器学习模型,以预测 NBA 薪资并分析最有影响力的变量

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Gabriel Pastorello

·

关注 发表在 Towards Data Science · 9 分钟阅读 · 2023 年 8 月 24 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(照片由 Emanuel Ekström 提供,来源于 Unsplash

NBA 作为体育界最盈利竞争激烈的联赛之一脱颖而出。在过去几年里,NBA 球员的薪资呈上升趋势,但每一个令人惊叹的扣篮和三分球背后都隐藏着决定这些薪资的复杂因素。

从球员表现和球队成功到市场需求和代言交易,许多变量都会影响结果。谁没想过为什么他们的球队会在表现不佳的球员身上花这么多钱,或对成功的交易感到惊讶?

在本文中,我们利用 Python 的机器学习能力来预测 NBA 薪资,并揭示对球员收入影响最大的关键因素

所有使用的代码和数据都可以在GitHub上找到。

理解问题

在深入探讨问题之前,了解联赛薪资系统的基本原理至关重要。当一名球员在市场上待签合同时,他被称为自由球员(FA),这是本项目中将频繁出现的术语。

NBA 在一套复杂的规则和规定下运营,旨在保持球队之间的竞争平衡。这个系统的核心有两个关键概念:薪资上限奢侈税

薪资上限作为支出限制,限制了球队在一个赛季中可以花费的球员薪资总额。上限由联赛收入决定,每年更新一次,以确保球队在合理的财务框架内运营。它还旨在防止大市场球队显著超支,促进球队之间的公平竞争。

薪资上限在球员之间的分配可以有所不同,顶级球员有最高薪资,而新秀和老将则有最低薪资。

然而,超越薪资上限并不罕见,尤其是对那些希望组建争冠阵容的球队。当一支球队超过薪资上限时,它进入了奢侈税的范畴。奢侈税对超支的球队施加处罚,抑制球队过度支出,同时为联赛提供额外收入。

还有许多其他规则作为例外,例如中层例外(MLE)和交易例外,允许球队进行战略性的阵容调整,但对于这个项目来说,了解薪资上限和奢侈税就足够了。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

NBA 薪资上限从 1984 年到 2023 年的演变(图片来源:作者)

由于薪资上限的持续增长,选择的方式将是使用薪资上限的百分比作为目标,而不是薪资金额本身。这个决定旨在融入薪资上限的演变特性,确保结果不会受到时间变化的影响,并在评估历史赛季时仍然适用。然而,需要注意的是,这并不完美,仅仅是一个近似值。

数据

对于这个项目,目标是使用仅来自上一个赛季的数据来预测球员签订新合同的薪资

使用的个人统计数据包括:

  • 每场比赛平均统计数据

  • 统计数据

  • 高级统计数据

  • 个体变量:年龄、位置

  • 与薪资相关的变量:上一赛季的薪资、上一赛季和当前赛季的最大薪资上限以及该薪资的上限百分比。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2022–23 赛季的薪资分布(图片由作者提供)

由于我们不知道球员将签约的球队,只包括了个体特征。

总体而言,这项研究为每名球员提供了78 个特征加上目标。

大部分数据通过BRScraper获得,这是我最近创建的Python 包,可以轻松抓取和访问来自Basketball Reference的篮球数据,包括 NBA、G 联盟和其他国际联赛。遵守了所有关于不对网站造成伤害或阻碍其性能的指南。

数据处理

一个值得考虑的有趣方面是选择球员来训练模型。最初,我选择了所有可用的球员,但大多数球员可能已经有合同,在这种情况下,薪资值不会发生剧烈变化。

例如,假设一名球员签订了为期 4 年的 2000 万美元合同。他每年大约获得 500 万美元(虽然很少每年的薪资完全相同,通常薪资在 500 万美元左右会有一定的递增)。然而,当自由球员签订新合同时,薪资值可能会发生更大幅度的变化。

这意味着用所有可用球员训练模型可能会整体上表现更好(毕竟,大多数球员的薪资接近最后一个!),但在评估仅自由球员时,表现会显著变差

由于目标是预测签订新合同的球员的薪资,因此数据中仅应包含这种类型的球员,以便模型能更好地理解这些球员之间的模式。

关注的赛季是即将到来的 2023–24 赛季,但会使用2020–21 赛季及之后的数据来增加样本量,这得益于目标的选择。由于缺乏自由球员的数据,未使用较早的赛季。

这使得在选择的三个赛季中有426 名球员,其中 84 名是 2023–24 赛季的自由球员。

建模

训练-测试拆分的设计是为了确保 2023–24 赛季的所有自由球员仅包含在测试集中,保持了大约 70/30 的拆分比例。

最初使用了几种回归模型:

  • 支持向量机(SVM)

  • 弹性网

  • 随机森林

  • AdaBoost

  • 梯度提升

  • 轻量级梯度提升机(LGBM)

通过均方根误差(RMSE)和决定系数()评估了每种方法的表现。

你可以在我之前的文章中找到每个指标的公式和解释,使用机器学习预测 NBA MVP

## 使用机器学习预测 NBA MVP

构建一个机器学习模型来预测 NBA MVP 并分析最有影响力的变量。

towardsdatascience.com

结果

查看整个数据集中的所有赛季,获得了以下结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

各种模型中获得的 RMSE 和 R²值(图像由作者提供)

模型整体表现良好,其中随机森林梯度提升取得了最低的 RMSE 和最高的 R²,而AdaBoost在使用的模型中表现最差。

变量分析

一种有效的可视化模型预测关键变量的方法是通过SHAP 值,这是一种提供合理解释每个特征如何影响模型预测的技术。

关于 SHAP 及其图表解读的更深入解释可以在使用机器学习预测 NBA MVP中找到。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

与随机森林模型相关的 SHAP 图表(图像由作者提供)

我们可以从这个图表中得出一些重要结论:

  • 每场比赛的分钟数 (MP) 和 得分 (PTS) 每场比赛总数 是三个最有影响力的特征。

  • 上赛季薪资 (Salary S-1) 和 该薪资的薪资占比 (% Cap S-1) 也非常有影响力,分别排在第 4 位和第 5 位。

  • 先进统计数据 在最重要的特征中并不占主导地位,只有两个出现了在列表中,WS (胜利贡献值) 和 VORP (替代球员价值)。

这是一项非常令人惊讶的结果,因为与MVP 项目不同,在该项目中,先进的统计数据主导了 SHAP 的最终结果,球员薪资似乎与常见统计数据如分钟、得分和首发场次有更大的关系。

这令人惊讶,因为大多数先进统计数据的设计初衷正是为了更好地评估球员的表现。PER (球员效率评级) 在前 20 名中缺席(排名第 43)尤为引人注目。

这提出了一个可能性,即在薪资谈判过程中,总经理可能遵循了一种相对简单的方法,可能忽视了更广泛的表现评估指标。

也许问题并没有那么复杂! 简化来看,打得时间最长、得分最多的球员赚得更多!

附加结果

聚焦于今年的自由球员,并将他们的预测与实际薪水进行比较:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2023–24 赛季随机森林模型的主要结果(单位:百万)(图片由作者提供)

在顶部,我们有五名被低估的球员(收入低于他们应得的),中间是五名估值正确的球员,底部是五名被高估的球员(收入高于他们应得的)。值得注意的是,这些评估仅基于模型的输出。

从顶部开始,前 MVP拉塞尔·威斯布鲁克是模型中最被低估的球员,我认为这是事实,因为他与快船签下了约每年 400 万美元的合同。埃里克·戈登梅森·普拉姆利马利克·比斯利也处于类似的情况,他们在获得良好表现的同时收入非常少。达安吉洛·拉塞尔也出现在这五人榜单中,尽管他的年薪为 1700 万美元,这表明他应该赚得更多。

值得注意的是,这些球员都签约了竞争性球队(快船、太阳、雄鹿和湖人)。这是一种已知行为,球员们选择减少薪水以有机会为能够赢得冠军的球队效力。

在中间,塔雷安·普林斯奥兰多·罗宾逊凯文·诺克斯德里克·罗斯的薪水都较低,看似足够。卡里斯·勒弗特年薪 1500 万美元,但也确实值这个价钱。

在底部,弗雷德·范弗利特被评为最被高估的球员。火箭队作为一支重建中的球队,在他的三年合同上投入了 1.285 亿美元,这是一项引人注目的举动。他们还签下了迪龙·布鲁克斯,合同金额高于预期。

克里斯·米德尔顿在这个夏天签下了大合同。尽管雄鹿队是一个竞争者队伍,但他们属于非主要市场,无法承受失去其中一名最佳球员的风险。德雷蒙德·格林卡梅隆·约翰逊在各自的球队中也有类似的情况。

结论

预测体育结果始终充满挑战。从目标选择到球员筛选,这个项目证明比预期要复杂。然而,结果证明其实相当简单,取得的结果非常令人满意!

当然,还有多种改进的方法,其中之一是使用特征选择或降维技术来减少特征空间,从而减少方差。

此外,访问之前赛季的自由球员数据也会使样本数量增加。然而,目前似乎没有公开获取这些数据的途径。

许多其他外部变量也会影响此问题。例如,毫无疑问,如果某种方式能够知道球队,像去年种子季后赛结果已经使用的薪资百分比这样的变量可能会非常有用。然而,保持镜像实际自由球员场景的方式,即球队未知,可能会得出更贴近球员**“真实价值”**的结果,不管签约球队的背景如何。

本项目的主要前提之一是仅使用上一赛季的数据来预测下一赛季的薪水。加入旧赛季的统计数据确实可能会提高结果,因为球员的历史表现可以提供有价值的见解。然而,这些数据的广泛性质将需要精心的特征选择来管理其复杂性和高维度。

再次说明,所有使用的代码和数据都可以在GitHub上找到。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(照片由Marius Christensen拍摄,来源于Unsplash

我始终可以通过我的渠道联系到(LinkedInGitHub)。

感谢您的关注!👏

加布里埃尔·斯佩兰扎·帕斯托雷洛

预测星巴克奖励计划的成功

原文:towardsdatascience.com/predicting-success-of-a-reward-program-at-starbucks-b32b77dcf9b8

初学者友好 — 从头到尾逐步解释完整项目

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Erdem Isbilen

·发布在 Towards Data Science ·9 分钟阅读·2023 年 6 月 20 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Robert LinderUnsplash 拍摄

项目概述

该项目专注于识别能够有效吸引星巴克现有客户并吸引新客户的奖励计划优惠。

星巴克是一家数据驱动的公司,通过利用包含客户信息、特别优惠和交易数据的数据集,致力于全面了解其客户。

为了开发一个能够确定奖励计划成功的模型,我将项目分为三个阶段

  1. 检查和清理 Udacity 提供的数据。

  2. 创建一个结合所有相关信息的数据集。

  3. 构建和评估三个分类模型的性能以预测特定人员的奖励计划的成功或失败

问题陈述

对市场营销活动进行重大投资是一个复杂的决策,需要获得各种利益相关者的批准、财务资源和时间。因此,拥有一个能够对特定目标群体发起特定优惠是否值得的预测模型,对于任何公司来说都是一个战略资产。

为了创建这个模型,我们将使用监督学习技术进行二分类。

模型的结果将指示该优惠是否预计会有效。

数据集探索与整理

Udacity 提供了三个 JSON 格式的数据集:portfolioprofiletranscript。每个数据集都有不同的用途,并为我们的分析提供了宝贵的信息。

Portfolio 数据集

该数据集提供了有关星巴克当前有效优惠的信息。

  • id(字符串)——优惠 id

  • offer_type(字符串)——优惠类型,即 BOGO、折扣、信息性

  • difficulty(整数)——完成优惠所需的最低花费

  • reward(整数)——完成优惠后给予的奖励

  • duration(整数)——优惠开放的时间,以天为单位

  • channels(字符串列表)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

组合数据集(原始)——图像作者:Erdem Isbilen

portfolio数据集中有十行和六列。这是一个简单的数据集,没有缺失、空值或重复值。

**‘channels’,‘id’,‘offer_type’列是分类变量,而‘difficulty’,‘duration’,‘reward’**是整数。

请参见我对数据集所做的修改:

  • 对**‘channels’‘offer_type’**进行独热编码

  • 将***‘id’更改为‘offer_id’***

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

组合数据集(经过数据清理)——图像作者:Erdem Isbilen

个人资料数据集

个人资料数据集包含有关星巴克客户的人口统计信息。

  • age(整数)——客户的年龄

  • became_member_on(整数)——客户创建应用账户的日期

  • gender(字符串)——客户的性别(注意有些条目包含‘O’代表其他,而不是 M 或 F)

  • id(字符串)——客户 id

  • income(浮点数)——客户的收入

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

个人资料数据集(原始)——图像作者:Erdem Isbilen

这个数据集中有 17000 行(数据集中唯一的人员数量)和 5 列,有 2175 个空值(出现在genderincome列中)。由于这些行的年龄值也为 118,我从数据集中删除了所有 2175 行。

请参见我对数据集所做的修改:

  • 删除 2175 行缺失值(也包括年龄值为 118 的行)

  • 将***‘id’更改为‘customer_id’***

  • ***‘become_member_on’***字符串转日期

  • 创建***‘year_joined’‘membership_days’***列

  • 对***‘gender’***进行独热编码

  • 创建***‘age_group’以将客户分类为青少年、年轻成人、成年人、老年人*

  • 创建***‘income_range’以将客户分类为普通、高于平均水平、高收入*

  • 创建***‘member_type’以将客户分类为新会员、常规会员、忠实会员*

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

个人资料数据集(经过数据清理)——图像作者:Erdem Isbilen

可以看到,加入该计划的人数在 2013 年至 2017 年之间呈上升趋势,2017 年是最佳年份。50%的会员年龄在 42 至 66 岁之间。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

年龄、收入、加入年份列的直方图——图像作者:Erdem Isbilen

如下所示,在低收入和中等收入区男性人口超过女性人口,而在高收入区女性人口超过男性人口。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

不同性别群体收入的直方图图表— 图片来源(作者)Erdem Isbilen

在考虑性别时,数据集存在一些偏差,因为男性人口数量超过女性人口,而且其他类别的人数较少。具体来说,数据集中有 8484 名男性、6129 名女性和仅 212 名其他类别的人。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

性别的直方图图表— 图片来源(作者)Erdem Isbilen

记录集数据

记录集数据捕捉了客户与优惠活动的互动。

  • event (str) — 记录描述(即交易、优惠接收、优惠查看等)

  • person (str) — 客户 ID

  • time (int) — 从测试开始的时间,以小时为单位。数据从时间 t=0 开始

  • value — (字符串字典) — 根据记录是优惠 ID 还是交易金额

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

记录集数据(原始)— 图片来源(作者)Erdem Isbilen

如果记录集中的事件对应于三种可能的优惠状态之一(查看、接收或完成),则值列包含优惠的id。除了优惠 ID 之外,如果事件状态为***‘offer completed’,还会有reward*** 值。

但是,如果事件是交易,则值列将仅显示交易金额

请查看我对数据集所做的修改:

  • 将***‘value’展开为‘offer_id’‘amount’‘rewards’*** 新列。

  • 通过将时间(小时)转换为天数来创建***‘time_in_days’***。

  • 将***‘person’更改为‘customer_id’***

  • 记录集数据拆分为两个子数据集:offer_tr(优惠数据)和transaction_tr(交易数据)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

offer_tr 数据集(数据处理后)— 图片来源(作者)Erdem Isbilen

构建模型

一旦数据集被清理并完成必要的修改,我们需要将它们合并为一个单一的数据集。然后我们将创建一个名为offer_successful的新列。该列将指示某个优惠是否对特定客户成功。这将使我们能够建立可以预测特定类型客户对某个优惠是否成功的模型。

客户必须查看完成优惠活动在允许的时间范围内,才能将该优惠视为成功

在开发模型时遇到的困难是数据集中没有特定的列指明某个客户的报价是否成功。因此,我开发了一个辅助函数,通过考虑已完成和查看的报价数据以及这些事件之间的时间范围来计算目标值。

成功和失败的报价数量分别为 35136 和 31365。这意味着我们在考虑目标时有一个平衡的数据集。因此,我们在选择模型时没有任何限制(这些限制通常与不平衡的数据集有关)。

为了创建一个可以预测报价是否成功的模型,我们需要在最终数据集上训练一个模型。

由于这是一个二分类问题,我们将使用三种不同的监督学习算法。

  • 逻辑回归

  • 随机森林

  • 梯度提升

我将使用 sklearn 的默认设置开始构建模型,以了解不同模型的准确性水平。

然后,我将使用RandomizedSearchCV进行 12 次迭代,以优化模型的超参数,因为它比GridSearchCV计算开销更小。RandomizedSearchCV通过从指定分布中随机抽样超参数值来工作。

这些模型中的每一个都在训练数据集上进行训练,并在测试数据集上进行评估,以避免过拟合,并查看模型在未见数据上的表现。

指标和结果

我将使用混淆矩阵和下面的指标来评估模型的性能。

我将特别关注精确度,因为研究的主要目的是尽可能准确地定义正类。

  • 准确率: 准确率是评估分类模型准确性最常用的指标。它通过将正确预测的数量除以预测总数来计算。

  • 精确度: 精确度是衡量模型在预测正类时准确性的指标。它通过将真正例的数量除以真正例的数量加上假正例的数量来计算。

  • 召回率: 召回率是衡量模型在预测正类时完整性的指标。它通过将真正例的数量除以真正例的数量加上假负例的数量来计算。

  • F1 分数: F1 分数是精确度和召回率的加权平均值。它通过将 2 * (精确度 * 召回率)除以(精确度 + 召回率)来计算。

超参数微调前的初始准确率指标:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表格形式的模型表现 — 图片作者(Erdem Isbilen)

超参数微调后的结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表格形式的模型表现 - 图片作者(Erdem Isbilen)

尽管所有模型提供了类似的性能指标,但随机森林准确性略微更高。我们发现超参数的微调对模型性能的影响很小。这可能是因为我们为调优选择的自定义参数比默认参数更差。因此,我们可能需要考虑选择不同的参数集来微调模型。

论证

由于随机森林梯度提升可以更好地处理异常值和高维数据,因此它们在我们的案例中表现优于逻辑回归。

结论

在这个项目中,我分析了星巴克提供的数据集,并开发了一个可以预测顾客是否会完成优惠的模型。

我对每个数据集进行了探索性分析,并可视化数据以获得全面的理解。这包括检查数据集的各个方面。随后,我进行了预处理部分,这是最耗时和具有挑战性的任务。数据集复杂,包含的数据需要应用数据整理、工程和预处理技术,以获得最终的清洁版本。

最后一步是开发几个二元分类模型以进行预测。模型构建完成后,使用测试数据评估模型的性能。结论是这三种模型提供了类似的性能结果。

改进

尽管这些模型提供了一个良好的起点,但精度为 66%的结果仍有改进的空间。

对于这个项目的一个有趣改进是创建多个监督学习模型,并将它们组合成一个自定义集成模型。通过组合多个监督学习模型来创建集成模型可以带来错误补偿的优势。通过利用不同模型的优势并弥补它们的弱点,集成模型实现了更好的泛化性能,从而提高了准确性和鲁棒性。

为了提高我们奖励计划预测的准确性,我们可以考虑将不同类型的奖励分开,并为每个计划开发单独的模型。通过这种方法,我们可以根据每个奖励计划的具体特征和目标量身定制我们的建模技术,从而实现更精确的预测和更好的结果。这种方法还可以帮助我们识别每个奖励计划中的任何独特趋势或模式,从而使计划设计和实施更加有效。

另一种优化奖励计划效果的方法是识别并排除那些无论奖励计划如何都购买的个人。通过这样做,我们可以将资源集中在那些更有可能积极影响计划结果的人身上,从而最大化我们从计划中获得的利益。

反思

我发现这个顶点项目是一个非常愉快的经历,让我能够提升数据预处理和建模的技能。数据预处理步骤是最耗时且令人畏惧的任务。然而,数据集本身很有启发性,这激励我继续努力工作。

使用 XGBoost 预测水泵的功能性

原文:towardsdatascience.com/predicting-the-functionality-of-water-pumps-with-xgboost-8768b07ac7bb

一个从“数据挖掘水表”竞赛中获得灵感的端到端机器学习项目

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Aashish Nair

·发表于 Towards Data Science ·10 分钟阅读·2023 年 6 月 1 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Kelly 拍摄:www.pexels.com/photo/close-up-of-a-child-s-hands-catching-water-from-the-spout-of-a-water-pump-3030281/

目录

∘ 介绍

∘ 目标

∘ 工具/框架

∘ 探索性数据分析

∘ 特征工程

∘ 创建训练和测试数据集

∘ 确定评估指标

∘ 创建基线模型

∘ 数据建模方法

∘ 超参数调优方法

∘ XGBoost 模型

∘ CatBoost 模型

∘ LightGBM 模型

∘ 选择最佳模型

∘ 模型解释

∘ 模型部署

∘ 限制

∘ 结论

∘ 参考文献

介绍

注:此项目灵感来自 DrivenData 主办的Pump it Up: 数据挖掘水表竞赛

坦桑尼亚目前面临严重的水危机,28%的居民缺乏安全用水。一种可行的解决办法是确保全国安装的水泵保持功能正常。

利用 Taarifa 提供的数据,这些数据来自坦桑尼亚水务部,有机会利用机器学习来检测不再功能或需要维修的水泵。

目标

本项目的目标是训练和部署一个机器学习模型,以预测水泵是否功能正常、无法运作或正常但需要维修。

工具/框架

本项目需要使用各种工具和框架。

便于数据分析和建模的脚本都使用 Python 编写。

数据预处理和特征工程使用 Pandas 和 Scikit Learn 模块完成。数据建模则结合了 Scikit Learn 和其他机器学习库进行。

最终模型被集成在一个使用 Streamlit 库构建的 Web 应用程序中。该应用程序随后通过 Heroku 进行部署。

为了更全面地了解项目的依赖关系,请访问GitHub 仓库

探索性数据分析

进行探索性数据分析(EDA)将揭示数据集的组成、数据应经历的过程以及应该考虑的机器学习算法。

提供的数据包括 59400 个数据点和 41 个特征,其中包括目标标签。

41 个特征如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

代码输出(作者创建)

注意:有关这些特征的详细信息,请访问竞赛问题描述

status_group特征将作为项目的目标标签。它揭示了水泵是否正常运作、无法运作,或正常但需要维修。

如代码输出所示,数据主要由类别特征组成。

此外,许多特征报告了类似的信息。例如,latitudelongituderegionregion_code特征都显示了水泵的位置。包含所有这些特征是多余的,甚至可能会影响模型的性能。

此外,数据集中有几个特征存在缺失值。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

缺失值(作者创建)

最终,目标标签中的值分布表明数据是不平衡的,其中“功能需要维修”类别的样本较少。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

目标标签(作者创建)

特征工程

从 EDA 的结果可以看出,许多特征在建模之前需要移除或修改。

  1. 移除不相关的列

idrecorded_bywpt_name特征已被移除,因为它们对目标标签没有影响。

2. 移除冗余列

包含冗余信息的特征也应当被移除。这些特征包括:subvillagelatitudelongituderegion_codedistrict_codelgawardscheme_nameextraction_typeextraction_type_grouppaymentwater_qualityquantitysourcesource_typewaterpoint_typemanagement

3. 创建“年龄”特征

construction_yeardate_recorded 特征与水泵的状态无关。然而,通过使用这两个特征,我们可以推导出水泵的“年龄”(即从建设开始的年数),以了解它们有多旧。

2. 移除弱预测变量

最后,应移除与目标标签关系不够强的预测变量。

数值特征与目标标签之间的关系通过 ANOVA 进行评估。以下代码片段创建了一个图表,显示了每个特征的 p 值。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

代码输出(作者创建)

类别特征与目标标签之间的关系通过卡方独立性检验进行评估。以下代码片段创建了一个图表,显示了每个特征的 p 值。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

代码输出(作者创建)

在测试的数值特征和类别特征中,只有 num_private 特征因其高 p 值而被移除。

在特征选择过程后,数据集从 41 个特征缩减为 18 个特征。

创建训练集和测试集

原始数据集被分割为训练集和测试集,并采用分层抽样,以确保目标标签中的各组在每个划分中有相同的代表性。

确定评估指标

数据已为建模做好准备,但首先需要确定最适合该项目的评估指标。

为此,我们需要考虑最终用户的优先级。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者创建

机器学习解决方案应通过检测非功能性或需要修理的水泵来提高获取清洁水的可及性。该解决方案还应通过正确识别不需要修理或更换的水泵来限制资金和资源的浪费。

值得注意的是,错误预测是非常不受欢迎的。

未能正确识别需要修理或更换的水泵(即假阴性)将减少清洁水的获取。依赖这些水泵的居民将无法用于农业和卫生目的,生活水平将下降。此外,建造水泵的政府和/或组织将失去声誉。

另一方面,未能正确识别功能正常的水泵(即假阳性)也是一种不理想的结果。这将导致将有限的资金和资源浪费在不需要修理或更换的水泵上。

鉴于假阳性和假阴性的巨大成本,机器学习模型应考虑精确度和召回率指标。然而,由于假阴性似乎更具后果,因此应更加重视提高召回率。

因此,用于项目的评估指标是 f2 分数指标,该指标考虑了精确度和召回率,但对召回率给予更大权重。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

F2-score 公式

创建基线模型

基线将有助于为机器学习模型的结果提供背景。本项目将利用两个基线模型作为参考:一个哑分类器和一个逻辑回归。

哑分类器将始终对水泵的功能进行随机预测。

在对类别特征和缺失数据进行编码和填补后,将训练一个具有默认超参数的逻辑回归模型。

逻辑回归的作用是展示一个简单模型在现有数据上的表现。如果逻辑回归的表现不如哑分类器,则会表明数据存在问题。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

代码输出(作者创建)

如输出所示,逻辑回归的 f2 分数明显高于哑分类器,这表明数据有足够的信号。

数据建模方法

构建模型的过程可以在以下流程图中呈现:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据建模流程图(作者创建)

将用于项目的三种模型包括Catboost 分类器LGBM 分类器XGBoost 分类器。所有这些分类器都包含集成学习,非常适合处理不平衡的数据。此外,它们支持类别特征和/或缺失数据。

对于这些模型中的每一个,都确定了最佳超参数集。然后使用这些超参数训练模型,并用测试集进行评估。

一旦每个模型都经过测试,最终将选择最佳模型(即 f-2 分数最高的模型)。该模型将用于网络应用程序。

超参数调整方法

超参数调整方法本身包含许多关键技术,因此值得通过另一个流程图进行详细说明。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

超参数调优流程图(作者创建)

超参数调优将使用 Optuna 库执行。

该过程包括创建一个 Optuna 研究。在每个研究中,分类器使用 100 个超参数集进行训练和评估。每个超参数集都通过分层交叉验证进行评估,该方法将训练数据分成多个折叠,每个折叠用来训练一个模型。

每个超参数集将通过训练模型的平均 f-2 分数来衡量。产生最高 f-2 分数的超参数集将被认为是最佳超参数集。

XGBoost 模型

为了展示对分类器进行的数据建模和超参数调优,以下代码片段展示了 XGBoost 的训练和评估过程。

首先,运行 Optuna 研究以找到 XGBoost 分类器的最佳超参数。

该研究将确定最佳超参数组合。这些超参数随后用于训练 XGBoost 分类器,然后在测试集上进行评估。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

代码输出(作者创建)

CatBoost 模型

使用 XGBoost 分类器的程序,对 CatBoost 分类器进行训练并在测试集上进行评估(有关整个代码库,请访问 GitHub 仓库)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

代码输出(作者创建)

LightGBM 模型

使用 XGBoost 分类器的程序,对 LightGBM 分类器进行训练并在测试集上进行评估(有关整个代码库,请访问 GitHub 仓库)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

代码输出(作者创建)

选择最佳模型

所有模型的性能记录在下表中。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每个模型的性能(作者创建)

由于 XGBoost 分类器产生了最高的 f-2 分数(≈0.80),因此被认为是最佳模型。

模型解释

XGBoost 模型的性能可以通过分类报告和混淆矩阵进行背景说明,这些报告将预测值与实际值进行比较。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

分类报告(作者创建)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

混淆矩阵(作者创建)

如分类报告和混淆矩阵所示,功能性和非功能性水泵的精度和召回率相对较高。然而,该模型在功能正常但需要修理的水泵上的表现不佳。

模型部署

现在建模过程已完成,该模型应部署到一个可供最终用户访问的 Web 应用程序中。

该网页应用程序是使用 Streamlit 库构建的,文件名为 app.py。该文件的底层代码如下所示:

当使用 streamlit run app.py 命令运行时,应用程序应如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Streamlit 应用程序(由作者创建)

该应用程序包含一个侧边栏,用户可以在其中输入感兴趣的水泵参数。在点击“预测水泵状态”后,XGBoost 模型将预测具有选定特征的水泵是否功能正常、功能失效或功能正常但需要维修。结果会在应用程序中输出。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

进行预测(由作者创建)

该网络应用程序也已通过 Heroku 托管,因此你可以通过点击下面的链接访问它:

[## Streamlit

预测水泵功能的 Streamlit 应用程序

water-pump-functionality-app.herokuapp.com](https://water-pump-functionality-app.herokuapp.com/?source=post_page-----8768b07ac7bb--------------------------------)

限制

尽管该项目已经产生了一个功能性网络应用程序,但它仍然存在值得注意的某些限制。

1. 没有现成的解决方案作为参考

尽管所提出的解决方案确实使用户能够确定坦桑尼亚水泵的功能,但由于没有现成的解决方案可以作为参考,因此很难向客户推介。因此,很难确定此模型能节省多少资金以及它能在多大程度上改善水的可及性。

2. 约束条件知识有限

项目在假设下进行,即假阴性(即将非功能水泵识别为功能正常)比假阳性(即将功能正常的水泵识别为非功能)更不可取。然而,只有在修理和更换水泵的资金和资源没有重大限制的情况下,这一假设才成立。

不幸的是,没有对这些约束条件的清晰理解,就无法确定最适合的机器学习模型评估指标。

3. 缺乏领域知识

数据集中在分类特征中存在许多独特的值。然而,DrivenData 没有提供这些值代表的含义。因此,该项目缺乏基于证据的分类特征处理策略。

结论

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Alexas_Fotos 提供,来源于 Unsplash

总体而言,该项目旨在利用 Taarifa 收集的数据训练一个预测水泵功能的机器学习模型,并将其纳入具有商业价值的应用程序中。

要访问整个代码库,请访问 GitHub 仓库:

[## GitHub - anair123/Detecting-Faulty-Water-Pumps-With-Machine-Learning

目前无法执行该操作。您在另一个标签或窗口中已登录。您在另一个标签或…

github.com](https://github.com/anair123/Detecting-Faulty-Water-Pumps-With-Machine-Learning/tree/main?source=post_page-----8768b07ac7bb--------------------------------)

感谢您的阅读!

参考文献

Bull, P., Slavitt, I., & Lipstein, G. (2016 年 6 月 24 日). 利用群众的力量来提高社会部门的数据科学能力。arXiv.org. arxiv.org/abs/1606.07781

什么是泊松分布

原文:towardsdatascience.com/predicting-the-unpredictable-an-introduction-to-the-poisson-distribution-5afd4d70b1d7

对最著名概率分布之一的概述

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Egor Howell

·发表于 Towards Data Science ·阅读时间 4 分钟·2023 年 6 月 6 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由Anne Nygård拍摄,来源于Unsplash

背景

泊松分布 是一种普遍的离散概率分布。它由西门·丹尼斯·泊松于 19 世纪初首次发布,并且已经在许多行业中找到应用,包括保险、流行病学和电子商务。因此,数据科学家需要了解这一重要概念。在这篇文章中,我们将深入探讨这一分布的复杂性,并提供实际世界的例子。

补充视频。

直观

泊松分布的核心概念是量化一个事件在给定时间间隔内发生特定次数的概率。

作为一个例子,假设我们有一个零售店,每小时平均接待 20 位顾客。利用泊松分布,我们可以计算该店在一个小时内接待特定数量顾客的概率,例如 10、15 或 30 位顾客。

理论

泊松分布的概率质量函数 (PMF) 为:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

公式由作者以 LaTeX 编写。

其中:

  • e: 欧拉数(约 2.73)

  • k: 发生次数(≥ 0)

  • X: 离散随机变量(≥ 0)

  • λ: 预期发生次数(≥ 0)

泊松分布的参数是λ,它表示发生次数的均值,E(X) = λ,以及分布的方差,VAR(X) = λ。有关均值和方差的推导,请参见这里

值得注意的是,泊松分布实际上是从二项分布派生出来的。虽然我们在本文中不会详细讨论其推导过程,但有兴趣的读者可以在这里找到相关信息。

泊松分布的条件:

  • 事件的数量, k*,是独立发生的 (泊松过程)*

  • 事件在时间间隔内随机发生

  • 事件的预期数量是固定的

  • 在时间间隔的任何点获得事件的概率是相等的

示例与图表

回到我们之前的商店示例,其中每小时的平均顾客数为 20。商店在一个小时内接待 10 名顾客的概率是多少?

所以,我们得到的是:

  • λ = 20

  • k = 10

将这些值代入 PMF 公式:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Equation by author in LaTeX.

如我们所见,这个概率非常低。为了更好地理解顾客访问的分布情况,我们可以绘制整个 PMF:

GitHub Gist by author.

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Plot generated by author in Python.

正如观察到的,顾客访问的分布几乎呈钟形曲线,最可能的顾客数量为 20。这是因为 20 是预期的数量。为了进一步了解,我们可以探索一些均值为 10 或 30 的情景,并绘制相应的分布:

GitHub Gist by author.

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Plot generated by author in Python.

因此,当均值变小的时候,分布中的大部分概率质量会向左移动。这种移动是预期的,因为均值率代表了顾客访问的预期速率。因此,顾客人数更可能接近均值。

总结与进一步思考

泊松分布是数据科学和统计学中广泛使用且著名的概率分布。它基于给定的均值率来建模事件以特定速率发生的概率。泊松分布在遗传学、保险、欺诈检测等多个行业中都有应用。

如果你想查看本文中使用的完整代码,它可以在我的 GitHub 上找到:

[## Medium-Articles/poisson.py at main · egorhowell/Medium-Articles

我在我的中等博客/文章中使用的代码。通过创建帐户贡献于 egorhowell/Medium-Articles 的开发…

github.com](https://github.com/egorhowell/Medium-Articles/blob/main/Statistics/Distributions/poisson.py?source=post_page-----5afd4d70b1d7--------------------------------)

另一件事!

我有一个免费的通讯,Dishing the Data,在这里我分享每周的提示,帮助你成为更好的数据科学家。没有“虚 fluff”或“点击诱饵”,只有来自实际数据科学家的纯粹可操作的见解。

[## Dishing The Data | Egor Howell | Substack

如何成为更好的数据科学家。点击阅读由 Egor Howell 发表的 Substack 文章《Dishing The Data》,…

newsletter.egorhowell.com](https://newsletter.egorhowell.com/?source=post_page-----5afd4d70b1d7--------------------------------)

与我联系!

参考文献和进一步阅读

各种逻辑回归模型的预测(第一部分)

原文:towardsdatascience.com/prediction-in-various-logistic-regression-models-2543281cd55a

R 系列统计

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Md Sohel Mahmood

·发表于Towards Data Science ·阅读时间 8 分钟·2023 年 4 月 16 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由Jen Theodore拍摄,来源于Unsplash

介绍

在之前的几篇文章中,我们已经涵盖了各种类型的逻辑回归模型。这些模型的目标是尽可能准确地预测未来的数据点以及中间数据点。在本文中,我们将讨论如何在 R 中进行预测分析,包括简单和多重逻辑回归,使用二分类和有序数据。

数据集

成人数据集将作为我们研究的一部分案例研究。该数据集中收集的数据包括超过 30,000 名个体的详细信息。数据包括每个人的种族、教育背景、职业、性别、工资、每周工作小时数、持有的工作数量以及他们的收入水平。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

成人数据集来自 UCI 机器学习库

数据集简述:

  • Bachelors:1 表示该人拥有学士学位,0 表示该人没有学士学位。

  • Income_greater_than_50k_code:1 表示家庭总收入超过 50,000 美元,0 表示家庭总收入低于 50,000 美元。

  • Marital_status_code:1 表示该人已婚,0 表示该人未婚或离婚。

  • Race_code:1 表示非白人,2 表示白人。

用于二分类数据的简单逻辑回归预测

我们将通过上述数据集来识别两个可以用于预测二元收入结果的变量,这些收入结果可以大于 $59K 或小于 $50K,利用教育水平和婚姻状况变量。 该研究提出了以下问题:

教育水平对收入的影响是什么?

要进行预测分析,首先需要安装ggpredict库。第一个命令将提供二元“学士学位”变量的预测概率。我们知道Bachelors变量可以有两个值:0 和 1。R 将为家庭收入(也是一个二元变量)提供大于 $50k 的概率。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

第一个输出提供了预测收入概率的表格数据。在这里,如果学士学位 = 1,则家庭收入大于 $50k 的概率变为 0.47,而如果没有学士学位,该概率降至 0.16。 这告诉我们教育对家庭收入有重要影响。

多重逻辑回归预测二元数据

使用上述数据集,我们将采用两个预测变量:教育水平和婚姻状况,以预测二元收入结果,这些结果可以大于 $50K 或小于 $50k。这里的研究问题是:

教育水平和婚姻状况对收入的综合影响是什么?

使用二元数据进行多重逻辑回归的实现与下面的简单逻辑回归非常相似。

在这里,我们希望将婚姻状况作为另一个预测变量来预测家庭收入。使用类似的ggpredict命令,我们得到以下结果。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

对于第二个有 2 个预测变量的模型,概率数据已经为另一个变量的均值进行了调整。这里我们有第二个预测变量为Marital_status_code,其均值为 0.47。 这告诉我们,数据集中有 47% 的人已婚,而 53% 的人要么未婚要么离婚。保持该值不变,如果一个人有学士学位,则家庭收入大于 $50k 的概率为 0.44。如果没有,则概率降至 0.17。

简单逻辑回归预测有序数据

有时我们可能会有多于 2 个响应水平的结果变量,且该变量是有序的。我们数据集中家庭收入变量只有两个结果水平,但如果响应变量有超过 2 个结果,则可以遵循相同的方法。

回归模型的目的是对以下数据集中的问题提供定量解释:

教育水平、性别和种族对收入的个体影响是什么?

为了定义一个响应变量为有序的逻辑回归模型,我们可以使用clm()命令来自ordinal包。首先,我们需要将预测变量和响应变量转换为因子。这里响应变量有两个以上的类别,通常这个模型被称为比例奇数(PO)模型。

我们将使用相同的数据集来进行 PO 模型的预测。预测时,我们也将使用来自ggeffects库的*ggpredict()*命令。

首先预测教育代码为 5、10 和 13 的家庭收入是否大于$50k,这些代码分别代表 9 年级、高中毕业和博士学位。这里响应水平 1 表示家庭收入低于$50k 的群体,响应水平 2 表示家庭收入高于$50k 的群体。如果个人拥有 9 年级教育,那么家庭收入低于$50k 的概率是 0.98,家庭收入高于$50k 的概率是 0.02。如果个人拥有博士学位,家庭收入低于$50k 的概率是 0.35,家庭收入高于$50k 的概率是 0.65。因此,教育水平越高,家庭收入一般也越高。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如果我们使用“学士学位”作为预测变量进行相同的预测,我们观察到没有学士学位的个体家庭收入低于$50k 的概率是 0.84,家庭收入高于$50k 的概率是 0.16。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

以“教育年限”作为预测变量,我们可以得出教育年限越长,获得高收入的概率越高。如果个体有 16 年的教育,家庭收入高于$50k 的概率是 0.69。值得注意的是,所有响应水平的概率总和等于 1。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

多重逻辑回归在有序数据中的预测

在有序逻辑回归中,预测变量可以是有序的、二元的或连续的,而响应变量是有序的。

考虑预测收入的例子,其中教育的等级只有两个响应级别。例如,要进行回归分析,我们可以从 1 年级到博士学位分配有序的数字。也可以使用二进制变量来预测收入。例如,我们可以将 1 分配给拥有学士学位的人,将 0 分配给没有学士学位的人。这有点像一个有两个级别的有序变量。最后,我们还可以使用连续变量如教育年限来预测收入。在这里,我们尝试定量回答以下问题。

教育水平、性别和种族对收入的综合影响是什么?

在第一个模型中,我们使用了所有三个预测变量来预测收入,结果如下所示。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

收入水平分为两种(收入超过 50000 美元和收入低于 50000 美元),因此也有两个响应级别(如上图所示的红色)。在第二列中,你会发现每个教育水平的预测概率。教育水平为 3(5-6 年级)时,收入低于 50000 美元的概率为 0.99,而教育水平为 13(博士学位)时,收入低于 50000 美元的概率为 0.36。根据第二个响应级别的预测结果,可以得出相同的结论。因此,很明显家庭收入与教育水平有正相关关系。较高的教育水平与家庭收入的提高相关联。与往常一样,这些结果使用了其他两个变量Gender_codeRace_code的均值进行调整。

当在第二个ggpredict命令中包含性别时,我们得到以下结果。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

现在我们有了性别和收入水平的 2 个层次,因此我们有了 4 个表格。如果一个人的性别代码为 1(女性),并且他们拥有博士学位(教育代码 13),那么他们预期收入超过 50000 美元的概率为 0.42;而如果这个人的性别代码为 2(男性),且他拥有博士学位,那么预期收入超过 50000 美元的概率为 0.74。换句话说,这表明女性在相同教育水平下没有得到与男性相等的薪酬。与往常一样,该模型只使用了Race_code的均值进行调整。

结论

在这篇文章中,我们对二项和序数逻辑回归模型进行了预测分析,使用了单一和多个预测变量。我们涵盖了这四个模型中*ggpredict()*命令的使用,并且定量讨论了结果。作为提醒,模型的整体表现将取决于数据清理的程度。存在不必要的数据、重复的数据或错误的数据会导致模型结果误导。

数据集致谢

Dua, D. 和 Graff, C. (2019). UCI 机器学习库 [http://archive.ics.uci.edu/ml]. 加州欧文:加州大学信息与计算机科学学院。

感谢阅读。

阅读 Md Sohel Mahmood 的所有故事

加入 medium

各种逻辑回归模型的预测(第二部分)

原文:towardsdatascience.com/prediction-in-various-logistic-regression-models-part-2-f8994e306a4c

R 语言统计系列

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Md Sohel Mahmood

·发布在Towards Data Science ·阅读时长 8 分钟·2023 年 4 月 27 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Vladimir Fedotov拍摄,发布在Unsplash

介绍

我们已经涵盖了二元和有序数据的逻辑回归模型,并演示了如何在 R 语言中实现该模型。此外,使用 R 库的预测分析也在早期文章中讨论过。我们已经看到单一和多个预测变量对响应变量的影响,并对其进行了量化。我们采用了二元和有序响应变量,展示了如何处理不同类型的数据。在本文中,我们将探讨四种额外的逻辑回归模型预测分析,即广义有序回归模型、部分比例奇数模型多项逻辑回归模型和泊松回归模型。

数据集

我们的研究将使用相同的UCI 机器学习库的成人数据集作为案例研究。该数据集中收集了超过 30000 名个体的人口统计数据。数据包括每个人的种族、教育、工作、性别、工资、持有的工作数量、每周工作小时数和收入。为了更好地理解,以下是所考虑的变量。

  • 教育:数值型和连续型。个人的健康状况可以受到教育的重大影响。

  • 婚姻状态:二元(0 代表未婚,1 代表已婚)。这个变量的影响可能会很小,但它仍然被纳入了分析中。

  • 性别:二元(0 代表女性,1 代表男性)。它也有可能影响较小,但这值得关注。

  • 家庭收入:二元(0 代表平均或低于平均,1 代表高于平均)。健康状况可能受到影响。

  • 健康状态:有序(1 代表差,2 代表一般,3 代表良好,4 代表优秀)

广义有序回归模型中的预测

考虑我们收集了数百个个体的数据。数据中包括有关个人的教育、年龄、婚姻状态、健康状态、性别、家庭收入和全职就业状态的信息。教育、性别、婚姻状态和家庭收入将作为回归模型中的预测变量。除了教育外,预测变量都是二元的,这意味着它们的值要么是 0,要么是 1。教育是一个连续变量,表示个人受教育的年限。以下变量被考虑用于此次回归分析。

  • 教育年限

  • 婚姻状态

  • 性别

  • 家庭收入

  • 健康状态

如果我们执行有序逻辑回归并保持比例奇数假设,那么每个预测变量的系数值将为 1。假设家庭收入的系数为‘x’,这意味着每单位家庭收入增加(在这种情况下从 0 到 1),健康状态的更高类别的对数概率或对数赔率将增加‘x’。因此,我们可以得出关于此模型的以下结论。

  1. 如果家庭收入增加到高于平均水平,从差健康状态变为平均健康状态的对数赔率是‘x’。

  2. 如果家庭收入增加到高于平均水平,从一般健康状态变为良好健康状态的对数赔率是‘x’。

  3. 如果家庭收入增加到高于平均水平,从良好健康状态变为优秀健康状态的对数赔率是‘x’。

比例奇数模型的特点是所有结果水平上的对数赔率相同。现实数据常常违反这一假设,因此我们不能使用比例奇数模型。如前所述,解决这一非比例奇数问题的两个可能解决方案是采用广义有序模型或部分比例奇数模型。

  • 广义有序回归模型 -> 所有预测变量的所有级别的效果可能会有所不同

  • 部分比例奇数模型 -> 所有/部分预测变量的某些级别的效果允许变化

我们已经在早期文章中实现了广义方法和部分比例奇数方法的模型。

## 广义有序回归模型在 R 中的应用

R 语言中的统计系列

[towardsdatascience.com ## 部分比例奇数模型在 R 中

R 统计系列

[towardsdatascience.com

现在我们将使用这些模型实现预测过程。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在这里,我们可以看到提供的教育值下,不同健康状态的累计预测概率。我们知道我们的健康状态有四个独特的值。

如果个人有 15 年的教育,

  • 健康一般及以上的累计概率是 96%

  • 健康良好及以上的累计概率是 77%

  • 健康优秀的累计概率是 24%

如果个人只有 5 年的教育,

  • 健康一般及以上的累计概率是 81%

  • 健康良好及以上的累计概率是 41%

  • 健康优秀的累计概率是 8%

因此,很明显,教育年限在决定个人健康状态方面起着重要作用。如果我们只想获得预测概率,可以执行以下命令。

ggpredict(model1, terms = “educ[5,10,15]”,ci=NA)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如果个人有 15 年的教育,

  • 健康不良的概率是 4%

  • 健康一般的概率是 20%

  • 健康良好的概率是 52%

  • 健康优秀的概率是 24%

如果个人只有 5 年的教育,

  • 健康不良的概率是 19%

  • 健康一般的概率是 40%

  • 健康良好的概率是 33%

  • 健康优秀的概率是 8%

显然,教育年限增加了拥有更好健康的概率。所有这些值都已针对婚姻状况、性别和全职工作状态的均值进行了调整。

部分比例奇数模型中的预测

在部分比例奇数模型中,我们可以选择希望不同结果水平的效应变化的预测变量。我们可以首先确定哪些预测变量违反了比例奇数假设,然后将这些变量放在parallel = FALSE ~ 命令之后。这里,我们将婚姻状况和家庭收入作为违反假设的预测变量。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如果个人有 15 年的教育,

  • 健康不良的概率是 4%

  • 健康一般的概率是 20%

  • 健康良好的概率是 52%

  • 健康优秀的概率是 24%

如果个人只有 5 年的教育,

  • 健康状况差的概率为 17%

  • 健康状况一般的概率为 41%

  • 健康良好的概率为 35%

  • 健康状况优秀的概率为 7%

累积概率也可以使用之前描述的方法计算。

多项回归模型中的预测

我们在以下文章中介绍了多项逻辑回归分析。

## R 中的多项逻辑回归

R 中的统计系列

[towardsdatascience.com

多项回归是一种统计方法,用于估计个体落入特定类别的可能性,相对于基准类别,利用对数几率或对数几率比的方法。实质上,当名义响应变量有多个结果时,它作为二项分布的扩展来工作。作为多项回归的一部分,我们需要定义一个参考类别,模型将基于参考类别确定各种二项分布参数。

在以下代码中,我们定义了健康状态的第一个级别作为参考水平,我们将基于这个参考水平比较多个二项回归模型。

我们的预测方法得出了以下结果。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如果个体接受了 15 年的教育,

  • 健康状况差的概率为 4%

  • 健康状况一般的概率为 19%

  • 健康状况良好的概率为 52%

  • 健康状况优秀的概率为 25%

再次,这些预测的概率是在保持其他预测变量均值的情况下计算的。在多项逻辑回归中,响应变量应为名义变量。然而,这里的响应被转换为序数变量以使用*ggpredict()*命令。

泊松回归模型中的预测

有时我们需要处理涉及计数的数据。为了对计数响应变量建模,例如博物馆访问次数,我们需要使用泊松回归。到医院的访问次数或特定学生群体修读的数学课程数量也可以作为示例。我们在以下文章中介绍了泊松回归。

## R 中的泊松回归

R 中的统计系列

[towardsdatascience.com

我们将使用相同的数据集,预测从教育年限、性别、婚姻状况、全职工作状态和家庭收入中得出的科学博物馆访问次数。代码块如下所示。

使用相同的*ggpredict()*命令,我们获得了不同教育年限以及不同性别的以下结果。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

  • 如果个体是女性(性别=0)且教育年限为 15 年,则预测的科学博物馆访问次数为 0.44。

  • 如果个体是男性(性别=1)且教育年限为 15 年,则预测的科学博物馆访问次数为 0.62。

  • 这表明女性访问科学博物馆的频率低于男性。结论已针对婚姻状况、全职工作状况和家庭收入的均值进行调整。

结论

在这篇文章中,我们覆盖了四种不同类型回归模型的预测分析。部分比例奇数模型可以视为广义序数回归模型的一个子集,因为 PPO 模型允许只有少数预测变量在不同层次上变化。多项式回归模型适用于类别没有顺序的名义响应变量。最后,泊松回归模型适合预测计数变量。我们展示了在所有四种回归模型中使用*ggpredict()*函数以及结果的解释。

数据集致谢

Dua, D. 和 Graff, C. (2019). UCI 机器学习库 [http://archive.ics.uci.edu/ml]。加州欧文:加州大学信息与计算机科学学院(CC BY 4.0)

感谢阅读。

[## 使用我的推荐链接加入 Medium - Md Sohel Mahmood

阅读 Md Sohel Mahmood(以及 Medium 上的成千上万其他作者)的每一个故事。您的会员费直接…

mdsohel-mahmood.medium.com](https://mdsohel-mahmood.medium.com/membership?source=post_page-----f8994e306a4c--------------------------------) [## 每当 Md Sohel Mahmood 发布文章时,您将收到电子邮件。

每当 Md Sohel Mahmood 发布文章时,您将收到电子邮件。通过注册,您将创建一个 Medium 账户(如果您还没有的话)…

mdsohel-mahmood.medium.com](https://mdsohel-mahmood.medium.com/subscribe?source=post_page-----f8994e306a4c--------------------------------)

请请我喝咖啡

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值