TowardsDataScience 博客中文翻译 2020(四百二十六)

原文:TowardsDataScience Blog

协议:CC BY-NC-SA 4.0

机器学习的概要

原文:https://towardsdatascience.com/headstart-ml-part-1-distilled-summary-of-ml-d1f038c6ff9a?source=collection_archive---------37-----------------------

介绍最大似然法的现状及其所有重要概念

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

ML 与传统编程有何不同?

传统的编程包括手动或脚手架(之前已经写好,因此是克隆/重复的)逻辑过程,该逻辑过程编写一种方法,以根据某些预定义的约束将某个输入转换成期望的输出。简而言之:

传统编程:
输入+程序(通过编码器)=输出

相比之下,机器学习(ML)及其相关领域通过查看一大组正确的输入及其相应的正确输出来自动化这种提出逻辑程序的手动过程,并“学习”这种输入到输出的映射,称为 模型

然后用这个生成的模型 用一组新的输入来预测一组新的输出。简而言之:

机器学习:
a)输入+输出=模型
b)模型+新输入=新预测输出

这个【学习】是怎么发生的?

学习发生在训练阶段,此时大量的正确输入及其相应的正确输出被组合起来,并用于训练,以获得一个模型,该模型随后可用于对新的输入值进行预测。

这个庞大的正确输入和输出集合一起被称为训练集**。**

最后,当模型完成训练后,我们可以预测一组新输入值的输出。

这一组被称为测试组**,我们在其上测试网络的准确性。**

历史:

如果我们看看过去 10 年(2010 年至 2019 年)的技术世界,一个取得巨大进步的研究领域是 ML。

在这十年之前,由于缺乏计算资源来支持最大似然算法所需的大量并行化任务,这些算法仅用于一小部分电子商务和数字市场网站以提供推荐。

在本世纪初,当 GPU 开始成为趋势,计算的大量并行化成为现实时,该领域和子领域(如深度学习)以及其他相关领域(如计算机视觉)都取得了快速发展,每 3 至 6 个月发布一次最新模型,在一个或多个指标(如 准确性 泛化 )方面比前代模型表现更好

这些术语的含义是什么?

准确度 是指 ML 模型能够预测正确输出的程度。该输出可以是分类变量(即是或否的答案)或数字(即某公司股票的价格)。

泛化 指的是在被“训练”以对相似任务域进行预测之后,ML 模型如何能够对任务域进行预测。
也就是说,一个经过训练可以区分猫和狗的模型可以扩展到其他动物吗,比如人类和黑猩猩?

健壮性 是指 ML 模型如何处理输入的“怪异边缘情况”。也就是说,当人穿着狗装而狗穿着 t 恤时,被训练来区分人和狗的模型也能区分它们吗?

****易训练性衡量模型“学习”预测给定正确输出所花费的计算资源和训练时间(时间)用于训练模型在不同的测试输入集上达到给定的精确度。如果模型 A 需要 4 个小时的训练时间来以 80%的准确度区分人和狗,而模型 B 在 12 个小时的训练中以相同的准确度在 1000 张新的人和狗的图像上进行测试,那么模型 A 更好,因为它花费更少的时间来训练和重新训练。

****注:除了训练时间之外,不同模型的推理时间(即模型对给定输入做出预测所用的时间)也可能不同,但通常这一时间比训练时间小得多,因此除非提供的输入非常大,否则这不是一个大问题。

