Pytorch环境下搭建lenet5网络实现手写数字识别
(文章采用cuda9.0+cudnn7.4+pytorch1.6环境)
数据集选用EMNIST dataset中的手写数据集,参考链接如下:
数据集下载地址
代码部分参考S.E作者的pytorch实现手写英文字母识别,链接如下:
代码部分
在原作者的基础下将整个运行环境搭载在GPU上,实现代码如下:
#查看GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#模型上GPU
cnn.to(device)
#数据上GPU
x=x.to(device)
y=y.to(device)
test_x=test_x.to(device)
test_y=test_y.to(device)
实现结果如下: