最近实现了一个简单的手写数字识别的程序,我安装的pytorch是gpu版(你也可以安装cpu版本的,根据个人需要),这里我介绍pytorch的gpu版本和cpu版本的安装以及训练手写数字识别时gpu和cpu之间的切换。
1、pytorch的安装
1.1 pytorch(带有gpu)安装
首先进入pytorch官网,选择自己所需要的版本,这里我选择的版本如下图所示。
![d1bf71250f8d2e253cf9154046991129.png](https://i-blog.csdnimg.cn/blog_migrate/38f95b9bf098e078c3c4d55f23f5cca9.jpeg)
然后打开anaconda Prompt,首先输入:conda activate py3激活py3(解释一下为什么是py3,因为我之前装的是python3.6,创建的名字为py3),然后输入:conda install pytorch torchvision cudatoolkit=9.0 -c pytorch安装pytorch,等待安装就好,如下图所示。
![f99bf2ce9b44db7c5ad26f378d19d3dd.png](https://i-blog.csdnimg.cn/blog_migrate/0bab36b8f10c4342d52c19b14e7ee66f.jpeg)
![324d76e4d19b9fe9eb4f571bd23a82eb.png](https://i-blog.csdnimg.cn/blog_migrate/65552bdfaa807dd87cbebd041b752015.jpeg)
1.2 pytorch(无gpu)安装
这时CUDA选择none即可
![95c13386968fc270c5c4b985428a65ac.png](https://i-blog.csdnimg.cn/blog_migrate/aa4ad6127a4ea8ead30abf19cda48324.jpeg)
打开anaconda终端,首先激活py3,然后输入这个命令:conda install pytorch-cpu torchvision-cpu -c pytorch,等待安装就好,如下图所示。
![c01ef25e9d34d88756998cbe7983c37a.png](https://i-blog.csdnimg.cn/blog_migrate/16bf0a530a851b40dd982327935440c9.jpeg)
1.3 测试是否安装成功
首先cmd打开终端,输入python即可查看当前安装的python的版本,然后import torch 等待几秒出现如下图所示,这样就成功安装了。
![be70a97a552d57f947b46f0aece5305e.png](https://i-blog.csdnimg.cn/blog_migrate/7643b9ef4c71d801287ccb20f000da82.jpeg)
2、选择cpu进行网络的训练(推荐下载带有gpu的)
因为下载gpu版本的,训练时可以选择gpu或者cpu进行训练。但是下载cpu版本,只能选择cpu进行训练。
2.1新建一个model.py模块
from
2.2 新建一个train.py模块(使用cpu训练的)
主要的格式为
def get_variable(x):
x = Variable(x)
return x.cpu() if torch.cuda.is_available() else x
...
...
...
cnn = CNN()
if torch.cuda.is_available():
cnn = cnn.cpu()
具体cpu训练实例如下所示
import
2.3 新建一个test.py模块
import
3、选择gpu进行网络的训练
3.1 model.py模块不变,可以参考以上2.1
3.2 train.py模块(gpu训练)代码如下,可以对比以上2.2
gpu训练需要补充:gpu训练可以选择gpu设备,详细请看以下主要格式部分的代码模块。
主要格式为
# 将数据处理成Variable, 如果有GPU, 可以转成cuda形式
def get_variable(x):
x = Variable(x)
return x.cuda() if torch.cuda.is_available() else x
...
...
...
cnn = CNN()
#这部分默认的是两个gpu训练
if torch.cuda.is_available():
cnn = cnn.cuda()
#这部分表示可以对gpu进行选择,只需要标明设备号
# if torch.cuda.device_count() > 1:
# cnn = nn.DataParallel(cnn, device_ids=[0])
具体gpu训练实例如下所示(以下这种gpu训练效果比较好)
import
3.3 test.py模块不变,可以参考以上2.3
4、最后附上测试结果
![10df8b502488370d0c9708b75d395e93.png](https://i-blog.csdnimg.cn/blog_migrate/da2262e2125df9ba7025fc59f84e6fbd.png)