请注意,这与解决任务 a 并给出正确输出的传统编程非常相似。如果一段代码 a 以 O(n)线性时间复杂度完成任务,另一段代码 b 以 O(n^2 多项式时间复杂度完成任务,那么第一段代码 a 会被认为更好,因为它对于巨大的输入会执行得更快。

即*如果用一个模型来决定何时卖出或买入某家公司的股票。由一名 ML 工程师向一名经验丰富的技术股票分析师解释训练后的模型考虑了哪些因素,这被视为其可解释性***

在 2017-2018 年之后,模型对于领域利益相关者来说更具解释力,因此,像它是否偏向于某种类型的输入,或者它是否公平地对待所有输入,或者没有违反生态系统的任何规则这样的问题可以由可能不是 ML 专家的领域专家来验证。

由于最近对消费者数据保护的共同兴趣,以及复杂的 ML 算法在做出任何决定时会考虑哪些数据点,以及它们在这样做时是否遵循所有道德和法律规则,这些进步获得了急需的支持。

什么是不同的类型的 ML 算法?****

基于提供给它们的训练集,ML 算法大致分为:

监督学习:
监督 学习中,既提供了模型在最佳情况下应该预测的输入(也称为
非目标属性
)及其正确输出(也称为目标属性地面真值)。****

任务是学习一个函数 F,它取非目标属性 X 并输出一个逼近目标属性的值,即F(X)≈y目标属性 y 作为指导学习任务的老师,因为它提供了学习结果的基准。因此,这项任务被称为监督学习。**

**在虹膜数据集中,虹膜花的类别可以作为目标属性。带有目标属性的数据通常被称为 标注为 数据。基于上述定义,对于使用标记数据预测鸢尾花的类别的任务,可以看出这是监督学习。

无监督学习:

与监督学习不同,我们在非监督学习任务中没有基础事实。人们期望从数据中学习潜在的模式或规则,而没有预先定义的基础事实作为基准。**

聚类算法是非监督学习的例子之一。

半监督学习:

在数据集很大但标记样本很少的情况下,人们可能会发现监督和非监督学习的应用。我们可以把这个任务称为 半监督学习

**如果想要预测图像的标记,但是只有 10%的图像被标记。通过应用监督学习,我们用已标记的数据训练一个模型,然后我们应用该模型预测未标记的数据。很难说服我们自己这个模型足够通用,毕竟,我们只是从少数数据集中学习。更好的策略可能是首先将图像聚类成组(无监督学习),然后对每个组单独应用监督学习算法。

第一阶段的无监督学习可以帮助我们缩小学习范围,以便第二阶段的监督学习可以获得更好的精度。

基于最大似然算法进行预测的类型,最大似然算法大致分为:

回归算法:

预测连续数值的算法,如一天中的温度预测,或特定时间给定股票的股价,称为回归算法。

回归算法的例子有线性回归多元回归等。

分类算法:

预测离散分类值的算法被称为分类算法,例如基于胸部 X 射线预测某人是否患有新冠肺炎病(,或者给定图像中是否有猫、狗或人()。****

分类算法的一个例子是逻辑回归。

ML 模型预测的常见问题有哪些:

尽管使用 ML 模型进行预测存在许多问题,但我们将重点关注两大类型:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

预测房价的线性回归模型的拟合不足、拟合恰到好处和拟合过度的示例。来源:公开开放 MOOC 斯坦福的机器学习来自 Coursera

欠拟合 欠拟合模型是与训练数据拟合不佳的模型,明显偏离训练集目标变量。

拟合不足的原因之一可能是模型对数据过于简化,因此无法捕捉数据中隐藏的关系。

从上图可以看出,在第一部分中,为了预测房价几乎线性的线无法正确预测房价,预测价格与实际价格相差太大。在这里,一个简单的线性模型(一条线)不能够“拟合”价格曲线,从而导致价格预测的重大误差。**

作为避免上述欠拟合原因的对策,可以选择能够从训练数据集生成更复杂模型的替代算法。

人们也可以有一个大而多样的训练集,以避免模型的欠适应。

欠拟合也可以直观地与 高偏差 联系起来,因为这里的模型是在高偏差的原则下,即房价随着房子的面积(大小)线性增加。

过拟合

过拟合模型是与训练数据很好地拟合的模型,即很少或没有误差,然而,它不能很好地推广到看不见的数据。

与拟合不足的情况相反,一个能够拟合每一点数据的过于复杂的模型会陷入噪声和误差的陷阱。

从上图中可以看出,在第 3 部分中,模型在训练数据中的价格预测误差几乎为零,但它更有可能会在看不见的数据上出错。

类似于欠拟合的情况,为了避免过拟合,可以尝试另一种算法,该算法可以从训练数据集生成更简单的模型。

或者,保持生成过拟合模型的原始算法,但是将正则化项添加到该算法中,惩罚过于复杂的模型,使得该算法在拟合数据的同时生成不太复杂的模型。

过度拟合也可以直观地与 高方差 相关联,因为这里模型预测非常具体的价格与房屋面积相比突然下降或上升,只是小于或大于房屋的当前输入面积,因为它在训练数据上过度拟合。

因此,证明了我们的直觉,当相对于独立变量,即房屋的面积(大小)作图时,过度拟合的模型对于房屋价格具有高的

资源

从现在开始,您可以从以下网址进一步了解上述主题:

谢谢你一直读到最后一点!

我是 Ravi Vats,软件工程师,毕业于班加罗尔 Ramaiah 理工学院的计算机科学与工程专业。我感兴趣的领域是深度学习、ML、算法和数据结构、可伸缩和并发系统、数据分析和可视化。

您可以通过我的 LinkedIn 个人资料与我联系。

或者,我也可以在推特脸书InstagramQuora 上找到。

我希望你觉得这个系列有趣且足智多谋。我随时欢迎任何编辑或建议,以增强本系列中提供的信息。

为学习干杯!😃

医疗聊天机器人可以帮助疫情

原文:https://towardsdatascience.com/healthcare-chatbots-can-help-with-the-pandemic-bcc07fc606c9?source=collection_archive---------59-----------------------

随着全球疫情给医疗系统带来压力,这将是对远程健康聊天机器人的最终考验

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

沃洛季米尔·赫里先科在 Unsplash 上的照片

电子医学已经存在了很多年,随着交流方式的进步,它也在不断变化和发展。亚历山大·格雷厄姆·贝尔最早遇到的情况是,当你发高烧时,他把酸洒在他的裤子上让你的妈妈打电话给医生,于是他用电话寻求帮助。随着时间的推移,远程医疗发生了巨大的变化。如今,高速互联网、视频会议和物联网设备让远程医疗变得更加可行。

新冠肺炎极大地改变了我们互动的方式,随着虚拟医生咨询的出现,远程医疗正迅速成为一种新常态。例如,Babylon Health 通过医患视频聊天、咨询历史和疾病报告在英国迅速扩张。不出所料,随着新冠肺炎疫情的推出,该平台的人气飙升。随着医疗保健行业和技术的不断发展,人们不禁要问:人工智能在远程医疗中能扮演什么角色?

第一批聊天机器人可以追溯到 20 世纪 60 年代。第一批著名的聊天机器人之一是伊莱扎,创造了一个对最初的精神病学采访的模仿。自 2010 年以来,大多数聊天机器人都用于客户服务。如果软件中有错误或者你想报告反馈,你可以使用聊天机器人,它会指引你正确的方向,而不是搜索用户文档。然而,聊天机器人刚刚开始在医疗保健行业站稳脚跟。

患者参与

对于习惯于即时信息的人来说,使用现有的聊天机器人如 Alexa、Siri 和 Facebook Messenger 来回答问题并改善与医疗专业人员的联系可以加强提供者和患者之间的关系。现有的聊天机器人提供的服务包括安排预约、定位医疗设施和鼓励临床试验。

此外,增加患者参与度有助于解决具体问题。在一个例子中,人工智能解决方案有助于减少结肠镜检查的“缺席”。聊天机器人有助于满足我们“按需”社会中消费者的现有期望,并改善患者体验。

患者诊断

最近,人工智能聊天机器人已经开始转变病人诊断和家庭自我护理。对于新冠肺炎疫情,有必要进行自我治疗,以减少疾病的传播和医疗保健专业人员的负担。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

沃洛季米尔·赫里先科在 Unsplash 上的照片

它是如何工作的?假设你头痛,并开始发冷。你量了体温,有轻度发烧。你吃止痛药吗?你去看医生吗?聊天机器人将检查你的症状,询问与医疗保健专业人员相同的筛查问题,以排除病症。在告诉聊天机器人你的症状,并为你提供可能的诊断后,使用自然语言处理。简单来说,我们不一定知道该问什么样的问题。聊天机器人可以帮助识别你可能没有意识到的潜在症状。

更好的信息

聊天机器人可以帮助他们以更高的准确性缩小潜在原因的范围,而不是让 WebMD 上的患者匹配自己的症状。

Buoy Health 的准确率为 92 %,而 WebMD 的准确率为 56 %, Health line 的准确率为 53 %, Mayo Clinic 的准确率为 38%

例如 Ada Health 就是一个这样做的聊天机器人。如果你感觉不舒服,Ada 会“问”你一系列越来越具体的关于你症状的问题。该应用程序在审查可能的原因时为您提供可能性指标,以及您可以采取的后续措施。此外,Ada 允许您跟踪您的症状,如果它们升级到与医疗保健提供者共享。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

阿达健康聊天机器人

总之,这些功能有助于患者更准确地推断他们的疾病,使人们能够采取更好的自我护理措施,并防止可能的歇斯底里。此外,医疗保健聊天机器人有助于提供多样化和个性化的资源。例如, Youper 是一个聊天机器人,为患者提供心理健康的个性化资源。人工智能可以建议一些技术来监控你的进展,测量健康状况,并与你的医疗保健专业人员分享信息。同样,内置于脸书 messenger 的癌症聊天机器人为癌症患者、护理人员和家人提供最新资源。与简单地在线阅读静态信息不同,根据患者的需求定制资源非常有益。

数据聚合

由于数据驱动的方法,使用大型数据库有可能提供更准确的信息。 Buoy Health ,另一个用于自我保健诊断的聊天机器人,在他们的解决方案中使用了大量数据。与其他聊天机器人类似,它提供了一个基于文本的症状检查器,将您的输入与可能的情况进行匹配。该公司使用来自 18000 篇医学论文(500 万名患者,1700 种情况)的临床数据进行训练,以模仿医生。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

浮标健康助手聊天机器人

例如,在 Buoy Health 推出之前,它进行的一项测试是如何解释咳嗽。据首席执行官乐介绍,“他们研究了 100 个标准化病例,涉及 33 种不同的诊断,从良性咳嗽到危及生命的肺栓塞,以及罕见疾病的流行,如组织胞浆菌病和普通感冒。Buoy 在 92%的情况下都是正确的,相比之下,WebMD 的正确率为 56%,Healthline 为 53%,Mayo Clinic 为 38%。(来源)

此外,Buoy 随着时间的推移向用户学习,测量症状的流行程度。该公司正在使用该软件收集传统研究中无法测量的大量数据。迄今为止,Buoy Health 还没有发表多少关于他们从其平台上了解到的数据的研究文章,但他们确实有一个工作人员写作团队为症状诊断创建已发布的内容。

不去看医生

对于更小的健康问题,聊天机器人可以完全避免去看医生。据统计,78%去急诊室的病人本可以在急诊室外得到照顾。你的。MD 是一个人工智能聊天机器人,其主要功能是自我保健检查器。它决定了你正在经历的症状是否值得去看医生,或者你只是在经历一种自我护理的状况,如普通感冒或过敏。它可以加快患者的分类,并为患者提供更个性化的体验。然而,如果聊天机器人确定你的病情更加严重,它还可以快速联系你所在地区的医疗服务提供者

不去看医生还有很多“隐藏”的好处。除了节省时间之外,依靠保险看医生可能会很贵。聊天机器人如 Ada,Buoy 和 Your。MD 可以免费使用。有趣的是,这些聊天机器人通过向你推荐医疗机构来创收。在不去看医生的情况下,呆在家里可以防止疾病传播。另一方面,你也限制了自己在医疗机构感染其他疾病的可能性。此外,与家庭医生和急救设施不同,聊天机器人是全天候可用的。即使对于简单的医疗问题,如果你在半夜醒来打寒战,这也是一个很好的便利。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

创建于 1976 年,这张历史性的照片展示了一位公共卫生科学家从一位老年妇女身上抽血,在 1976 年 10 月开始的全国范围的猪流感疫苗接种活动中进行测试。|图片由疾控中心Unsplash 拍摄

聊天机器人的挑战

不愿采用新技术

聊天机器人缺乏医生给予的个人接触和直觉。尽管人们毫不犹豫地使用聊天机器人进行基本的客户支持,但许多人不愿意使用聊天机器人来讨论敏感的医疗信息。

Conversa Health 是一个聊天机器人,它简化了患者和提供者之间的客户服务和管理。首席执行官 Kouris Kalligas 是医疗保健创新的倡导者,但提醒该领域的其他人,他们需要认识到系统中现有的限制。

“创新者应该记住,推动技术发展的是看似合理的需求,而不仅仅是技术上的可能性。”

—首席执行官康沃斯·卡利加斯

在过去几年中,像 Marriot、MyFitnessPal、Target 和 MyHeritage 这样的公司都经历过重大数据泄露。尤其是对于敏感的医疗数据,提供安全的解决方案和保护技术以确保数据不会落入他人之手至关重要。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

阿瑟尼·托古列夫Unsplash 上的照片

聊天机器人永远不会完全取代医生

你的。MD 注册为1 类医疗器械,符合欧盟医疗器械指令,可用作医生流程的辅助设备。它指出,这些服务在法律上不提供诊断、医疗建议或治疗。尽管理论上聊天机器人会比使用谷歌的自我研究提供更有针对性的结论。更多的是明显的实际限制。有了医疗设备和程序,聊天机器人在诊断和解决医疗疾病时只能触及表面。

作为医疗助理的聊天机器人

如前所述,聊天机器人的局限性将阻止它们完全取代医生。随着人工智能聊天机器人的改进和日益成为主流,聊天机器人和医生一起工作可能会成为现实。例如, Sensely 目前向医疗保健提供商提供解决方案,以提供患者服务和健康评估。它有助于护理机构为预约提供动态登记服务,并帮助消费者处理保险索赔。此外,它通过问卷和反馈为医生提供了来自患者行为的实时信息。

这种类型的“候诊室问卷”已经存在了几十年,但聊天机器人解决方案可以提供更多针对患者习惯的个性化问题,自动附加到患者记录中,并让医生立即了解患者预先筛选的健康状况。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

苹果的新冠肺炎筛选工具,照片由 Brian McGowanUnsplash 上拍摄

此外,聊天机器人还有其他明显的局限性,例如对婴儿和老人的治疗,病人无法使用这项技术。然而,随着技术的发展,聊天机器人甚至可以帮助替代最基本的任务。因为聊天机器人会做琐碎的工作,医生会利用聊天机器人的资源做出诊断。随着医生和护士的压力减少,他们将能够看到更多需要人工智能聊天机器人无法提供的关键治疗的病人。

医疗保健行业中交织的聊天机器人显然应用了社会公益的概念。然而,重要的是要始终考虑技术可以帮助的程度,以及实施技术会在哪些方面阻碍整体用户体验。此外,经常提到的是,医疗科技初创公司从风险投资公司获得了大量资金。透过烟雾来看哪些公司提供了最准确的结果是很重要的。

鉴于医疗行业的某些方面已经过时,而另一些方面却非常先进,聊天机器人可以改善患者的整体体验,以提供准确的医疗信息。

新冠肺炎(新型冠状病毒肺炎)

虽然全球疫情给医疗系统带来了巨大压力,但看看聊天机器人的使用是否有所增加将是一件有趣的事情。在疫情期间,许多初级保健机构和特护医生已经远程切换。越来越多的人已经习惯于联系他们的医生,并通过视频分享关于他们健康的私密细节。就在几个月前,医生对待病人的方式还被认为是“不恰当的”,现在他们也变得更加自如了。然而,一些专业(例如整形外科)将面临为患者提供虚拟护理的挑战。

Twitter 允许员工永远在家工作

随着许多公司考虑永久转移到远程环境,医疗保健行业将继续发展。这将是对远程医疗的最终考验,看看虚拟咨询、患者参与平台和自我护理聊天机器人等程序是否会在疫情飓风过后继续存在。

彼得·巴鲁(@彼得鲁巴 )是利哈伊大学 CS +设计专业的学生。他联合创立了一个教育技术安全平台 SmartPass,并且喜欢在⛷滑雪

指导撰写Daniel p . lop resti教授,Data X Initiative 主任,利哈伊大学计算机科学教授(CSE 398: AI for Social Good)

参考文献&延伸阅读

[1]魏曾鲍姆,约瑟夫。伊莱扎——一个用于研究人机自然语言交流的计算机程序

[2]麦克,希瑟。“数字健康初创公司 Buoy 推出人工智能驱动的症状检查聊天机器人。” MobiHealthNews ,2017 年 3 月 8 日

使用深度学习的心律失常检测

原文:https://towardsdatascience.com/heart-arrhythmia-detection-using-deep-learning-a659848f2742?source=collection_archive---------12-----------------------

如何利用 Tensorflow / Keras 结合 CNN 和 LSTM 进行时间序列分类

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

西蒙·米加杰Unsplash 上拍摄的照片

深度学习最近正在帮助解决的一个常见问题涉及时间序列分类。解决这类问题的经典方法是从我们拥有的信号中生成特征,并训练一个机器学习模型。

手工制作功能的过程可能会占用您项目日程的很大一部分。就此而言,采用卷积神经网络与长短期记忆递归神经网络的组合。这种架构已被证明是有效的,可以减少花费在特征工程上的时间。

在本文中,我们将训练几个模型来检测不规则的心律。其思想是逐步展示如何为时间序列分类建立一个序列模型。如果你想深入了解这个问题,我强烈推荐这篇来自斯坦福 ML 小组的文章

处理数据集

在这个实验中,我们将使用麻省理工学院-BIH 心律失常数据库,其中包含从 47 名受试者获得的 48 个半小时的双通道动态心电图记录摘录。心脏病专家对数据集进行了注释,所有标签和对数据收集方式的完整解释可在此处找到。

为了这个实验的目的,让我们只考虑两类:正常节拍和异常节拍。从标签中丢弃无效的节拍。这些被认为是标签,不在这些列表上的被标记为正常。

每个记录由一个信号及其注释组成。这两个文件都是通过库 Python WFDB 读取的。可以使用以下方式读取记录:

record = wfdb.rdrecord(filename)
annotation = wfdb.rdann(filename, "atr")

注释包含每个节拍注释。在我们的模型中,每个样本将是一个作为目标的节拍标签和一个围绕这个节拍的序列作为输入(每边 3 秒)。

现在让我们定义两个函数来帮助我们处理记录。一个简单的分类记录,给定上面的列表,分为正常和异常的心跳。另一个函数将为我们的 CNN/LSTM·1D 模型构建一个样本。

为了更好地理解第二个函数,我们需要理解 CNN 或 LSTM 模型的输入是如何格式化的。你必须输入一个形状为*的三维数组(batch_size,sequence_size,number_of_features)。*在这种情况下,每个样本将是 (1,序列大小,1) ,因为我们的模型中只有一个特征。稍后,你会看到,对于 CNN 和 LSTM,我们还需要定义另一个维度(子序列)。

接下来,下面的代码将为每个患者构建一个序列列表(输入)标签列表(目标)和一个地图。其思想是根据异常心跳的比率进行分层训练和验证。

培训模式

让我们尝试不同的模型来训练我们的分类器。这里的主要目的是展示如何设置 keras 层,这些架构设计得不是很好。第一个是 CNN 唯一的模型。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

该模型如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

为了训练模型,让我们调用 keras 的 fit 方法:

使用这种架构,我们可以在验证集中实现 0.82 的准确度。这可能并不意味着一个惊人的模型,尤其是因为没有改善的时代的损失。

另一个用于序列模型的通用架构是 CNN 和 LSTM。这个想法是用 CNN 层作为特征提取层,用 LSTM 来解释子序列的顺序。

这种建筑看起来是这样的:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用顺序 API 构建模型与上面所做的没有太大变化。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用这种架构,验证集的精确度为 0.8。

下一步,你可以通过丢弃层或内核正则化来减少过度拟合,从而改进这些模型。

这个实验的代码在这篇文章的报告中。

心脏病分类

原文:https://towardsdatascience.com/heart-disease-classification-8359c26c7d83?source=collection_archive---------34-----------------------

深入分析

XGBoost 会显示与医生使用的相同的风险因素吗?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Alexander SinnUnsplash 上拍摄的照片

介绍

心血管疾病(CVD)或心脏病是美国人死亡的主要原因之一。疾病控制预防中心估计每年有 647,000 人死亡。CVD 是一个总括术语,包含不同的心脏状况,包括血管病变(动脉粥样硬化或血管炎)、结构问题(心脏肥大)和心律不齐(心律不齐)。在心血管疾病中,美国最常见的心脏病是冠状动脉疾病。大多数时候,CVD 是“无声的”,在个体经历心脏病发作、心力衰竭或心律失常的迹象或症状之前,没有诊断。研究已经确定了与发展心血管疾病相关的危险因素。这些风险因素可以是不可修改的,其中因素不能被改变,或者是可修改的因素,其中因素可以被改变。

不可修改的风险因素有:

  • 年龄增长
  • 生物性别——男性比女性面临更大的风险
  • 遗传的

可修改的风险因素包括:

  • 吸用烟草
  • 高血胆固醇
  • 高血压
  • 身体不活动
  • 肥胖
  • 糖尿病

医生可以利用这些风险因素获得洞察力,为患者建议生活方式的改变或治疗策略。我想研究的一个好奇的问题是,XGBoost 树模型能否根据医生使用的这些风险因素预测某人是否患有 CVD。

数据

用于进行这项分析的数据来自克利夫兰、匈牙利、瑞士和弗吉尼亚州长滩的四家医院汇编的数据集。这些数据被称为 UCI 心脏病数据集。该数据集由 303 个具有 14 种属性的个体组成,其中 138 个个体没有 CVD,165 个个体有 CVD。最初,有 76 个属性,但是发表的实验提到使用只有 14 个属性的子集。目标变量是使用任何主要血管中的直径变窄来诊断心脏病。截断百分比为 50%(见下文属性#14)。

只用了 14 个属性:
1。年龄:以年为单位的年龄
2。性别:性别(1 =男性;
0 =女性)3。cp:胸痛型— 1:典型心绞痛,2:不典型心绞痛,3:非心绞痛性疼痛,4:无症状
4。trestbps:静息血压(入院时以毫米汞柱为单位)
5。chol:血清胆固醇的毫克/分升
6。fbs:(空腹血糖> 120 mg/dl) (1 =真;
0 =假)7。restecg:静息心电图结果— 0:正常,1:ST-T 波异常(T 波倒置和/或 ST 抬高或压低> 0.05 mV),2:根据 Estes 标准
8,显示可能或明确的左心室肥大。thalach:达到最大心率
9。exang:运动诱发心绞痛(1 =是;
0 =否)10。oldpeak =运动相对于休息诱发的 ST 段压低
11。坡度:最大运动 ST 段的坡度— 1:上坡,2:平地,3:下坡
12。ca:透视着色的主要血管数(0-3)13。thal: 3 =正常;6 =修复缺陷;7 =可逆缺陷
14。目标:心脏病的诊断(血管造影疾病状态)— 0: <直径缩小 50%,1: >直径缩小 50%

数据集和所有变量信息可以在这里找到:https://archive.ics.uci.edu/ml/datasets/Heart+Disease

方法

使用 XGBoost 树模型有两个原因:1)该模型是通过分割特定特征创建的,2)它对其他形式的决策树 models⁴.更健壮

特征选择

没有进行特征工程,因为 1)只有 14 个特征,以及 2)每个特征被视为彼此独立的变量,以查看哪些特征有助于预测。此外,为了确保这一点,我检查了 14 个属性之间是否存在共线性。从皮尔逊相关性来看,变量之间不存在强相关性(图 1)。

此外,数据被分成 70:30 的比例(训练:测试比例)。这种分割是必要的,因为数据集只有 303 个人,这是一个相对较小的数据集。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图一

基线模型

为了验证模型是否得到了适当的调优,我使用了 XGBoost 树的默认值,目标是“binary: logistic”。然后,使用分层 k-fold 的 10 倍交叉验证以准确度作为验证的度量进行拟合。测试准确度约为 81%,每个折叠的平均验证率为 77.81%,标准偏差为 11.09%。因此,基线模型过度拟合了数据。此外,每个折叠之间的度量标准偏差为 11.09%。这表明一些折叠表现不佳。因此,超参数调整解决了这个问题。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2:基线模型的混淆矩阵

超参数调谐

使用随机搜索 10 倍交叉验证来确定 XGBoost 树的最佳参数。最佳参数为:

{colsample_bytree:1,
learning_rate:0.1,
max_depth:4,
min_child_weight:1e-05,
n_estimators:200,
objective:’binary:logistic’,
subsample:0.5}

超参数调谐模型

