import torch
from torch import nn
import torchvision.datasets as dsets
import torch.utils.data as Data
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
EPOCH = 5
BATCH_SIZE = 64
TIME_STEP = 28 # 考虑了多少时间点的数据
INPUT_SIZE = 28
LR = 0.001
USE_GPU = torch.cuda.is_available()
print('GPU:', USE_GPU)
# 准备数据集
train_dataset = dsets.MNIST(root='../../data_sets/mnist',
train=True,
transform=transforms.ToTensor(),
download=False)
test_dataset = dsets.MNIST(root='../../data_sets/mnist',
train=False,
transform=transforms.ToTensor(),
download=False)
train_data_loader = Data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_data_loa
LSTM 循环神经网络实践(基于Mnist数据集)
最新推荐文章于 2024-07-06 23:59:46 发布
使用LSTM循环神经网络在Mnist数据集上进行实践,测试准确率达到98.68%,展示了网络的优秀性能。训练过程中loss函数图像呈现出典型的锯齿形变化。
摘要由CSDN通过智能技术生成