pytorch 三维点分类_pytorch 实现简单二元分类器以及可视化

本文介绍如何使用PyTorch构建一个二元分类器,并通过matplotlib进行可视化。首先生成两组数据,然后定义神经网络模型并进行训练。通过contourf函数画出决策边界,展示3D点在2D平面上的分类效果。最后,展示了一个预测示例。
摘要由CSDN通过智能技术生成

这个文章是用pytorch和matplotlib实现一个二元分类器并且可视化。

思路

自己生成两团数据。

定义自己的神经网络类。

训练网络

打印出边界。

先放效果图

关于可视化

定义网络、训练网络主要没什么好说的啦其实,毕竟有pytorch这么好的框架,已经提供了如此简单的代码工作。

主要是可视化的技巧。

主要是matplotlib中有个contourf,本身是画等高线用的,就是地理中那个三维图投射到二维图的那种图。

我们可以把这个用到可视化上来(当然只是3维的,如果是更高维度就没法用这个可视化了)。

具体怎么可视化的?

首先,先自己生成200个训练数据(这步对应getData函数),然后把属于不同类别的数据染上不同颜色;

然后,进行网络的训练(对应run函数);

然后,用同样的数据让网络进行预测。因为二元分类器最后预测的结果要么是0,要么是1,所以可以利用matplotlib中的画等高线的函数,来近似画出决策边界。这一步主要对应showBoundary函数。

使用conturf函数

这个函数我自己在用的时候有点懵逼,使用这个要先meshgrid,mesh合并的意思,grid网格的意思,要把两个列表先合成一个网格,这个形式我也不是很喜欢。

勉勉强强参考了一些博客才写了出来。具体我也没办法一一讲述,还请各位原谅。

不过其中,cmap是画出来的图的风格参数,可以是camp&

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值