与基线模型相似,使用精确度指标对调整后的模型进行了 10 重交叉验证。每个折叠的平均准确度为 82.12%,标准偏差为 7.47%,比基线好得多。测试的准确度为 84.6%,曲线下面积度量为 0.84。XGBoost 在没有过度拟合的情况下比基线表现得更好。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3:超参数调整模型的混淆矩阵

分析

模型的有效性

Dinh 等人的一篇论文使用 dataset⁶.国家健康和营养调查(NHANES)的 XGBoost 模型来预测糖尿病和心血管疾病对于心血管疾病分类,他们的 XGBoost AUC 指标为 0.831。尽管这些数据集略有不同,但在我的模型的关键特征和 Dinh 等人的关键特征之间,前五个特征有一些重叠。Dinh 等人的 XGboost 确定了 1)年龄,2)收缩压,3)自我报告的体重,4)胸痛的发生,和 5)舒张压为关键因素。而在我的模型中确定的主要特征是:1)年龄,2)胆固醇(chol),3)达到的最大心率(thalach),4)静息血压(trestbps),和 5)相对于静息运动诱发的 ST 段压低(oldpeak)。尽管变量不同,但我的模型和 Dinh 等人的模型在三个特征上是一致的:年龄、血压和胸痛(注意,旧峰是运动诱发的 st 段压低,这是诊断阻塞性冠状动脉粥样硬化的可靠发现,阻塞性冠状动脉粥样硬化可因血流减少而引起胸痛)

特征重要性图显示了 XGBoost 模型对特征进行分割的次数(f 分数而非 f1 分数)。因此,年龄被分成大量的时间来确定 CVD 的存在。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 XGBoost 模型用于心脏病分类的主要特征。f 得分是模型使用特定特征进行分割的次数

特征分析

年龄

XGBoost tree 用来区分一个人是否患有心脏病的首要特征之一是年龄,这是一个不可修改的风险因素。从生理学角度来看,年龄是心血管疾病的决定性风险因素。随着年龄的增长,主动脉和颈动脉的顺应性降低。这意味着我们的主动脉和颈动脉变得更硬,从而使老年人的血压比正常人高,这是心血管疾病和动脉粥样硬化的危险因素。此外,65 岁及以上的年龄组更有可能患上 CVD⁷.图 5 描述了患有 CVD 的年龄组中的人的百分比。从这张图表中可以看出两点。首先,60 岁以上的人都有心血管疾病,但请注意,并不是所有 65 岁以上的人都有心血管疾病的风险。这可能是由于成功的老化。成功的衰老是指个体通过适当的 exercise⁸.维持身体的生理功能而通常的衰老是没有明显的心脏病理,但有一些功能下降。其次,图表显示有大量的年轻人或中年人患有心血管疾病。这可能是由于导言中提到的可变风险因素。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5:每个年龄组的总人数。百分比是患有心血管疾病的人数。

不可改变的风险因素。

根据前五个特征,胆固醇和血压是已知的心血管疾病的危险因素,医学研究已证明会导致心脏病。然而,XGBoost 用来确定某人是否患有 CVD 的一个因素是达到的最大心率(MHR)。MHR 告诉我们运动时心脏每分钟应该跳动的平均次数。数据字典没有具体说明它是如何计算的,但 MHR 通常是通过从 220 减去一个人的年龄来计算的。值得注意的趋势是最大心率随着年龄的增长而降低。这可能是由于随着年龄的增长,凋亡的窦房结细胞数量减少。下面是美国心脏协会按年龄分类的表格。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表 1 来自[10]

下面的两张图(图 6 和图 7)描绘了有或没有 CVD 的平均 MHR。描绘的结果对我来说有点奇怪。两组都没有达到美国心脏协会指出的 MHR,但是有 CVD 的组比没有 CVD 的组更有可能达到 MHR。根据学术文献,最大运动诱发心率与心血管死亡率呈负相关。所以,MHR 越高,心血管疾病的机会就越少。一种可能性是,由于心脏的代偿机制,这些人的静息心率已经设置得很高。由此需要注意的另一个重要方面是机器学习模型的性质。XGBoost 是一个数学模型,它只根据输入到模型中的数字进行分类。因此,由于两者之间存在显著差异,该模型使用 MHR 作为进行分割的特征。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 6:患有心脏病的每个年龄组的平均最大心率

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 7:没有心脏病的每个年龄组的平均最大心率

限制

这种分析的局限性是不能在其他基于树的模型上运行数据,比如 light-XGBoost 或 random forest。这些模型本可以给出更好的结果。这种分析的另一个限制是数据集没有单独的测试集。因此,我必须从 303 个人中创建我的测试集,这减少了可能影响结果的训练样本的数量。

结论

XGBoost 模型确实揭示了与医生评估 CVD 潜在风险相似的风险因素。这个小评估证明了在医学中使用数据科学算法的有效性。在更大的心血管数据集上创建稳健的机器学习模型作为 CVD 的初步筛选工具是有效的。

密码

https://github.com/DharaRan/Heart-Disease

参考文献

[1]SISON,G. (2020 年 5 月 09 日)。这就是心脏病对美国人的影响。检索于 2020 年 5 月 30 日,来自https://www . single care . com/blog/news/heart-disease-statistics/

【2】关于心脏病。(2020 年 3 月 20 日)。于 2020 年 5 月 30 日从https://www.cdc.gov/heartdisease/about.htm检索

[3]了解你的风险,以防止心脏病发作。(2016).检索于 2020 年 5 月 30 日,来自https://www . heart . org/en/health-topics/heart-attack/understand-your-risks-to-prevent-a-a-heart-attack

[4]格伦,S. (2019 年 7 月 28 日)。决策树 vs 随机森林 vs 梯度提升机器:简单解释。2020 年 5 月 30 日检索,来自https://www . datascience central . com/profiles/blogs/decision-tree-vs-random-forest-vs-boosted-trees-explained

[5]库马尔,A. (2020 年 01 月 04 日)。心脏病预测模型(5+模型)。2020 年 5 月 30 日检索,来自https://medium . com/@ AK 8427916/heart-disease-prediction-model-5-models-c6aa 269 FB 74

[6] Dinh,A .,Miertschin,s .,Young,A. 一种用机器学习预测糖尿病和心血管疾病的数据驱动方法。 BMC Med 通知决策制定 19 211 (2019)。https://doi.org/10.1186/s12911-019-0918-5

[7]北 BJ,辛克莱·达。衰老和心血管疾病的交汇点。保监会决议。2012;110(8):1097‐1108.doi:10.1161/circresaha 11.246686766766

[8] UT 西南医学中心。(2018 年 1 月 8 日)。适当的锻炼可以逆转心脏老化带来的损害。科学日报。2020 年 5 月 29 日从 www.sciencedaily.com/releases/2018/01/180108090132.htm检索

[9]弗莱彻,J. (2020 年 1 月 5 日)。不同年龄的胆固醇水平:差异和建议。).于 2020 年 5 月 30 日从https://www.medicalnewstoday.com/articles/315900检索

10]美国心脏协会。(2015).了解你运动、减肥和健康的目标心率。2020 年 5 月 30 日检索,来自https://www . heart . org/en/healthy-living/fitness/fitness-basics/target-heart-rates

[11] Sandvik,l .,Erikssen,j .,Ellestad,m .,Erikssen,g .,Thaulow,e .,Mundal,r .,和 Rodahl,K. (1995 年)。运动时心率增加和最大心率是心血管死亡率的预测因素:一项对 1960 名健康男性的 16 年随访研究。冠状动脉疾病6 (8),667–679。

使用 Apache Spark ML-二进制分类进行心脏病预测

原文:https://towardsdatascience.com/heart-disease-prediction-using-apache-spark-ml-808073f52495?source=collection_archive---------43-----------------------

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

凯利·西克玛在 Unsplash 上的照片

介绍

有各种医学参数影响患有心脏病的人。可能是年龄、胆固醇、血糖、静息血压等等。

这里,我们将使用机器学习中的分类来创建预测模型。分类是一项受监督的机器学习任务,我们希望自动将数据分类到一些预定义的分类方法中。基于数据集中的特征,我们将创建一个模型来预测患者是否患有心脏病。我们将在 Apache spark 中使用各种分类算法,并根据预测得分选择最佳算法。

关于数据集

我们有一个基于特征的心脏病患者或非心脏病患者的详细资料数据集。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者图片

注:数据集可以从 Kaggle 下载。

用的 spark 版本是 3.0.0。

读取数据

Spark 可以读取 CSV、Parquet、Avro 和 JSON 等不同格式的数据。这里的数据是 CSV 格式,使用下面的代码读取。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者图片

机器学习的数据准备

我们有几节课?

对于分类任务,检查因变量中的类别不平衡很重要。如果存在严重不足或过度表示的类,模型预测的准确性可能会因为模型本质上存在偏差而受到影响。

如果我们看到类不平衡,一种常见的纠正方法是引导或重新采样数据帧。

df.groupBy('target').count().show()+------+-----+
|target|count|
+------+-----+
|     1|  165|
|     0|  138|
+------+-----+

由于班级人数差不多,我们可以继续了。

检查空值

我们必须查看数据帧中有多少列具有空值。如果空值的百分比非常小,我们可以丢弃数据。因为我们的数据集是干净的,没有空值,所以我们可以开始了。

处理偏斜度和异常值

偏斜度衡量值的分布偏离平均值的对称程度。零值意味着分布是对称的,而正偏斜度表示较小值的数量较大,负值表示较大值的数量较大。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者图片

处理偏斜度的一个常见建议是对正偏斜数据进行对数变换,或者对负偏斜数据进行指数变换。

异常值
纠正异常值的一种常见方法是通过下限和上限,这意味着编辑高于或低于某个阈值(第 99 个百分点或第 1 个百分点)的任何值,使其回到该百分点的最高/最低值。例如,如果第 99 个百分位数是 96,而值是 1,000,您可以将该值更改为 96。

将数据分成测试和训练数据集

现在,我们可以使用随机拆分方法将数据拆分为测试和训练数据集。我们将按 70/30 的比例分割数据。

train,test **=** final_data.randomSplit([0.7,0.3])

训练和评估算法

既然我们已经清理和矢量化了数据,我们就可以将它输入到我们的训练算法中了。我们将测试不同的算法,如逻辑回归,随机森林分类器,梯度推进树分类器和决策树分类器。

交互效度分析

Spark 有一个名为 CrossValidator 的内置函数来进行交叉验证,该函数首先将训练数据集分成一组“折叠”,用作单独的训练和测试数据集。例如,当 k=5 倍时,CrossValidator 将生成 5 个不同的(训练、测试)数据集对,每个数据集对使用 4/5 的数据进行训练,1/5 的数据进行测试。为了评估一个特定的参数(在 paramgrid 中指定),CrossValidator 计算 5 个模型的平均评估度量,这 5 个模型是通过在 5 个不同的(训练、测试)数据集对上拟合评估器而产生的,并在完成后告诉您哪个模型执行得最好。

在确定了最佳参数图之后,CrossValidator 最终使用最佳参数图和整个数据集重新拟合估计器。

现在,我们可以使用最佳模型来评估模型。fitModel 自动使用最佳模型,所以我们这里不需要使用 best model。

运行评估器后,我们将得到如下结果。

!!!!!Final Results!!!!!!!!
+----------------------+------+
|Classifier            |Result|
+----------------------+------+
|LogisticRegression    |89.32 |
|RandomForestClassifier|86.27 |
|GBTClassifier         |84.31 |
|DecisionTreeClassifier|80.82 |
+----------------------+------+

我们可以看到,逻辑回归具有最高的预测得分,因此我们可以使用逻辑回归的最佳模型来创建我们的最终模型。

我们可以通过 spark 中的特征选择技术对模型进行更多的调整。我们可以使用 ChiSqSelector 特征选择方法。ChiSqSelector 代表卡方特征选择。它对带有分类特征的标记数据进行操作。ChiSqSelector 使用卡方独立性检验来决定选择哪些特性。

因此,在对数据进行特征选择、训练和评估后,我们得到以下结果作为逻辑回归分类算法的预测得分。

+------------------+------+
|Classifier        |Result|
+------------------+------+
|LogisticRegression|91.08 |
+------------------+------+

最后我们的分类算法模型得到了 91%的预测分数。我们可以通过调整 spark 提供的更多参数来提高模型的预测得分。

源代码可以在 Github 上找到。

参考

https://spark.apache.org/docs/latest/ml-guide.html

使用机器学习预测心脏病发作风险

原文:https://towardsdatascience.com/heart-disease-risk-assessment-using-machine-learning-83335d077dad?source=collection_archive---------4-----------------------

用机器学习的力量预防疾病

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

安娜·科洛舒克在 Unsplash 上的照片

1.介绍

心脏病是全球发病率和死亡率的主要原因:它每年造成的死亡人数超过任何其他原因。根据 T4 世卫组织的数据,2016 年估计有 1790 万人死于心脏病,占全球死亡人数的 31%。超过四分之三的死亡发生在低收入和中等收入国家。

在所有心脏病中,冠心病(又名心脏病发作)是最常见也是最致命的。以美国为例,据估计每 40 秒就有一人心脏病发作,每年约有80.5 万美国人心脏病发作( CDC 2019 )。

令人欣慰的是,心脏病发作是高度可预防的,简单的生活方式改变(如减少酒精和烟草的使用;健康饮食和锻炼)加上早期治疗大大改善了其预后。然而,由于诸如糖尿病、高血压、高胆固醇等几个促成风险因素的多因素性质,很难识别高风险患者。这就是机器学习和数据挖掘的救援之处。

医生和科学家都转向机器学习(ML)技术来开发筛选工具,这是因为与其他传统的统计方法相比,它们在模式识别和分类方面具有优势。

在这篇文章中,我将向您介绍一种筛选工具的开发过程,该工具使用不同的机器学习技术在 Framingham 数据集上预测患者是否有 10 年患冠心病的风险。

这篇文章的代码可以在我的 Github 库中找到,或者直接从它的伙伴 Kaggle 笔记本中找到

2.数据集描述

该数据集在 Kaggle 网站上公开,它来自一项正在进行的对马萨诸塞州弗雷明汉镇居民的心血管研究。分类目标是预测患者是否有未来冠心病(CHD)的 10 年风险。数据集提供了患者的信息。它包括 4000 多条记录和 15 个属性。每个属性都是潜在的风险因素。既有人口统计学上的,也有行为和医学上的风险因素。

属性:

  1. 人口统计:
  • 性别:男性或女性(名义)
  • 年龄:患者的年龄;(连续——尽管记录的年龄被截断为整数,但年龄的概念是连续的)

