PyTorch的HelloWord之旅

前言

PyTorch 是一个基于 Python
的深度学习平台,它简单易用上手快的同时功能十分强大。

本篇文章首先将介绍 PyTorch 的基本数据结构 Tensor
的一些操作;随后给出神经网络中 的 HelloWorld
例子:用最经典的卷积神经网络(LeNet5)训练手写数据集 MNIST

PyTorch 中的 Tensor

以下内容来自:
https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html

Tensor 简单讲就是多维数组,用来表示各种维度的数据。

Tensor 的创建修改

  • 创建一个未初始化的 5*3 的 Tensor
import torch
torch.empty(5, 3)
  • 初始化一个随机 5*3 矩阵
torch.rand(5, 3)
  • 初始化一个全零 5*3 矩阵,类型为 long
torch.zeros(5, 3, dtype=torch.long)
  • 用数据初始化一个tensor
torch.tensor([5.3, 4.6])
  • 从一个存在的tensor创建一个tensor
x = torch.tensor([5.3, 4.6])
x = x.new_ones(5, 3, dtype=torch.double)
print(x)

x = torch.randn_like(x, dtype=torch.float)
print(x)
  • 返回大小
x.size()

Tensor 的一些操作

运算(如加法)

  • +号

    x = torch.rand(5, 3)
    y = torch.rand(5, 3)
    print(x + y)
    
  • torch.add

    torch.add(x, y)
    
  • 提供一个参数保存结果

    result = torch.empty(5, 3)
    torch.add(x, y, out=result)
    print(result)
    
    • 原地相加
        y.add_(x)
        print(y)
    

改变维度,类似 numpy 中的 reshape

使用 tensor.view 改变 tensor 的大小。

x = torch.randn(4, 4)
y = x.view(16)
z = x.view(-1, 8)  # -1 被推断成其它维度
print(x.size(), y.size(), z.size())

返回值

对于只有一个元素的 tensor 通过 .item() 得到它的值

x = torch.randn(1)
print(x)
print(x.item())

Tensor 转换成 Numpy

a = torch.ones(5)
print(a)
b = a.numpy()
print(b)

a 和 b 共享内存

a.add_(1)
print(a)
print(b)

Numpy 转换成 Tensor

import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
print(b)

CUDA Tensor

Tensor 可以移动到任意设备,通过 .to 方法

if (torch.cuda.is_available()):
    device = torch.device('cuda')
    y = torch.ones_like(x, device=device)
    x = x.to(device)
    z = x + y
    print(z)
    print(z.to('cpu', torch.double))

搭建 LeNet5 训练 MNIST 数据集

MNIST 数据集处理

虽然 PyTorch 中已经预置了 MNIST
数据集的处理代码,但是我们要有自己处理数据集的能
力,特别是在学习阶段,所以本文会自己处理数据集,然后结合 PyTorch
的数据处理机制。

MNIST 数据集的结构

MNIST
数据集包含60000张训练用的图片,10000张测试用的图片,每个图片均有对应的标签。
每张图片的像素是 28 * 28,每个像素值的范围是 0 -
255,用8个比特表示。数据集有下
面四个二进制文件,对应训练图片,训练标签,测试图片,测试标签:

train-images.idx3-ubyte

train-labels.idx1-ubyte

t10k-images.idx3-ubyte

t10k-labels.idx1-ubyte

图片(idx3)的格式:首先是32位的整数,是一个magic
数字,接下来32位整数表示图片的
数量,接下来的两个32位整数是分别是图片行数和列数,

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值