pytorch初学概念整理-RNN-MNIST数据集

这篇博客介绍了如何在PyTorch中使用RNN处理MNIST数据集。文章详细解析了源代码,包括库的引入、超参数定义、RNN网络类的定义,特别是nn.RNN和nn.Linear()的使用。此外,还讨论了初始化函数_init_以及训练过程中的优化器和损失函数nn.CrossEntropyLoss。
摘要由CSDN通过智能技术生成

源代码

源代码直接采用找到的已写好的代码,可以看到基本这套代码用的是LSTM结构,对mnist数据进行训练,下面解读一下,顺便修改+对比一下LSTM和RNN的训练区别

import torch
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os

超参数定义

这部分定义训练网络结构的一些参数
代表的意思用注释标在代码行里了

# Hyper Parameters
EPOCH = 3 # 喂几轮数据
BATCH_SIZE = 64 #一次前向传播训练里面用多少个样本
INPUT_SIZE = 28  #一个样本的大小,这里是像素点,一个图像28个像素点。换成文本就是看one-hot字典集的长度
LR = 0.01 #Adam优化器的下降步长
DOWNLOAD_MNIST = False

定义RNN网络类

这里使用的是对nn.Module进行继承和修改。相关定义和理解参考链接

事实上,在pytorch里面自定义层也是通过继承自nn.Module类来实现的,我前面说过,pytorch里面一般是没有层的概念,层也是当成一个模型来处理的,这里和keras是不一样的。前面介绍过,我们当然也可以直接通过继承torch.autograd.Function类来自定义一个层,但是这很不推荐,不提倡,至于为什么后面会介绍。记住一句话,keras更加注重的是层Layer、pytorch更加注重的是模型Module。
————————————————
版权声明:本文为CSDN博主「LoveMIss-Y」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_27825451/article/details/90705328

class RNNnet(nn.Module): #继承nn.Module
    def __init__(self): #这里self指的实例本身,self必须是第一个参数
        super(RNNnet, self).__init__()
        self.rnn = nn.RNN(
            input_size=INPUT_SIZE,
            hidden_size=64,  
            num_layers=1,
            batch_first=True
        )
        self.out = nn.Linear(64, 10)

    def forward(self, x):
        # x shape (batch, time_step, input_size)
        # r_out shape (batch, time_step, output_size)
        # h_n shape (n_layers, batch, hidden_size)
        # h_c shape (n_layers, batch, hidden_size)
        r_out, h_n = self.rnn(x, None)
        # choose r_out at the last time step
        out = self.out(r_out[:, -1, :])
        return out

_init_函数基本定义

其中&

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值