2.教育:未提供进一步信息

3.行为:

  • 当前吸烟者:患者是否是当前吸烟者(名义上)
  • 每日吸烟量:一个人一天内平均吸烟的数量。(可以认为是连续的,因为一个人可以有任何数量的香烟,甚至半支香烟。)

4.病史信息:

  • 血压药物:患者是否在服用降压药(标称值)
  • 流行性中风:患者以前是否曾患过中风(名义上)
  • 流行性高血压:患者是否患有高血压(正常)
  • 糖尿病:患者是否患有糖尿病(名义上)

5.当前医疗状况信息:

  • 总胆固醇:总胆固醇水平(连续)
  • 系统血压:收缩压(持续)
  • Dia BP:舒张压(持续)
  • 身体质量指数:身体质量指数(连续)
  • 心率:心率(连续)——在医学研究中,心率等变量虽然实际上是离散的,但由于有大量的可能值,因此被认为是连续的。)
  • 葡萄糖:血糖水平(持续)

要预测的目标变量:

10 年患冠心病的风险——(二进制:“1”,表示“有风险”,“0”,表示“没有风险”)

3.工具开发

这篇文章的完整代码可以在这里找到。它是用 Python 实现的,使用了不同的分类算法。以下是我所采用的一般方法的简要描述:

  1. 数据清理和预处理:这里我检查并处理了数据集中缺失和重复的变量,因为这些变量会严重影响不同机器学习算法的性能(许多算法不允许缺失数据)。
  2. 探索性数据分析:这里我想从数据中获得重要的统计见解,我检查了不同属性的分布、属性之间的相关性以及目标变量,并计算了分类属性的重要概率和比例。
  3. 特征选择:由于数据集中不相关的特征会降低所应用模型的准确性,我使用了 Boruta 特征选择技术来选择最重要的特征,这些特征后来被用于构建不同的模型。
  4. 模型开发和比较:我使用了四种分类模型,分别是逻辑回归、K 近邻、决策树和支持向量机,之后我使用模型的精确度和 F1 值来比较模型的性能。然后我选择了表现最好的模特。

3.1 数据清理和预处理

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据集的前 20 条记录

数据集中没有重复条目,但有些条目缺少值,下表给出了这些条目的汇总:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每个要素缺失值的百分比

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在 9.15%的情况下,血糖条目的缺失数据百分比最高。其他功能几乎没有遗漏条目。

丢失的条目只占全部数据的 12%,因此可以在不丢失大量数据的情况下被删除。

3.2 探索性数据分析

第一步是检查不同属性的分布,这可以通过直方图最好地可视化

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据分布

从分布图中挑选出分类变量和连续变量是很容易的。此外,可以看出,受访者中没有人患有流行性中风,极少数人患有糖尿病、服用降压药或高血压。这些分布也引起了数据集可能不适当平衡的怀疑,为了证实这一点,我比较了阳性和阴性病例的数量,与我的怀疑相符,有 3179 名应答者没有冠心病,572 名患者有冠心病。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

不平衡数据集

为了更深入地了解数据,我检查了每一类中阳性和阴性病例的比例。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

分类可变比例

由于数据集的不平衡性质,很难做出结论,但根据观察到的情况,可以得出以下结论:

  • 患冠心病的男性略多于女性。
  • 吸烟者和不吸烟者患冠心病的比例几乎相等。
  • 与没有类似疾病的人相比,糖尿病患者和高血压患者中患冠心病的比例更高。
  • 更大比例的冠心病患者正在服用降压药。

我检查的另一个有趣的趋势是冠心病患者的年龄分布,患病人数通常随着年龄的增长而增加,63 岁时达到高峰。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

最后一步是检查不同特征与目标变量以及彼此之间的相关性,因为这不仅可以很好地估计作为冠心病预测指标的特征强度,还可以揭示特征之间的任何共线性

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

相关矩阵

从矩阵中可以看出,没有任何特征与 10 年内发生冠心病的风险之间的相关性超过 0.5,这表明这些特征是不良的预测因素。然而,相关性最高的特征是年龄、流行性高血压和收缩压。

此外,有几个特征彼此高度相关,使用这两个特征来构建机器学习模型是没有意义的。其中包括:血糖和糖尿病(明显);收缩压和舒张压;吸烟和每天吸烟的数量。

3.3 功能选择

相关矩阵的结果提示需要进行特征选择。为此,我采用了 Boruta 特征选择算法,这是一种围绕随机森林分类算法构建的包装器方法。它试图捕捉与结果变量相关的数据集中所有重要的、有趣的特征。

它的工作方式如下:

  • 首先,它通过创建所有特征(称为阴影特征)的混洗副本,给给定的数据集增加了随机性。
  • 然后,它在扩展数据集上训练一个随机森林分类器,并应用一个特征重要性度量(默认为 Mean Decrease Accuracy )来评估每个特征的重要性,其中分数越高意味着特征越重要。
  • 在每次迭代中,它会检查真实要素的重要性是否高于其最佳阴影要素(即,该要素的 Z 值是否高于其阴影要素的最大 Z 值),并不断移除被认为非常不重要的要素。
  • 最后,当所有特征被确认或拒绝时,或者当算法达到随机森林运行的指定限制时,算法停止。

点击查看完整描述

在运行该算法 100 次迭代后,最先选择的特征是:年龄、总胆固醇、收缩压、舒张压、身体质量指数、心率和血糖。

然后我计算了主要特征的优势比和十年内患冠心病的风险,结果如下:

CI 5%     CI 95%    Odds Ratio
age        1.011381  1.033813    1.022536
totChol    0.994963  0.999184    0.997071
sysBP      1.018236  1.031493    1.024843
diaBP      0.962258  0.984627    0.973378
BMI        0.929304  0.973798    0.951291
heartRate  0.963690  0.977730    0.970685
glucose    1.001074  1.007518    1.004291

在所有其他特征保持不变的情况下,年龄和收缩压每增加一个百分点,患心脏病的几率就会增加 2 个百分点。

其他因素显示没有明显的积极的可能性。

3.4 模型开发和对比

不建议在不平衡数据集上训练分类器,因为它可能偏向一个类别,从而实现高精度,但具有较差的灵敏度或特异性。

在我们的案例中,阴性病例的数量(3179)大大超过了阳性病例的数量(572)。例如,如果我们训练一个总是预测负面类别的模型,它将达到 84.75 %(3179/(3179+572) x 100)的高准确度,但灵敏度为 0% (0/(0+572) x 100),因为它从不预测正面案例。

来解决这个问题。我使用合成少数过采样技术(SMOTE)平衡数据集。它是这样工作的:

SMOTE 首先随机选择一个少数类实例 x,并找到它的 k 个最近的少数类邻居。然后,通过随机选择 k 个最近邻居 x 中的一个并连接 x 和 x 以在特征空间中形成线段来创建合成实例。合成实例是作为两个选定实例 x 和 x 的凸组合生成的。—第 47 页,不平衡学习:基础、算法和应用,2013 年。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由胡风通过辛达维 (CC0)拍摄

该过程可用于根据需要为少数类创建尽可能多的合成示例。它建议首先使用随机欠采样来减少多数类中的样本数量,然后使用 SMOTE 对少数类进行过采样来平衡类分布。

使用这种技术后,得到的数据集更加平衡,有 3178 个阴性病例和 2543 个阳性病例

Numbers before  {positive: 3179, negative: 572} 
Numbers after   {positive: 3178, negative: 2543}

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

通过 SMOTE 平衡数据集

在平衡数据集之后,我缩放特征以加速分类器的训练,然后将数据分别以 0.8 比 0.2 的比例分成训练集和测试集

使用训练集,我训练了四个分类器,即:

  1. 逻辑回归:对数据点属于特定类别的概率进行建模,并根据选择的阈值为该点分配适当的标签。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由詹姆斯·乐通过数据营 (CC0)

2。K-nearest neighbors :试图通过查看一个数据点周围的数据点来确定该数据点属于哪个组。例如,给定一个数据点 C,如果它周围的大多数点都在组 A 中,那么很可能该数据点属于组 A 而不是组 B,反之亦然。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由柴坦尼亚·雷迪·帕托拉通过媒介 (CC BY-NC-ND 2.0)

3。决策树:基于树状图形,节点代表我们选择属性和提出问题的位置;边代表问题的答案;树叶代表实际的输出或类别标签。决策树通过从根到某个叶节点对示例进行排序来对示例进行分类,叶节点为示例提供分类。树中的每个节点都充当某个属性的测试用例,从该节点开始向下延伸的每条边都对应于测试用例的一个可能答案。这个过程本质上是递归的,并且对以新节点为根的每个子树重复进行。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像由 Rahul Saxena 通过 Dataaspirant (CCo)拍摄

4。支持向量机:由分离超平面形式定义的判别分类器。换句话说,给定标记的训练数据,该算法输出最佳超平面,该超平面基于新示例相对于它位于哪一侧来对新示例进行分类。在二维空间中,这个超平面是一条将平面分成两部分的线,每一类位于两边。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由弗朗索瓦·德·瑞克尔通过 Github 提供(CC BY-NC-ND 2.0)

在训练每个模型并使用网格搜索调整它们的超参数之后,我使用以下指标评估并比较了它们的性能:

  1. **准确率:**正确预测数与输入样本总数之比。它衡量算法对数据进行正确分类的趋势。
  2. F1 分数:定义为测试精度和召回的加权调和平均值。通过使用精确度和召回率,它给出了测试性能的更现实的度量。(精确度,也称为阳性预测值,是真正阳性的阳性结果的比例。回忆,也称为敏感性,是一项测试正确识别阳性结果以获得真实阳性率的能力)。
  3. ROC曲线下的面积(AUC): 提供了对所有可能的分类阈值性能的综合衡量。它给出了模型对随机正例的排序高于随机负例的概率

结果如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

不同型号的性能得分

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

精确度比较

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

F1 分数对比

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

AUC 的比较

支持向量机是所有指标中表现最好的模型。它的最佳参数是径向核,C 值为 10,gamma 值为 1。其高 AUC 和 F1 分数也表明该模型具有高的真阳性率,因此对于预测一个人是否具有发展为 CHD 的高风险(即在 10 年内患心脏病发作)是敏感的。

4.结论

然后,该模型可以用作简单的筛选工具,我们所需要做的就是输入以下数据:年龄、身体质量指数、收缩压和舒张压、心率和血糖水平,然后该模型可以运行并输出预测。

像所有优秀的科学家一样,我决定用我的个人数据来测试这个工具,结果如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

个人预测

我可以有 84%的信心说,我在未来 10 年内没有患冠心病的风险🎉🎉。我相信这些预测,因为我确实锻炼过,至少偶尔会锻炼一下。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由迪伦·克尔通过 imgur 提供

然而,作为健全性检查,关于阳性病例的大多数数据是使用 SMOTE 人工合成的,因此它们可能不是实际人群数据的真实代表,因此需要更多的数据,特别是关于阳性病例的数据,以建立更好的模型和更有效的筛选工具。

感谢您的阅读。请随意分享你的想法和主意。

心脏病 UCI-诊断和预测

原文:https://towardsdatascience.com/heart-disease-uci-diagnosis-prediction-b1943ee835a7?source=collection_archive---------0-----------------------

使用逻辑回归进行预测,准确率为 87%

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Robina WeermeijerUnsplash 上拍摄的照片

每天,人类的平均心脏跳动约 10 万次,将 2000 加仑的血液输送到全身。在你的体内有 60,000 英里长的血管。

女性心脏病发作的迹象远不如男性明显。在女性中,心脏病发作可能会感到胸部中心挤压、压力、胀满或疼痛等不适。它也可能导致一个或两个手臂,背部,颈部,下巴或胃部疼痛,呼吸急促,恶心和其他症状。男性经历心脏病发作的典型症状,如胸痛、不适和压力。他们也可能经历其他部位的疼痛,如手臂、颈部、背部和下巴,以及呼吸急促、出汗和类似烧心的不适。

对于一个就像一个大拳头,重量在 8 到 12 盎司之间的器官来说,这是一个很大的工作量。

来源:healthblog.uofmhealth

由 Hardik 编码:

Google colab 笔记本链接:https://colab . research . Google . com/drive/16 ifrpq 0 VX _ czy po 4 zyj _ qtlrvds 3 fldb?usp =分享

GitHub:https://GitHub . com/smarthardk 10/Heart-Disease-UCI-诊断-预测

UCI 心脏病数据集:

数据集来源:https://archive.ics.uci.edu/ml/datasets/Heart+Disease

