PyTorch深度学习框架学习记录(1)--安装,MNIST手写数字识别

框架发展

  • Caffe框架(2015)

    • 优势:只写配置文件

    • 劣势:运行环境配置繁琐

  • Tensorflow 1.x(2016)

    • 开发成本高,上手较难
  • keras(2017)

    • 封装的API很好用
  • Tensorfolw 2.x(2018)

  • PyTorch (2019)

    • 上手比较容易,直接套模板,学习成本较低
    • 目前的主流框架
    • 向下支持较好,新版本可以使用老版本的API

安装

cpu版本

直接pip安装

pip install torch

GPU版本

  1. 安装CUDA和cudnn

    查看显卡支持的CUDA版本,版本是向下兼容的,所装的不能高于显示支持的版本

    nvidia -smi
    

在这里插入图片描述

提前去PyTorch官网Start Locally | PyTorch看一下torch目前所支持的版本,然后去CUDA的官网CUDA Toolkit 12.0 Downloads | NVIDIA Developer选择对应版本安装,

显卡支持版本-torch支持版本-CUDA版本-CUDNN版本都要对应

**可以看其他相关博客,我当时装的时候有借鉴下面的这些文章,对作者们表示感谢!**比如:

安装完后在命令行输入:

    nvidia -v

显示CUDA版本就是成功了

在这里插入图片描述

  1. 安装torch

    就在PyTorch官网查看版本,我是使用conda安装的

在这里插入图片描述

​ 在conda创建虚拟环境以后就可以将那行命令输入。

​ 如果速度慢可以更改conda为国内镜像源

中间遇到很多问题,但是耐心,耐心,耐心,有问题慢慢解决,总会有解的,真的有很多博客给了我大量帮助,感谢作者,感谢CSDN。

安装完成后运行以下代码

import torch

print(torch.__version__)  # torch 版本
print(torch.cuda.is_available())  # torch 是否为gpu版本, 是的话会返回True

学习记录

MNIST手写数字识别

数据的准备工作

非常重要,但是只使用MNIST学习过程,所以并不需要深究,不同的数据集的处理都不一样

因为MNIST数据集很简单,由28×28的灰度图像组成,所以每张图片都是784个灰度数字。有脚本可以直接下载:

"""
download_mnist.py
下载数据
"""
from pathlib import Path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)

但是经常有网络问题,也可以去找MNIST的pkl格式的数据。

mnist.pkl.gz百度网盘链接

链接:https://pan.baidu.com/s/1nx2k5IPAnP1u6CkRR8NXfw?pwd=zbqy
提取码:zbqy
为了方便管理代码,所以单独将设置路径的代码写了一个py文件

"""
path_setting.py
设置数据所在路径
"""
import pickle
import gzip
from pathlib import Path

"""保存路径data/mnist/mnist.pkl.gz"""
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
FILENAME = "mnist.pkl.gz"

如果想看下数据集的内容,可以使用下面的方式:

"""
show.py
查看数据内容
"""
import matplotlib.pyplot as plt
import pylab
from path_setting import *

"""读取图像"""
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")  # 读取数据
    """
    x_train: 训练数据
    y_train: 训练标签
    x_valid: 验证数据
    y_valid: 验证标签
    """


print(x_train.shape)  # 查看x_train的形状: (50000, 784)
print(x_valid.shape)  # 查看x_valid的形状: (10000, 784)
print(y_train.shape)  # 查看y_train的形状: (50000, )
print(x_train[0])  # 0号数据(第一个数据)的784个灰度值
print(y_train[0])  # 0号数据(第一个数据)的标签: 5
print(x_train[0].shape)  # 看一下x_train中一个图像的形状: (784,) 784个灰度值
plt.imshow(x_train[0].reshape(28, 28), cmap="gray")  # 更改图像的形状为28 × 28
pylab.show()  # 展示图像

运行结果:

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值