这个文章是用pytorch和matplotlib实现一个二元分类器并且可视化。
思路
自己生成两团数据。
定义自己的神经网络类。
训练网络
打印出边界。
先放效果图
关于可视化
定义网络、训练网络主要没什么好说的啦其实,毕竟有pytorch这么好的框架,已经提供了如此简单的代码工作。
主要是可视化的技巧。
主要是matplotlib中有个contourf,本身是画等高线用的,就是地理中那个三维图投射到二维图的那种图。
我们可以把这个用到可视化上来(当然只是3维的,如果是更高维度就没法用这个可视化了)。
具体怎么可视化的?
首先,先自己生成200个训练数据(这步对应getData函数),然后把属于不同类别的数据染上不同颜色;
然后,进行网络的训练(对应run函数);
然后,用同样的数据让网络进行预测。因为二元分类器最后预测的结果要么是0,要么是1,所以可以利用matplotlib中的画等高线的函数,来近似画出决策边界。这一步主要对应showBoundary函数。
使用conturf函数
这个函数我自己在用的时候有点懵逼,使用这个要先meshgrid,mesh合并的意思,grid网格的意思,要把两个列表先合成一个网格,这个形式我也不是很喜欢。
勉勉强强参考了一些博客才写了出来。具体我也没办法一一讲述,还请各位原谅。
不过其中,cmap是画出来的图的风格参数,可以是camp&