作者:Jason Brownlee
翻译:wwl校对:车前子
本文约4000字,建议阅读3分钟本文介绍了haberman乳腺癌生存二分类数据集,进行神经网络模型拟合。包含数据准备、MLP模型学习机制、模型稳健性评估。
根据新数据集开发神经网络预测模型是一个挑战。
一种方法是先对数据集进行探查,然后思考什么模型适用于这个数据集,先尝试一些简单的模型,最后再开发并调优一个稳健的模型。
这个流程适用于为分类、回归预测模型问题开发高效的神经网络。
本教程中,你将学习如何开发一个多层感知机神经网络模型,用于癌症生存二分类数据集。
完成本教程后,你将了解到:
如何加载和汇总癌症生存数据集,根据结果来进行数据准备和模型配置。
如何探索MLP模型拟合数据的学习机制。
如何得到稳健的模型,调优并做预测。
开始吧!
Bernd Thaller拍摄
概览
本教程分为4部份:
Haberman 乳腺癌生存数据集
神经网络学习机制
模型鲁棒性评估
最终的模型及预测
Haberman 乳腺癌生存数据集
首先,定义数据集并作数据探查。
我们使用的是“haberman”标准二分类数据集。
数据集描述的是乳腺癌患者的数据,结局事件是患者生存,具体是指病人是否生存了五年活以上,或患者是否存活。
这是学习不平衡数据分类问题的标准的数据集。数据集的背景描述表明,研究是在1958年到1970年期间,在芝加哥大学的Billings医院开展的。
数据集有306个样本,3个输入变量:
病人在手术期间的年龄;
手术的两位数年份;
检测到的腋窝淋巴结阳性数,这是衡量癌症是否已扩散的一种手段。
我们只有以上数据,无法选择组成数据集合的病例,以及病例的特征。
尽管这个数据集描述的是乳腺癌患者的生存情况,但考虑到数据集的样本量少,以及这些数据是基于发生在几十年前的乳腺癌病例,因此基于这个数据集的模型并不具备泛化能力。
备注:声明,我们不是要治愈乳腺癌,而是在探索一种标准的分类数据集。
以下是数据集的前5行的抽样。
从以下链接,可以对这个数据集有更多了解:
Haberman Survival Dataset (haberman.csv)(https://github.com/jbrownlee/Datasets/blob/master/haberman.csv)
Haberman Survival Dataset Details (haberman.names)(https://github.com/jbrownlee/Datasets/blob/master/haberman.names)
可以直接从URL中加载数据集,保存为pandas DataFrame,如下:
执行这个例子,可以直接从这个URL加载数据,获得数据集的维度。
本例中,我们可以确定,数据集有4个变量(3个输入1个输出变量),有306行数据。