数据集列:

  • 年龄:以年为单位的人员年龄
  • 性别:人的性别(1 =男性,0 =女性)
  • cp:胸痛类型
    —值 0:无症状
    —值 1:不典型心绞痛
    —值 2:非心绞痛性疼痛
    —值 3:典型心绞痛
  • trestbps:患者的静息血压(入院时为毫米汞柱)
  • 胆固醇:人体的胆固醇含量,单位为毫克/分升
  • fbs:人的空腹血糖(> 120 mg/dl,1 =真;0 =假)
  • restecg:静息心电图结果
    —值 0:根据 Estes 标准
    —值 1:正常
    —值 2:ST-T 波异常(T 波倒置和/或 ST 段抬高或压低> 0.05 mV)
  • thalach:人达到的最大心率
  • exang:运动诱发的心绞痛(1 =是;0 =否)
  • oldpeak:运动相对于休息诱发的 st 段压低(“ST”与心电图图上的位置有关。点击此处查看更多)
  • 斜率:运动 ST 段峰值的斜率— 0:下降;1:平;2:上坡
    0:下坡;1:平;2:上坡
  • ca:主要血管的数量(0-3)
  • 地中海贫血:一种称为地中海贫血的血液疾病值 0:空(从先前的数据集中删除
    值 1:固定缺陷(心脏的某些部分没有血流)
    值 2:正常血流
    值 3:可逆缺陷(观察到血流但不正常)
  • 目标:心脏病(1 =否,0=是)

语境:

这是多变量类型的数据集,这意味着提供或涉及各种单独的数学或统计变量,多变量数值数据分析。它由 14 个属性组成,即年龄、性别、胸痛类型、静息血压、血清胆固醇、空腹血糖、静息心电图结果、达到的最大心率、运动诱发的心绞痛、运动相对于静息诱发的老峰-st 段压低、运动 ST 段峰值斜率、主要血管数和地中海贫血。该数据库包括 76 个属性,但所有发表的研究都涉及其中 14 个属性的子集的使用。克利夫兰数据库是迄今为止 ML 研究人员唯一使用的数据库。该数据集的主要任务之一是基于患者的给定属性来预测该特定人是否患有心脏病,其他任务是实验性任务,以诊断并从该数据集找出各种见解,这有助于更好地理解问题。

数据集由:- 创建

1.匈牙利心脏病研究所。布达佩斯:医学博士安朵斯·雅诺西。瑞士苏黎世大学医院:William Steinbrunn,医学博士。瑞士巴塞尔大学医院:医学博士马蒂亚斯·普菲斯特勒。弗吉尼亚医疗中心,长滩和克利夫兰诊所基金会:罗伯特·德特拉诺,医学博士,哲学博士。

目录

  1. 导入和读取数据集
  2. 数据描述

3.数据分析

4.数据可视化

5.数据预处理

6.逻辑回归

7.结论

1.导入和读取数据集

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inlinedf = pd.read_csv('/content/drive/My Drive/dataset/heart.csv')df.head()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

heart.csv

2.数据描述

形容

关于元数据有很多混乱,因为有各种不同的元数据可用。下面我从 kaggle 得到了两个最常用的元数据描述。所以我们要遵循第二个描述(2 —描述)。

1 —描述

•    age: The person's age in years
•    sex: The person's sex (1 = male, 0 = female)
•    cp: The chest pain experienced (Value 1: typical angina, Value 2: atypical angina, Value 3: non-anginal pain, Value 4: asymptomatic)
•    trestbps: The person's resting blood pressure (mm Hg on admission to the hospital)
•    chol: The person's cholesterol measurement in mg/dl
•    fbs: The person's fasting blood sugar (> 120 mg/dl, 1 = true; 0 = false)
•    restecg: Resting electrocardiographic measurement (0 = normal, 1 = having ST-T wave abnormality, 2 = showing probable or definite left ventricular hypertrophy by Estes' criteria)
•    thalach: The person's maximum heart rate achieved
•    exang: Exercise induced angina (1 = yes; 0 = no)
•    oldpeak: ST depression induced by exercise relative to rest ('ST' relates to positions on the ECG plot. See more here)
•    slope: the slope of the peak exercise ST segment (Value 1: upsloping, Value 2: flat, Value 3: downsloping)
•    ca: The number of major vessels (0-3)
•    thal: A blood disorder called thalassemia (3 = normal; 6 = fixed defect; 7 = reversable defect)
•    target: Heart disease (0 = no, 1 = yes)

2 —描述

 cp: chest pain type
-- Value 0: asymptomatic
-- Value 1: atypical angina
-- Value 2: non-anginal pain
-- Value 3: typical angina

restecg: resting electrocardiographic results
-- Value 0: showing probable or definite left ventricular hypertrophy by Estes' criteria
-- Value 1: normal
-- Value 2: having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV)

slope: the slope of the peak exercise ST segment
0: downsloping; 1: flat; 2: upsloping

thal
Results of the blood flow observed via the radioactive dye.

Value 0: NULL (dropped from the dataset previously)
Value 1: fixed defect (no blood flow in some part of the heart)
Value 2: normal blood flow
Value 3: reversible defect (a blood flow is observed but it is not normal)
This feature and the next one are obtained through a very invasive process for the patients. But, by themselves, they give a very good indication of the presence of a heart disease or not.

target : 0 = disease, 1 = no diseasedf.info()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

df.describe()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

检查空值

df.isnull().sum()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

#visualizing Null values if it exists 
plt.figure(figsize=(22,10))plt.xticks(size=20,color='grey')
plt.tick_params(size=12,color='grey')plt.title('Finding Null Values Using Heatmap\n',color='grey',size=30)sns.heatmap(df.isnull(),
            yticklabels=False,
            cbar=False,
            cmap='PuBu_r',
            )

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据集没有空值

熊猫-侧写

!pip install https://github.com/pandas-profiling/pandas-profiling/archive/master.zipimport pandas_profiling as pp
pp.ProfileReport(df)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

详细查看:https://colab . research . Google . com/drive/16 ifrpq 0 VX _ czypo 4 zyj _ qtlrvds 3 fldb # scroll to = 5 cribqgn 9 fii&line = 1&uniqifier = 1

3.数据分析

功能选择

  1. 单变量选择—统计测试可用于挑选与性能变量有最佳关系的某些特征。
    sci kit-learn 库提供了 SelectKBest 类,可用于在一套不同的统计测试中选择特定数量的特征。
    以下示例使用非负特征的卡方(chi2)统计测试从数据集中选择 13 个最佳特征。
from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import chi2
data = df.copy()
X = data.iloc[:,0:13]  #independent columns
y = data.iloc[:,-1]    #target column 
#apply SelectKBest class to extract top best features
bestfeatures = SelectKBest(score_func=chi2, k=10)
fit = bestfeatures.fit(X,y)
dfscores = pd.DataFrame(fit.scores_)
dfcolumns = pd.DataFrame(X.columns)
#concat two dataframes for better visualization 
featureScores = pd.concat([dfcolumns,dfscores],axis=1)
featureScores.columns = ['Specs','Score']  #naming the dataframe columns
print(featureScores.nlargest(12,'Score'))  #print best features

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2.要素重要性-通过使用模型特征属性,可以获得数据集每个要素的重要性。
特征值为您的结果的每个功能给出一个分数,分数越高,性能变量越重要或合适。
特征重要性是基于树的分类器自带的内置类,我们将使用额外的树分类器来提取数据集的顶部特征。

from sklearn.ensemble import ExtraTreesClassifiermodel = ExtraTreesClassifier()
model.fit(X,y)
print(model.feature_importances_) #use inbuilt class feature_importances of tree based classifiers
#plot graph of feature importances for better visualization
feat_importances = pd.Series(model.feature_importances_, index=X.columns)
feat_importances.nlargest(13).plot(kind='barh')
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

3.带热点图的关联矩阵-关联表示要素之间或要素与目标变量之间的关联方式。
相关性可能是正的(增加一个特性值会增加目标变量的值)或负的(增加一个特性值会减少目标变量的值)
热图可以很容易地对与目标变量最相关的特性进行分类,我们将使用 seaborn 库绘制热图的相关特性。

相关性显示特征是否相互关联或与目标变量相关。相关性可以是正的(增加一个值,目标变量的值增加)或负的(增加一个值,目标变量的值减少)。从该热图中,我们可以观察到“cp”胸痛与目标变量高度相关。与其他两个变量之间的关系相比,我们可以说胸痛在预测心脏病的存在方面贡献最大。医疗急救是心脏病发作。当血块阻塞了流向心脏的血液时,通常会发生心脏病。组织在没有血液的情况下失去氧气并死亡,导致胸痛。

plt.figure(figsize=(12,10))
sns.heatmap(df.corr(),annot=True,cmap="magma",fmt='.2f')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

for i in df.columns:
    print(i,len(df[i].unique()))

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

4.数据可视化

Seaborn

sns.set_style('darkgrid')
sns.set_palette('Set2')

准备数据

df2 = df.copy()def chng(sex):
    if sex == 0:
        return 'female'
    else:
        return 'male'df2['sex'] = df2['sex'].apply(chng)def chng2(prob):
    if prob == 0:
        return ‘Heart Disease’
    else:
        return ‘No Heart Disease’df2['target'] = df2['target'].apply(chng2)

1.计数图

df2['target'] = df2['target'].apply(chng2)
sns.countplot(data= df2, x='sex',hue='target')
plt.title('Gender v/s target\n')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

根据克利夫兰的数据,男性比女性更容易患心脏病。男性比女性更容易患心脏病。70% — 89%的男性经历过突发心脏病。女性可能会在完全没有胸部压力的情况下经历心脏病发作,她们通常会经历恶心或呕吐,这往往与胃酸倒流或流感相混淆。

sns.countplot(data= df2, x='cp',hue='target')
plt.title('Chest Pain Type v/s target\n')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

有四种类型的胸痛,无症状、非典型心绞痛、非心绞痛性疼痛和典型心绞痛。大多数心脏病患者被发现有无症状的胸痛。这些人群可能会出现非典型症状,如消化不良、流感或胸肌劳损。与任何心脏病发作一样,无症状的发作会导致流向心脏的血流受阻,并可能对心肌造成损害。无症状心脏病发作的危险因素与有心脏病症状的相同。这些因素包括:

年龄

糖尿病

过重

心脏病家族史

高血压

高胆固醇

缺乏锻炼

先前心脏病发作

烟草使用

无症状的心脏病发作会增加你再次心脏病发作的风险,这可能是致命的。再次心脏病发作也会增加并发症的风险,如心力衰竭。没有测试来确定你无症状心脏病发作的可能性。判断你是否有无症状发作的唯一方法是通过心电图或超声心动图。这些测试可以揭示预示心脏病发作的变化。

sns.countplot(data= df2, x='sex',hue='thal')
plt.title('Gender v/s Thalassemia\n')
print('Thalassemia (thal-uh-SEE-me-uh) is an inherited blood disorder that causes your body to have less hemoglobin than normal. Hemoglobin enables red blood cells to carry oxygen')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

β地中海贫血性心肌病主要以两种不同的表型为特征,扩张型伴有左心室扩张和收缩性受损,而限制性表型伴有限制性左心室感觉、肺动脉高压和右心衰竭。心脏问题、充血性心力衰竭和心律异常可能与严重的地中海贫血有关。

sns.countplot(data= df2, x='slope',hue='target')
plt.title('Slope v/s Target\n')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

sns.countplot(data= df2, x='exang',hue='thal')
plt.title('exang v/s Thalassemia\n')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2.距离图

plt.figure(figsize=(16,7))
sns.distplot(df[df['target']==0]['age'],kde=False,bins=50)
plt.title('Age of Heart Diseased Patients\n')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

心脏病在 60 岁及以上年龄组的老年人中非常常见,在 41 至 60 岁年龄组的成年人中也很常见。但是在 19 到 40 岁的年龄组中很少见,在 0 到 18 岁的年龄组中非常少见。

plt.figure(figsize=(16,7))
sns.distplot(df[df['target']==0]['chol'],kde=False,bins=40)
plt.title('Chol of Heart Diseased Patients\n')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

  • 总胆固醇
  • 低密度脂蛋白——“坏胆固醇”
  • 高密度脂蛋白——“好胆固醇”

在成人中,总胆固醇水平低于 200 毫克/分升(mg / dL)被认为是理想的。在 200 至 239 毫克/分升和 240 毫克/分升及以上之间的边界线被认为是高的。低密度脂蛋白应含有少于 100 毫克/分升的胆固醇。对于没有任何健康问题的人来说,100 mg / dl 的剂量率是合适的,但对于有心脏问题或有心脏病风险因素的人来说可能更合适。水平介于 130 至 159 毫克/分升和 160 至 189 毫克/分升之间。读数很高,达到或超过 190 毫克/分升。高密度脂蛋白水平应保持在较高水平。心血管疾病的危险因素被称为读数低于 40 毫克/分升。临界低被认为是在 41 毫克/分升和 59 毫克/分升之间。高密度脂蛋白水平最高可达 60 毫克/分升。

plt.figure(figsize=(16,7))
sns.distplot(df[df['target']==0]['thalach'],kde=False,bins=40)
plt.title('thalach of Heart Diseased Patients\n')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

3.接合图

准备数据

df3 = df[df['target'] == 0 ][['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach',
       'exang', 'oldpeak', 'slope', 'ca', 'thal', 'target']] 
#target 0 - people with heart diseasepal = sns.light_palette("blue", as_cmap=True)print('Age vs trestbps(Heart Diseased Patinets)')
sns.jointplot(data=df3,
              x='age',
              y='trestbps',
              kind='hex',
              cmap='Reds'

              )

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

sns.jointplot(data=df3,
              x='chol',
              y='age',
              kind='kde',
              cmap='PuBu'
              )

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

seaborn 的联合图有助于我们理解两个特征之间的趋势。从上面的图中可以看出,大多数 50 多岁或 60 多岁的心脏病患者的胆固醇含量在 200 毫克/分升到 300 毫克/分升之间。

sns.jointplot(data=df3,
              x='chol',
              y='trestbps',
              kind='resid',

              )

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

4.箱线图/紫线图

sns.boxplot(data=df2,x='target',y='age')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

plt.figure(figsize=(14,8))
sns.violinplot(data=df2,x='ca',y='age',hue='target')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

sns.boxplot(data=df2,x='cp',y='thalach',hue='target')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

plt.figure(figsize=(10,7))
sns.boxplot(data=df2,x='fbs',y='trestbps',hue='target')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

plt.figure(figsize=(10,7))
sns.violinplot(data=df2,x='exang',y='oldpeak',hue='target')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

plt.figure(figsize=(10,7))
sns.boxplot(data=df2,x='slope',y='thalach',hue='target')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

sns.violinplot(data=df2,x='thal',y='oldpeak',hue='target')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

sns.violinplot(data=df2,x='target',y='thalach')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

5.聚类图

sns.clustermap(df.corr(),annot=True)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

6.配对图

sns.pairplot(df,hue='cp')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

分类树

from sklearn.tree import DecisionTreeClassifier # Import Decision Tree Classifier
from sklearn.model_selection import train_test_split # Import train_test_split function
from sklearn import metrics #Import scikit-learn metrics module for accuracy calculation
X = df.iloc[:,0:13] # Features
y = df.iloc[:,13] # Target variable
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1) # 70% training and 30% test# Create Decision Tree classifer object
clf = DecisionTreeClassifier()# Train Decision Tree Classifer
clf = clf.fit(X_train,y_train)#Predict the response for test dataset
y_pred = clf.predict(X_test)print("Accuracy:",metrics.accuracy_score(y_test, y_pred))

Accuracy: 0.7142857142857143

feature_cols = ['age', 'sex', 'cp', 'trestbps','chol', 'fbs', 'restecg', 'thalach','exang', 'oldpeak', 'slope', 'ca', 'thal']from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO  
from IPython.display import Image  
import pydotplusdot_data = StringIO()
export_graphviz(clf, out_file=dot_data,  
                filled=True, rounded=True,
                special_characters=True,feature_names = feature_cols  ,class_names=['0','1'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
graph.write_png('diabetes.png')
Image(graph.create_png())

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

# Create Decision Tree classifer object
clf = DecisionTreeClassifier(criterion="entropy", max_depth=3)# Train Decision Tree Classifer
clf = clf.fit(X_train,y_train)#Predict the response for test dataset
y_pred = clf.predict(X_test)# Model Accuracy, how often is the classifier correct?
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))

