Classification based on Pycharm
本实验讲解基于Pycharm开发工具利用Pytorch框架中的神经网络实现简单的分类问题。本电脑实验环境如下:①Anaconda 3环境②Pycharm编译器采用Anaconda环境中的Python③Pytorch(根据自己电脑环境安装,本实验采用Cuda10,Python3版本)。分类(Classification)即找一个函数判断输入数据所属的类别,可以是二类别问题(是/不是),也可以是多类别问题(在多个类别中判断输入数据具体属于哪一个类别)。与回归问题(Regression)相比,分类问题的输出不再是连续值,而是离散值,用来指定其属于哪个类别。分类问题在现实中应用非常广泛,比如垃圾邮件识别,手写数字识别,人脸识别,语音识别等。这里直接对代码实现部分进行讲解,基于torch框架,对随机生成的一些数据点进行分类,旨在帮助刚入门的学者从代码端对分类问题进行一个入门了解。
加载相关库文件模块
创建随机数据n_data,定义数据点分布一个为正向(标0),一个为负向(标1),由于n_data是随机生成的,因此数据点可以理解为随机生成的。同时,利用torch将他们存放在一起,这里注意要保持他们的维度相同。其中input_data只能为FloatTensor,即torch.float32,标签只能为LongTensor,即torch.int64。
这里搭建NN
实例化所搭建的网络,设置优化函数为SGD,损失函数为CrossEntropy。
for循环代表训练次数,if判断用于数据训练过程的可视化打印。这里面注意准确率的计算,采用float型。读者应自行学习,plt相关函数用法。最后的plt.pause()用于停顿,否则无法见到训练过程的数据可视化,这里是代码要注意的地方,下图为实验结果。