Pytorch的学习——RNN

本文介绍了PyTorch中RNN的使用,详细阐述了input_size、hidden_size和num_layers等参数,并通过分类问题(手写数字识别)和回归问题(sin函数预测cos函数)两个实例展示了RNN的训练过程与效果。
摘要由CSDN通过智能技术生成

RNN

在pytorch中RNN(循环神经网络)由 torch.nn中的RNN()函数进行循环训练,其参数有input_size,hidden_size, num_layers。

input_size:输入的数据个数
hidden_size:隐藏层的神经元个数
num_layers:隐藏层的层数,数值越大RNN能力越强,相应的训练消耗时间越多

分类问题

这里通过手写数字的一个小例子来了解pytorch中的rnn

import torch
import torch.utils.data as Data
import torch.nn as nn
import torchvision

EPOCH = 1                # 批处理的次数
BATCH_SIZE = 64          # 批处理时,每次提取的数据个数,这里是图片的个数
TIME_STEP = 28           # 读取图片的步数,数值为图片的height,这里选择的mnist数据图片长宽都是28
INPUT_SIZE = 28          # 一次读取的数据量,数值为图片的weight
LR = 0.001               # 学习效率
DOWNLOAD_MNIST = False   # 有没有下载mnist手写数字数据,如果没有改成True

# 加载数据
train_data = torchvision.datasets.MNIST(
    root='./mnist',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST
)
# shuffle=True代表乱序提取图片shuffle=False代表按顺序提取图片,
# 这里还有一个参数num_workers代表线程个数,默认是0,Window上Pytorch并不支持多线程训练,Linux支持,可以加快训练速度
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 加载测试数据
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
# 改变数据格式,这里除255是为了把数据放入0-1区间之内
test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.targets[:2000]

# 搭建神经网络
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        """
        这里小编没有用RNN()而是LSTM(),即长短期记忆这是RNN的另一种形式,因为通常的RNN可能会出现梯度弥散
        或者梯度爆炸的情况,而LSTM()就解决了该情况。
        """
        self.rnn = nn.LSTM
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值