Accuracy: 0.7362637362637363

from sklearn.externals.six import StringIO  
from IPython.display import Image  
from sklearn.tree import export_graphviz
import pydotplus
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data,  
                filled=True, rounded=True,
                special_characters=True, feature_names = feature_cols,class_names=['0','1'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
graph.write_png('diabetes.png')
Image(graph.create_png())

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

5.数据预处理

预处理

更改列的名称

df.columns = ['age', 'sex', 'chest_pain_type', 'resting_blood_pressure', 'cholesterol', 'fasting_blood_sugar', 'rest_ecg_type', 'max_heart_rate_achieved',
       'exercise_induced_angina', 'st_depression', 'st_slope_type', 'num_major_vessels', 'thalassemia_type', 'target']df.columns

Index([‘age’, ‘sex’, ‘chest_pain_type’, ‘resting_blood_pressure’, ‘cholesterol’, ‘fasting_blood_sugar’, ‘rest_ecg_type’, ‘max_heart_rate_achieved’, ‘exercise_induced_angina’, ‘st_depression’, ‘st_slope_type’, ‘num_major_vessels’, ‘thalassemia_type’, ‘target’], dtype=’object’)

我们有 4 个分类列,如使用 pandas profiling 的数据描述所示:

cp —胸痛类型

restecg — rest_ecg_type

坡度 st _ slope _ type

地中海贫血型

生成分类列值

#cp - chest_pain_type
df.loc[df['chest_pain_type'] == 0, 'chest_pain_type'] = 'asymptomatic'
df.loc[df['chest_pain_type'] == 1, 'chest_pain_type'] = 'atypical angina'
df.loc[df['chest_pain_type'] == 2, 'chest_pain_type'] = 'non-anginal pain'
df.loc[df['chest_pain_type'] == 3, 'chest_pain_type'] = 'typical angina'#restecg - rest_ecg_type
df.loc[df['rest_ecg_type'] == 0, 'rest_ecg_type'] = 'left ventricular hypertrophy'
df.loc[df['rest_ecg_type'] == 1, 'rest_ecg_type'] = 'normal'
df.loc[df['rest_ecg_type'] == 2, 'rest_ecg_type'] = 'ST-T wave abnormality'#slope - st_slope_type
df.loc[df['st_slope_type'] == 0, 'st_slope_type'] = 'downsloping'
df.loc[df['st_slope_type'] == 1, 'st_slope_type'] = 'flat'
df.loc[df['st_slope_type'] == 2, 'st_slope_type'] = 'upsloping'#thal - thalassemia_type
df.loc[df['thalassemia_type'] == 0, 'thalassemia_type'] = 'nothing'
df.loc[df['thalassemia_type'] == 1, 'thalassemia_type'] = 'fixed defect'
df.loc[df['thalassemia_type'] == 2, 'thalassemia_type'] = 'normal'
df.loc[df['thalassemia_type'] == 3, 'thalassemia_type'] = 'reversable defect'

一个热编码

data = pd.get_dummies(df, drop_first=False)
data.columns

Index([‘age’, ‘sex’, ‘resting_blood_pressure’, ‘cholesterol’, ‘fasting_blood_sugar’, ‘max_heart_rate_achieved’, ‘exercise_induced_angina’, ‘st_depression’, ‘num_major_vessels’, ‘target’, ‘chest_pain_type_asymptomatic’, ‘chest_pain_type_atypical angina’, ‘chest_pain_type_non-anginal pain’, ‘chest_pain_type_typical angina’, ‘rest_ecg_type_ST-T wave abnormality’, ‘rest_ecg_type_left ventricular hypertrophy’, ‘rest_ecg_type_normal’, ‘st_slope_type_downsloping’, ‘st_slope_type_flat’, ‘st_slope_type_upsloping’, ‘thalassemia_type_fixed defect’, ‘thalassemia_type_normal’, ‘thalassemia_type_nothing’, ‘thalassemia_type_reversable defect’], dtype=’object’)

df_temp = data['thalassemia_type_fixed defect']
data = pd.get_dummies(df, drop_first=True)
data.head()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

由于一个热编码删除了“地中海贫血 _ 类型 _ 固定缺陷”列,与空列“地中海贫血 _ 类型 _ 无”相比,这是一个有用的列,因此我们删除了“地中海贫血 _ 类型 _ 无”和串联的“地中海贫血 _ 类型 _ 固定缺陷”

frames = [data, df_temp]
result = pd.concat(frames,axis=1)
result.drop('thalassemia_type_nothing',axis=1,inplace=True)
resultc = result.copy()# making a copy for further analysis in conclusion section

6.逻辑回归

1.收集列

X = result.drop('target', axis = 1)
y = result['target']

2.拆分数据

from sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

3.正常化

最小-最大归一化法用于归一化数据。此方法将数据范围缩放到[0,1]。在大多数情况下,标准化也是基于特性的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

X_train=(X_train-np.min(X_train))/(np.max(X_train)-np.min(X_train)).valuesX_test=(X_test-np.min(X_test))/(np.max(X_test)-np.min(X_test)).values

4.适合模型

from sklearn.linear_model import LogisticRegression
logre = LogisticRegression()
logre.fit(X_train,y_train)

LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True, intercept_scaling=1, l1_ratio=None, max_iter=100, multi_class=’auto’, n_jobs=None, penalty=’l2', random_state=None, solver=’lbfgs’, tol=0.0001, verbose=0, warm_start=False)

5.预言;预测;预告

y_pred = logre.predict(X_test)
actual = []
predcition = []for i,j in zip(y_test,y_pred):
  actual.append(i)
  predcition.append(j)dic = {'Actual':actual,
       'Prediction':predcition
       }result  = pd.DataFrame(dic)import plotly.graph_objects as go

fig = go.Figure()

fig.add_trace(go.Scatter(x=np.arange(0,len(y_test)), y=y_test,
                    mode='markers+lines',
                    name='Test'))
fig.add_trace(go.Scatter(x=np.arange(0,len(y_test)), y=y_pred,
                    mode='markers',
                    name='Pred'))

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

红点代表 0 或 1 的预测值,蓝线和点代表特定患者的实际值。红点和蓝点不重叠的地方是错误的预测值,而红点和蓝点重叠的地方是正确的预测值。

6.模型评估

from sklearn.metrics import accuracy_score
print(accuracy_score(y_test,y_pred))

0.8688524590163934

from sklearn.metrics import classification_report
print(classification_report(y_test,y_pred))

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

该模型的分类报告显示,91%的无心脏病预测被正确预测,83%的有心脏病预测被正确预测。

from sklearn.metrics import confusion_matrix
print(confusion_matrix(y_test,y_pred))
sns.heatmap(confusion_matrix(y_test,y_pred),annot=True)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

混乱矩阵

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

混淆矩阵真正值是 24,真负值是 29。假阳性是 3,假阴性是 5。

ROC 曲线

ROC 曲线总结了使用不同概率阈值的预测模型的真阳性率和假阳性率之间的权衡。

ROC 曲线的准确性为 87.09%。

from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred)
plt.plot(fpr,tpr)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.title('ROC curve for Heart disease classifier')
plt.xlabel('False positive rate (1-Specificity)')
plt.ylabel('True positive rate (Sensitivity)')
plt.grid(True)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

import sklearn
sklearn.metrics.roc_auc_score(y_test,y_pred)

0.8709150326797386

7.结论

1.系数

print(logre.intercept_)
plt.figure(figsize=(10,12))
coeffecients = pd.DataFrame(logre.coef_.ravel(),X.columns)
coeffecients.columns = ['Coeffecient']
coeffecients.sort_values(by=['Coeffecient'],inplace=True,ascending=False)
coeffecientsts

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2.分析

准备用于分析的数据

df4 = df[df['target'] == 0 ][['age', 'sex', 'chest_pain_type', 'resting_blood_pressure','cholesterol', 'fasting_blood_sugar', 'rest_ecg_type', 'max_heart_rate_achieved', 'exercise_induced_angina', 'st_depression','st_slope_type', 'num_major_vessels', 'thalassemia_type', 'target']] #target 0 - people with heart disease

心脏病患者的可视化

plt.figure(figsize=(16,6))
sns.distplot(df4['max_heart_rate_achieved'])

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

正常的心率在 60 到 100 次/分之间。在心脏病发作期间,由于缺乏血液,心肌的一些区域将开始死亡。一个人的脉搏可能会变得更慢(心动过缓)或更快(心动过速),这取决于他们所经历的心脏病发作的类型。

plt.figure(figsize=(20,6))
sns.boxenplot(data=df4,x='rest_ecg_type',y='cholesterol',hue='st_slope_type')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在正常类型的休息心电图证明是重要的预测模型,随着向下倾斜的 st 斜率。由这两种特征组成的病人通常胆固醇水平在 170-225mg/dl 之间。其余心电图的其他类型的偏差似乎更分散,更不简洁。

plt.figure(figsize=(20,6))
sns.boxenplot(data=df4,x='chest_pain_type',y='max_heart_rate_achieved',hue='thalassemia_type')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Shap

形状值

!pip install shap 
import shap
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test,check_additivity=False)shap.summary_plot(shap_values[1], X_test, plot_type="bar")

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

模型解释的形状值

shap.summary_plot(shap_values[1], X_test)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

def patient_analysis(model, patient):
  explainer = shap.TreeExplainer(model)
  shap_values = explainer.shap_values(patient)
  shap.initjs()
  return shap.force_plot(explainer.expected_value[1], shap_values[1], patient)

两个病人的报告

patients = X_test.iloc[3,:].astype(float)
patients_target = y_test.iloc[3:4]
print('Target : ',int(patients_target))
patient_analysis(model, patients)

Target : 0

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

patients = X_test.iloc[33,:].astype(float)
patients_target = y_test.iloc[33:34]
print('Target : ',int(patients_target))
patient_analysis(model, patients)

Target : 1

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

# dependence plotshap.dependence_plot('num_major_vessels', shap_values[1], X_test, interaction_index = "st_depression")

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

shap_values = explainer.shap_values(X_train.iloc[:50],check_additivity=False)shap.initjs()shap.force_plot(explainer.expected_value[1], shap_values[1], X_test.iloc[:50])

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

3.结论

  • ROC 曲线下的面积为 87.09%,这是令人满意的。
  • 模型预测准确率为 86.88%。该模型更具体而非敏感。
  • 根据该模型,有助于预测模型精度的主要特征在热图中按升序显示。
plt.figure(figsize=(10,12))coeffecients = pd.DataFrame(logre.coef_.ravel(),X.columns)coeffecients.columns = ['Coeffecient']coeffecients.sort_values(by=['Coeffecient'],inplace=True,ascending=False)sns.heatmap(coeffecients,annot=True,fmt='.2f',cmap='Set2',linewidths=0.5)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

有助于预测准确性的重要特征通过热图以降序显示。在银色码中,最有贡献的特征、胸痛类型和达到的最大心率被证明更有价值 1.28 至 1.03 个单位。

参考

【https://www.kaggle.com/ronitf/heart-disease-uci/kernels 号

鳍。

Seaborn 的热图基础

原文:https://towardsdatascience.com/heatmap-basics-with-pythons-seaborn-fb92ea280a6c?source=collection_archive---------0-----------------------

如何使用 Matplotlib 和 Seaborn 创建热图的指南

这个想法很简单,用颜色代替数字。

现在,这种可视化风格已经从简单的彩色编码表格发展到现在。它被广泛用于地理空间数据。它通常用于描述变量的密度或强度,可视化模式,方差,甚至异常。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

相关矩阵——谷物样品的成分

有了这么多的应用,这个基本方法值得注意。本文将介绍热图的基础知识,并了解如何使用 Matplotlib 和 Seaborn 创建它们。

亲自动手

我们将使用熊猫和 Numpy 来帮助我们处理数据争论。

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sb
import numpy as np

本例中的数据集是一个对美元汇率的时间序列。

我不想用通常的折线图来表示一段时间内的值,而是想用一个彩色编码的表来显示这些数据,以月份为列,以年份为行。

我将尝试绘制折线图和热图,以了解这将是什么样子。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

画板. app 绘制

折线图在显示数据方面会更有效;比较一个点在直线上的高度比区分颜色更容易。

热图将产生更大的影响,因为它们不是显示此类数据的传统方式。它们会失去一些准确性,特别是在这种情况下,因为我们需要在几个月内合计这些值。但总的来说,他们仍然能够显示模式并总结我们数据中的时期。

让我们读取数据集并根据草图重新排列数据。

# read file
df = pd.read_csv('data/Foreign_Exchange_Rates.csv', 
                 usecols=[1,7], names=['DATE', 'CAD_USD'], 
                 skiprows=1, index_col=0, parse_dates=[0])

对于这个示例,我们将使用第 1 列和第 7 列,它们是’*时间序列’*和’加拿大—加元/美元’

让我们将这些列重命名为’*DATE '*和’CAD _ USD ',T15,因为我们要传递标题,所以我们也需要跳过第一行。

我们还需要解析第一列,所以值是日期时间格式,我们将日期定义为我们的索引。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据框。

让我们确保所有的值都是数字,并删除空行。

df['CAD_USD'] = pd.to_numeric(df.CAD_USD, errors='coerce')
df.dropna(inplace=True)

我们需要按月汇总这些值。让我们为月和年创建单独的列,然后我们对新列进行分组并获得平均值。

# create a copy of the dataframe, and add columns for month and year
df_m = df.copy()
df_m['month'] = [i.month for i in df_m.index]
df_m['year'] = [i.year for i in df_m.index]# group by month and year, get the average
df_m = df_m.groupby(['month', 'year']).mean()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

前五行分组。

剩下要做的就是拆分索引,这样我们就有了自己的表。

df_m = df_m.unstack(level=0)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

经过整形的数据框。

彩色地图

一切就绪。现在我们可以使用 Seaborn 的.heatmap绘制我们的第一张图表。

fig, ax = plt.subplots(figsize=(11, 9))sb.heatmap(df_m)plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

第一张热图。

好了,在这个可视化准备好之前,还有很多事情要做。

颜色是我们图表中最重要的部分,颜色图有点过于复杂。我们不需要这个。相反,我们可以使用只有两种颜色的序列cmap

我们也可以通过定义vminvmax来明确色图的界限。熊猫.min.max可以帮助我们找出它们的最佳价值。

fig, ax = plt.subplots(figsize=(11, 9))# plot heatmap
sb.heatmap(df_m, cmap="Blues", vmin= 0.9, vmax=1.65,
           linewidth=0.3, cbar_kws={"shrink": .8})plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

第二张热图。

定制的

使用.heatmap还有很多其他的论点可以探讨。

例如linewidth定义了框之间的线的大小,我们甚至可以用cbar_kws将参数直接传递给颜色栏。

颜色看起来不错,现在我们可以把注意力转移到蜱上了。我不认为 CAD_USD-1 是一月的正确名称。让我们用一些更友好的文字来代替它们。

将刻度移动到图表的顶部会提高可视化效果,使它看起来更像一个表。我们也可以去掉 x 和 y 标签,因为我们的轴中的值是不言自明的,标题也会使它们变得多余。

# figure
fig, ax = plt.subplots(figsize=(11, 9))# plot heatmap
sb.heatmap(df_m, cmap="Blues", vmin= 0.9, vmax=1.65, square=True,
           linewidth=0.3, cbar_kws={"shrink": .8})# xticks
ax.xaxis.tick_top()
xticks_labels = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
                'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
plt.xticks(np.arange(12) + .5, labels=xticks_labels)# axis labels
plt.xlabel('')
plt.ylabel('')# title
title = 'monthly Average exchange rate\nValue of one USD in CAD\n'.upper()
plt.title(title, loc='left')plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

最终热图。

我传递给热图的最后一个参数是square。这将使我们的矩阵的细胞呈正方形,而不管图形的大小。

总的来说,看起来不错。我们可以看到,在 21 世纪初,美元比加拿大元高出近 50%,这种情况在 2003 年左右开始改变。这种美元贬值一直持续到 2014 年末,在 2008 年金融危机期间有所变化。

到 2015 年,它已经稳定在 1.20~1.40 左右,直到 2019 年,即我们的记录结束时,月平均值的变化相对较小。

相关矩阵

在下面的例子中,我将通过一个关联矩阵来查看 Seaborn 热图的更多功能。

数据集是 80 种不同谷物的样本,我想看看它们的成分。

为了构建相关矩阵,我们可以使用 Pandas .corr()

# read dataset
df = pd.read_csv('data/cereal.csv')# get correlations
df_corr = df.corr()# irrelevant fields
fields = ['rating', 'shelf', 'cups', 'weight']# drop rows
df_corr.drop(fields, inplace=True)# drop cols
df_corr.drop(fields, axis=1, inplace=True)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

谷物数据框。

相关矩阵中有很多冗余。表格上部的三角形与下部的三角形具有相同的信息。

面具

幸运的是,我们可以在 Seaborn 的热图中使用遮罩,Numpy 有建立一个遮罩的功能。

np.ones_like(df_corr, dtype=np.bool)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一的矩阵(布尔型)

Numpy .ones_like可以创建一个与我们的数据框形状相同的布尔矩阵,而.triu将只返回该矩阵的上三角。

mask = np.triu(np.ones_like(df_corr, dtype=np.bool))

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

面具

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

第一相关矩阵。

面具可以有所帮助,但在我们的矩阵中仍有两个空细胞。

没什么不好。这些值可以增加我们绘图的对称性,也就是说,如果两个列表以相同的值开始和结束,就更容易知道它们是相同的。

如果像我一样,你对此感到困扰,你可以在绘图时过滤掉它们。

fig, ax = plt.subplots(figsize=(10, 8))# mask
mask = np.triu(np.ones_like(df_corr, dtype=np.bool))# adjust mask and df
mask = mask[1:, :-1]
corr = df_corr.iloc[1:,:-1].copy()# plot heatmap
sb.heatmap(corr, mask=mask, annot=True, fmt=".2f", cmap='Blues',
           vmin=-1, vmax=1, cbar_kws={"shrink": .8})# yticks
plt.yticks(rotation=0)plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

第二相关矩阵。

酷,唯一没有提到的是注释。我们可以用参数annot来设置它们,并且我们可以用fmt传递一个格式化函数给它。

发散调色板

我们仍然需要一个标题,大写的刻度会更好看,但这还不是最重要的。

相关性范围从-1 到 1,所以它们有两个方向,在这种情况下,发散调色板比顺序调色板更好。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Seaborn 发散调色板

Seaborn 有一个有效的方法,叫做.diverging_palette,它用来建立我们需要的彩色地图,每边一种颜色,在中间会聚成另一种颜色。

那种方法使用 HUSL 颜色,所以你需要色相饱和度和*明度。*我用hsluv.org来选择这张图表的颜色。

fig, ax = plt.subplots(figsize=(12, 10))# mask
mask = np.triu(np.ones_like(df_corr, dtype=np.bool))# adjust mask and df
mask = mask[1:, :-1]
corr = df_corr.iloc[1:,:-1].copy()# color map
cmap = sb.diverging_palette(0, 230, 90, 60, as_cmap=True)# plot heatmap
sb.heatmap(corr, mask=mask, annot=True, fmt=".2f", 
           linewidths=5, cmap=cmap, vmin=-1, vmax=1, 
           cbar_kws={"shrink": .8}, square=True)# ticks
yticks = [i.upper() for i in corr.index]
xticks = [i.upper() for i in corr.columns]plt.yticks(plt.yticks()[0], labels=yticks, rotation=0)
plt.xticks(plt.xticks()[0], labels=xticks)# title
title = 'CORRELATION MATRIX\nSAMPLED CEREALS COMPOSITION\n'
plt.title(title, loc='left', fontsize=18)plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

最终相关矩阵。

非常酷,我们为相关矩阵建立了一个漂亮的可视化。现在更容易看到最显著的相关系数,比如纤维和钾。

密度

通常,在一个相关矩阵之后,我们可以更好地了解具有强关系的变量。

在这种情况下,我们没有太多的数据来研究,所以散点图将足以开始调查这些变量。

散点图的问题是,如果数据太多,它们会变得难以阅读,因为点开始重叠。这时,热图会回到场景中来可视化密度。

fig, ax = plt.subplots(1, figsize=(12,8))sb.kdeplot(df.potass, df.fiber, cmap='Blues',
           shade=True, shade_lowest=False, clip=(-1,300))plt.scatter(df.potass, df.fiber, color='orangered')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

密度估计

如果你有兴趣更多地了解 KDE,我建议你看看马修·康伦关于这个话题的文章。

我们探索了热图中的大多数基础知识,并研究了它们如何通过彩色地图、条形图、蒙版和密度估计来增加复杂性。

感谢阅读我的文章。我希望你喜欢它。
更多教程 | 推特

资源:
Seaborn 示例—相关矩阵
海鸟选择颜色托盘
Seaborn 发散调色板
Seaborn 核密度估计图

对冲基金复制 ETF 策略

原文:https://towardsdatascience.com/hedge-fund-replication-etf-strategy-634068ed6629?source=collection_archive---------37-----------------------

美国股市的系统性板块轮动策略

目标

与传统资产管理相比,对冲基金的投资者可能普遍担心流动性较低、杠杆利用率较高以及费用较高。通过使用行业 ETF,我的目标是模仿对冲基金在行业配置中的风险/回报状况,但同时享受使用低成本高流动性资产的好处。板块轮动策略允许投资者从市场周期的不同阶段获取回报,同时通过增加(减少)上涨(下跌)板块的权重来分散投资。行业 ETF 提供了一种投资不同行业数百只股票的廉价工具。

战略

其策略是训练一个模型来预测代表每个行业的 11 只 ETF 的未来回报,方法是利用最大对冲基金过去的行业配置,基于它们最新的 13F 股票申报文件。测试期间的预期回报将用于均值-方差优化,以获得每个部门的最佳权重。这些权重将被分配到 11 只行业 ETF 的投资中。

数据

美国 SEC 13F 文件

Form 13F 是由 AUM 超过 1 亿美元的机构投资经理向美国 SEC 提交的季度报告,列出了管理的所有股票资产。该表格需要在一个日历季度结束后的 45 天内提交,它提供了对冲基金持仓的滞后快照,将用于预测未来行业 ETF 回报的模型中。虽然有专门提供高质量备案数据的数据供应商,但我采用了免费但肮脏的方式,从 SEC 记录中搜集顶级对冲基金的所有持股。

投资经理会得到一个 CIK id,我们可以通过这个 id 查询他们的档案。以下是 Bridgewater Associates,LP (CIK: 0001350694)在 2020 年 5 月 14 日提交的 2020 年第一季度文件的示例。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

【https://www.sec.gov/cgi-bin/browse-edgar?action=getcompany】T4&CIK = 0001350694&所有者=包括&计数=40

将股票的 CUSIP 映射到 ticker

将 CUSIP 映射到 ticker 允许我们查询每只股票的板块。我无法访问任何付费服务,所以我使用了 SEC 的未能交付数据,其中包含 CUSIP 和股票代码列。我用历史失败数据创建了一个映射表,并通过查询 CUSIP(例如 00206R102)和从 Investing.com 抓取符号(例如 T)来补充它。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从 ticker 获取扇区

现在我们已经有了每只股票的股票代码,我们可以通过查询股票代码并从 Yahoo Finance 中抓取 sector 来将 SECTOR 列添加到映射表中。我补充了对 marketbeat.com 的搜索。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

对冲基金的行业持股

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我根据我的映射表汇总了每个部门的持股价值,需要注意的是,我排除了在雅虎财经和 marketbeat.com 找不到的股票。因此,结果的质量取决于输入数据和使用的映射。我选择了 AUM 的 13 个顶级高频交易,但我们肯定可以研究更多,以找到重要的美国股票交易者。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这些文件最早可以追溯到 1998 年,但我们受到行业 ETF 价格范围的限制。

行业 ETF 回报

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我使用 yfinance API 得出了自 2000 年以来每个领域最大的 ETF 的价格。在 0.02 的无风险利率下,我计算了每个 ETF 的年化持有期超额回报,其中持有期由以下因素决定:

  • 条目:当我们列表中的最后一个 HF 已经为最后一个季度提交时,当天的收盘价
  • 退出:当我们的名单中的第一个 HF 为当前季度提交时,当天的收盘价

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

生成阿尔法

方法

这个想法是为每个 HF 训练一个预测模型,其中输入是 HF 的最后一个季度的部门分配,输出是部门 ETF 的持有期超额收益。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

基本模型是一个多变量多输入随机森林回归器,能够纳入因变量(部门阿尔法值)之间的关系。元模型是 13 个 HF 模型中每一个的预测输出的打包,以预测最终输出。由于一些基金是在 2000 年之后开始的,13 个 HF 模型中的每一个都有不同的培训开始日期,从 2000 年第二季度到 2006 年第四季度,到 2015 年第四季度结束。因此,培训周期从 37 到 60 个季度不等。从 2016 年第一季度到 2019 年第四季度(16 个季度),所有 13 款 HF 车型的测试周期一致。整体 bagging 模型预测样本外测试期间的 alphas 值。

假设

我想使用状态空间/混合模型来解释每个 HF 中部门持股的嵌套结构,但是样本大小对于这样一个高参数空间来说太小了。bagging 方法假设每个 HF 的部门轮换对市场的影响相等。虽然我们无法与 HFs 一起先行或投资,但我们可以尝试预测市场对其分配的延迟影响。一个很大的假设是,市场对 HFs 行业轮动的反应很慢,而且它们的波动足够大,足以影响市场。

结果

然后,将测试季度 x 部门 ETF 的总体 alphas 值输入均值-方差优化,以获得每个部门的最佳权重。

Test Period from 2016Q1 to 2019Q4
---------------------------------
Expected annual return: 14.4%
Annual volatility: 10.1%
Sharpe Ratio: 1.43

分配给每个行业 ETF 的权重:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

然后,在最新季度报告(2020 年第一季度)的延期集合中,对分配的权重进行测试,并与所有部门的权重平均分配进行比较。

Allocated Weights
-----------------
Annualized Expected Return: -48.76%
Annualized Volatility: 64.11%
Annualized Sharpe Ratio: -1.042Equal Weights
-----------------
Annualized Expected Return: -52.82%
Annualized Volatility: 67.19%
Annualized Sharpe Ratio: -1.116

尽管通过降低风险和提高回报,该策略的表现略好于同等权重,但很明显,该策略仍严重暴露于美国股市整体低迷的风险之中。在上面的图表中,行业 ETF 的历史回报率似乎也是高度相关的。

如果模型足够精确,对冲基金回报能否成功复制的问题可以归结为市场风险因素(贝塔)与基金优势(阿尔法)。尽管如此,回报是出了名的难以预测,尤其是当我们的 alpha 可能会随时间而丢失,并且由于模型的数据和稳健性而质量相当差的时候。

Hello Danfo:用于 Javascript 的熊猫,来自 Tensorflow

原文:https://towardsdatascience.com/hello-danfo-pandas-for-javascript-from-tensorflow-3d1d0ea3f3be?source=collection_archive---------34-----------------------

Tensorflow.js 刚刚获得了更多端到端

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图多尔·巴休在 Unsplash 上的照片

什么和为什么

如今,绝大多数数据科学家都生活在 python 和熊猫的世界里。有了 python 的金星科学计算堆栈,很难解释为什么它应该有所不同。通过绑定将“尽可能的方便”和“必须的性能”结合起来是无与伦比的。

然而,有一个类似的流行使用渠道,它为应用程序开发人员提供了更多的服务,而不是普通的数据科学家:javascript 中的机器学习。

等等……什么

Javascript 中的机器学习

这并不像听起来那么荒谬。嗯,大部分是。

当然,绝大多数数据科学家应该继续使用现有的和流行的基于 python 的框架(PyTorch、Tensorflow、ONNX 等)。这些框架已经过高度优化,以支持小规模和大规模的快速研究和应用。

然而,机器学习的民主化正在进行中,从爱好者到应用程序开发者再到物理学家,每个人都想分一杯羹。自 NodeJS 兴起以来,应用程序开发人员经历了端到端功能的爆炸式增长,因此有理由希望机器学习也能在浏览器中实现。请注意,这意味着浏览器真正地将样本处理成预测,而不仅仅是将请求发送到远程后端。

这种浏览器内的 ML 运动已经发展了一段时间,主要是通过 TensorflowJS。谷歌基于 Javascript 的 Tensorflow 库为大量全栈开发人员带来了机器学习,这些开发人员不想仅仅为了让一个基于 python 的服务可以完成所有算法处理而旋转多个服务。

一个肮脏的数据仓库

尽管如此,仍有一个关键环节缺失。向普通开发人员提供 ML 建模只有在没有强大灵活的数据处理库的情况下才有意义。

输入 DanfoJS。

Danfo.js 是一个开源的 JavaScript 库,为操作和处理结构化数据提供了高性能、直观且易于使用的数据结构。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由亚伦·伯顿Unsplash 上拍摄

想象一下,如果没有我们心爱的 pandas、NumPy、scikit-learn 和其他工具使预处理和 ETL 代码变得更加简单和简洁,会有多么冗长。除了许多额外的定制工作时间之外,每个人的 ML 管道最终看起来比 python 环境中的更加不同和支离破碎。丹佛的开发者明白。只需看看主页上宣传的一些要点:

-轻松处理浮点和非浮点数据
中缺失的数据(表示为NaN)-大小可变性:可以从数据帧
中插入/删除列-自动和显式对齐:对象可以显式对齐到一组标签,或者用户可以简单地忽略标签,让[Series](https://danfo.jsdata.org/api-reference/series)[DataFrame](https://danfo.jsdata.org/api-reference/dataframe)等。在计算中为您自动对齐数据
-强大、灵活的分组功能,可对数据集执行拆分-应用-组合操作,用于聚合和转换数据
-轻松将数组、JSON、列表或对象、张量和不同索引的数据结构转换为 DataFrame 对象
-智能的基于标签的切片、花式索引和大型数据集查询
-直观的合并和连接数据集
-强大的 IO 工具,用于从平面文件加载数据
-强大、灵活、直观的 API,用于交互式绘制数据帧和系列。
-特定于时间序列的功能:日期范围生成以及日期和时间属性。

第一眼

请看下面的片段,它摘自一个示例笔记本,该笔记本用 TensorflowJS 训练了一个泰坦尼克号生存预测模型。如果你问我的话,这和典型的pandas语法没有太大区别。

仅从这个片段中就可以看出一些东西:

  • python 数据生态系统用户非常熟悉该语法
  • 代码还有一个 MinMaxScaler 助手
  • 数据类型对张量有一流的支持

很好,不用再搜索“np.array to tensor”😅。

该库的产品包括建模库中通常存在的附加缩放/标记助手: OneHotEncoderStandardScalerMinMaxScalerLabelEncoder

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Max DuzijUnsplash 上拍摄的照片

结论

许多人认为浏览器内的机器学习是一只跛脚鸭。我认为它有有效的用例,比如端到端的 javascript 应用程序。随着 DanfoJS 的发展,TensorflowJS 中的 ML 管道可以清理很多,也可以适当地从其他地方移入数据处理代码。

除了验证这一运动,我更感兴趣的是随着时间的推移会有什么样的项目出现。你有什么项目想法吗?在这里让我知道

资源

资源

保持最新状态

除了在 Medium 上,用 LifeWithData 博客、机器学习 UTD 时事通讯和我的 Twitter 让自己保持更新。通过这些平台,我分别提供了更多的长文和简文思想。

如果你不是电子邮件和社交媒体的粉丝,但仍然想留在圈子里,可以考虑在 Feedly 聚合设置中添加lifewithdata.org/bloglifewithdata.org/newsletter

你好 PyTorch —安装和数量比较

原文:https://towardsdatascience.com/hello-pytorch-installation-numpy-comparison-9879fd677af3?source=collection_archive---------33-----------------------

在进入深度学习之前掌握基础知识

如果你没有生活在岩石下,你很可能听说过深度学习现在是一件事。为了执行深度学习任务而不太令人头疼,需要一个复杂的库。

这就是 PyTorch 的用武之地——它是一个用于 NLP 和计算机视觉等任务的开源机器学习库。它主要是由脸书开发的,最重要的是,它有一个完善的 Python 接口。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由伊戈尔·莱皮林Unsplash 上拍摄

在开始编写代码之前,首先需要在您的机器上安装这个库。安装本身会因用户而异,这取决于你是否有一个 GPU 。如果你有,那么请参考这篇文章,它演示了如何安装 CUDAcud nn——如果你想在 GPU 上训练深度学习模型,这是一个先决条件:

[## 在 Windows 10 上安装 Tensorflow-GPU 2.0 的简单指南

根据你的网速,不会超过 15 分钟。

towardsdatascience.com](/an-utterly-simple-guide-on-installing-tensorflow-gpu-2-0-on-windows-10-198368dc07a1)

我知道它说的是“张量流”,所以你只需要遵循文章的第一部分。

CUDA 和 cuDNN 装了? 很好,可以进行了。

PyTorch 装置

请导航至该网站并点击与您的机器相关的选项:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

正如你所看到的,PyTorch.org 的人会生成一个终端命令,你必须执行这个命令。GPU 版本大约有 750MB,所以下载和安装可能需要一段时间。

一旦完成,你就可以打开一个 Jupyter 笔记本环境。

导入方面,唯一需要的两个库是 NumpyTorch ,它们可以很容易地导入,如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

声明数组/张量

在 Numpy 库中,数组的概念是已知的,与数组是一维还是多维无关。在 Torch 中为相同的概念,但是名字张量被使用并且被用来概括一个 n 维数组的概念。

下面是如何在 Numpy 中声明一个二维数组:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

下面是如何在 PyTorch 中做同样的事情:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

可以看出,如果不考虑不同的方法名,声明背后的思想或多或少是相同的。

但是还有什么可以从 Numpy 转移到 Torch 呢?让我们看看如何声明随机数的数组。Numpy 实现优先:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

现在在火炬中:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

语法还是差不多一样。

数据科学家发现他们自己一直在做的一件更常见的事情是检查数组的形状。在这个领域,形状不匹配的错误是很常见的,为了避免这种错误,我们需要知道是什么原因造成的。

下面是如何在 Numpy 中检查数组的形状:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这一次在 Torch 中,语法实际上是相同的,这对加速您的学习过程很有帮助:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

太好了。展示了一些基本的东西,希望你现在对这个库感觉更舒服了。让我们在下一节检查更多的常见任务。

矩阵乘法

要在 Numpy 中执行矩阵乘法,首先需要创建两个可以相乘的矩阵**—因此第一个矩阵的列数必须与第二个矩阵的行数相同。**

让我们看看 Numpy 是怎么做的:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

与 Numpy 不同,Torch 使用.mm()方法来乘矩阵:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

然而,如果.mm()听起来不像你想的那样具有描述性,也可以使用.matmul()方法:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如果目标不是执行矩阵乘法,而是逐元素执行乘法**,则矩阵可以像任何其他常规数一样相乘,并且 Numpy 和 Torch 的语法是相同的:**

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

用矩阵还能做什么?

其他一些常见的操作是声明一个具有特定形状的零数组。在两个库中,这样做的语法是相同的:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

然而,如果您希望创建一个任意的具有特定形状的一个的数组,您可以使用方便的.ones()函数。两个库的语法也是相同的:****

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从 Numpy 转换到 Torch,反之亦然

借助于.from_numpy()函数,可以很容易地将任意数量的数组转换成 Torch 张量:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

以类似的方式,Torch 张量可以转换为 Numpy 数组。语法有点不同,所以请记住:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

下一步是什么?

本文的目的是介绍 PyTorch,并将其与您可能熟悉的东西 Numpy 库进行比较。

Numpy 很棒,但是不像 PyTorch in 不能在 GPU 上运行,所有的深度学习都需要 GPU 在合理的时间内进行训练**。接下来的文章将更加面向深度学习,如果你对此感兴趣,请继续关注。**

然而,在实际使用它们来解决现实世界的问题之前,您应该了解您正在使用的库的基础知识,所以这就是本文试图涵盖的内容。

感谢阅读,敬请期待更多内容。

喜欢这篇文章吗?成为 中等会员 继续无限制学习。如果你使用下面的链接,我会收到你的一部分会员费,不需要你额外付费。

** [## 通过我的推荐链接加入 Medium-Dario rade ci

作为一个媒体会员,你的会员费的一部分会给你阅读的作家,你可以完全接触到每一个故事…

medium.com](https://medium.com/@radecicdario/membership)**

你好世界!为了机器学习

原文:https://towardsdatascience.com/hello-world-for-machine-learning-4dc9af0a7430?source=collection_archive---------31-----------------------

从零开始构建您的第一个模型,开始机器学习!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 VisionPic 提供。来自的网像素

我不断被问到的最常见的问题是“ ”我如何开始机器学习? ”。我知道这可能会让人不知所措——网上有各种各样的工具和资源,你不知道从哪里开始。相信我,我也经历过。因此,在本文中,我将尝试让您开始构建机器学习模型,并让您熟悉行业中正在使用的实践。

我希望你知道一个小 python,我们将用它来编码我们的模型。如果没有,这里是你开始的好地方:https://www.w3schools.com/python/

什么是机器学习?

机器学习为系统提供了自动学习和改进的能力,而无需显式编程。

“哎呀!那似乎很复杂……"

简单来说,机器学习就是模式识别而已。就是这样!

想象一个婴儿在玩玩具。他必须把积木放在正确的槽里,否则就放不进去了。他试着把方砖粘到圆孔上。没用!他又试了几次,最终把方块放进了方孔。现在,每当他看到一个方块,他就会知道把它放进方孔里!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

来自 PexelsTatiana Syrikova 的照片

这是机器学习如何工作的一个非常简单的想法。机器学习(显然)更复杂,但让我们从简单的开始。

在机器学习中,不是试图定义规则并用编程语言表达,而是你提供答案(通常称为标签)和数据,机器将推断出确定答案和数据之间关系的规则。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这些数据和标签用于创建机器学习算法(“规则”),通常称为模型

使用这种模型,当机器获得新数据时,它可以预测或正确地标记它们。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

例如,如果我们用猫和狗的标签图像训练模型,模型将能够预测何时显示一张新图像,它是猫还是狗。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如何训练一个模特!

现在我们有了基本的了解,让我们开始编码吧!

创建您的第一个机器学习模型

考虑下面一组数字。你能弄清楚他们之间的关系吗?

X: -1 0 1 2 3 4 5

Y: -1 1 3 5 7 9 11

当你从左向右阅读时,你可能会注意到 X 值增加了 1,相应的 Y 值增加了 2。所以你可能会想 Y=2X 加减什么的。然后你可能会看到 X 上的 0,看到 Y = 1,你会得出关系 Y=2X+1

现在如果给我们一个 X 的值 6,我们可以准确的预测出 Y 的值为 2*6 + 1 = 13。

对你来说想明白这一点一定很容易。现在让我们试着用电脑来解决这个问题。戴上你的编码帽子,因为我们即将编码我们的第一个机器学习模型!

设置环境

我们将使用 Google Colab 来编写我们的代码。

那么 Google Colab 是什么?

这是一个令人难以置信的基于浏览器的在线平台,允许我们免费在机器上训练我们的模型!听起来好得难以置信,但多亏了谷歌,我们现在可以处理大型数据集,建立复杂的模型,甚至与他人无缝共享我们的工作。

所以基本上这就是我们训练和使用模型的地方。你需要一个谷歌账户来使用 Colab。一旦完成,创建一个新的笔记本。瞧啊。你有了你的第一个笔记本。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如果你以前没有使用过 Google Colab,请查看本教程,了解如何使用它。

现在我们真正地写代码了!

让我们开始编码吧

完整的笔记本可以在这里这里(GitHub) 获得。

进口

我们正在导入 TensorFlow,为了方便使用,将其命名为 tf。

接下来,我们导入一个名为 numpy 的库,它帮助我们轻松快速地将数据表示为列表。

将模型定义为一组连续层的框架称为 keras,所以我们也导入了它。

创建数据集

如前所示,我们有 7x 和 7y,我们发现它们之间的关系是 Y = 2X + 1

一个名为 numpy 的 python 库提供了许多数组类型的数据结构,这实际上是一种标准的方法。我们通过使用 np.array[]在 numpy 中将值指定为数组来声明我们想要使用这些

定义模型

接下来,我们将创建尽可能简单的神经网络。它有一层,那层有一个神经元,它的输入形状只是一个值。

你知道在函数中,数字之间的关系是 y=2x+1

当计算机试图“学习”时,它会进行猜测…也许是 y=10x+10 。损失函数根据已知的正确答案来衡量猜测的答案,并衡量它做得好或坏。

接下来,该模型使用优化器函数进行另一次猜测。基于损失函数的结果,它将尝试最小化损失。此时,它可能会得出类似于 y=5x+5 的结果。虽然这仍然很糟糕,但它更接近正确的结果(即损失更低)。

这个模型将会重复这个过程,你很快就会看到。

但首先,我们告诉它如何对损失使用均方差,对优化器使用随机梯度下降(sgd)。你还不需要理解这些的数学,但是你可以看到它们是有效的!😃

随着时间的推移,你将了解不同的和适当的损失和优化功能不同的情况。

训练模型

训练是模型学习的过程,就像我们之前说的那样。model.fit 函数用于将我们创建的训练数据拟合到模型中。

当您运行这段代码时,您会看到每个时期的损失都会打印出来。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

您可以看到,对于最初的几个时期,损失值相当大,并且随着每一步的进行,损失值越来越小。

使用训练好的模型进行预测

终于!我们的模型已经训练好了,准备好面对现实世界了。让我们看看,给定 x,我们的模型能多好地预测 Y 的值。

我们使用 model.predict 方法计算出 x 的任意值。

所以,如果我们取 X 的值为 8,那么我们知道 Y 的值是 2*8 + 1 = 17 。让我们看看我们的模型是否能做好。

[[17.00325]]

这比我们预期的值多了一点。

机器学习处理概率,所以给定我们提供给模型的数据,它计算出 X 和 Y 之间的关系很有可能是 Y=2X+1 ,但只有 7 个数据点我们无法确定。因此,8 的结果非常接近 17,但不一定是 17。

就是这样!您已经介绍了可以在不同场景中使用的机器学习的核心概念。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

pettergra 的照片来自 Meme Economy

我们在这里使用的过程/步骤是您在构建复杂模型时会做的。

如果你已经建造了你的第一个模型,那么提升一个等级,用你的新技能建造一个 计算机视觉 模型怎么样?

[## 时尚服装分类-计算机视觉入门

通过创建一个对时尚服装图像进行分类的模型,开始学习计算机视觉。

towardsdatascience.com](/classifying-fashion-apparel-getting-started-with-computer-vision-271aaf1baf0)

编码快乐!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值