PyTorch--双向长短期记忆网络(BiRNN)在MNIST数据集上的实现与分析

前言

本代码实现了一个基于PyTorch的双向长短期记忆网络(BiRNN),用于对MNIST数据集中的手写数字进行分类。MNIST数据集是一个广泛使用的计算机视觉数据集,包含了大量的手写数字图像,适合用来训练和测试深度学习模型。

代码的关键特点包括:

  1. 数据加载与预处理:使用torchvision库加载MNIST数据集,并应用了标准化变换以准备数据输入模型。

  2. BiRNN模型定义:模型使用nn.LSTM模块构建双向LSTM层,能够处理序列数据,并通过nn.Linear层进行最终的分类。

  3. 设备无关性:通过torch.device自动选择GPU或CPU,提高了代码的通用性。

  4. 训练与测试:实现了模型的训练循环和测试循环,包括损失计算、反向传播和参数更新。

  5. 可视化工具:集成了数据可视化和模型架构可视化功能,使用matplotlib库展示数据样本和训练进度。

  6. 模型保存:训练完成后,使用torch.save保存模型参数,方便后续的加载和使用。

  7. 超参数设置:提供了灵活的超参数设置,包括隐藏层大小、层数、批次大小、训练轮数和学习率。

代码结构清晰,易于理解和修改,适合作为深度学习入门和实践的参考。通过本代码,用户可以了解如何使用PyTorch构建和训练一个BiRNN模型,并对MNIST数据集进行分类任务。

说明

  • 确保安装了PyTorch、torchvision和matplotlib。
  • 调整超参数以适应不同的训练需求。
  • 运行代码,观察训练过程和测试结果。
  • 使用可视化工具了解数据和模型架构。

完整代码

import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(BiRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size*2, num_classes)  # 2 for bidirection
    
    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection 
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)


# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# Test the model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 

# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

代码解析

1.导入库

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

这部分代码导入了编写神经网络所需的PyTorch库及其子模块。

2.设备配置

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

根据是否有可用的GPU,设置计算设备,优先使用GPU以加速训练。

3.超参数设置

sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003

设置了模型训练所需的超参数,包括时间序列的长度、输入数据的尺寸、隐藏层的尺寸、LSTM层数、类别数、批次大小、训练轮数和学习率。

4.数据集加载

train_dataset = torchvision.datasets.MNIST(..., transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(..., transform=transforms.ToTensor())

加载MNIST数据集的训练集和测试集,并使用transforms.ToTensor()将图像数据转换为张量。

5.数据加载器

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

创建了两个数据加载器,分别用于训练和测试数据的批量加载。

6.定义BiRNN模型

class BiRNN(nn.Module):
    # 定义双向循环神经网络模型

创建了一个双向LSTM的模型,包含初始化方法和前向传播方法。

7.实例化模型并移动到设备

model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)

实例化BiRNN模型,并将模型移动到之前设置的计算设备上。

8.损失函数和优化器

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

定义了交叉熵损失函数和Adam优化器。

9.训练模型

for epoch in range(num_epochs):
    # 训练循环

在每个epoch中,遍历训练数据的每个批次,执行前向传播、计算损失、反向传播和参数更新。

10.测试模型

with torch.no_grad():
    # 测试循环

在测试阶段,关闭梯度计算,遍历测试数据的每个批次,计算模型的预测准确率。

11.保存模型

torch.save(model.state_dict(), 'model.ckpt')

保存模型的参数到文件,以便于后续的加载和使用。

这段代码实现了一个完整的训练和测试流程,适合用于分类任务,特别是涉及序列数据的任务。对于MNIST数据集,尽管它不是序列数据,但通过将图像的每一行视为序列的一部分,可以使用RNN进行处理。

常用函数

  1. torch.device

    • 格式:torch.device(device_str)
    • 参数:device_str —— 指定设备类型(如'cuda''cpu')的字符串。
    • 样式:属性访问器。
    • 示例:
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      
  2. torchvision.datasets.MNIST

    • 格式:torchvision.datasets.MNIST(root, train, transform, download)
    • 参数:
      • root —— 数据集存放的根目录。
      • train —— 是否加载训练集。
      • transform —— 对图像进行的变换操作。
      • download —— 是否下载数据集。
    • 样式:类方法调用。
    • 示例:
      train_dataset = torchvision.datasets.MNIST(root='../../data/', train=True, transform=transforms.ToTensor(), download=True)
      
  3. torchvision.transforms.Compose

    • 格式:torchvision.transforms.Compose(transforms_list)
    • 参数:transforms_list —— 包含多个变换操作的列表。
    • 样式:类方法调用。
    • 示例:
      transform = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.1307,), (0.3081,))
      ])
      
  4. torch.utils.data.DataLoader

    • 格式:torch.utils.data.DataLoader(dataset, batch_size, shuffle)
    • 参数:
      • dataset —— 加载的数据集。
      • batch_size —— 每个批次的样本数。
      • shuffle —— 是否在每个epoch开始时打乱数据。
    • 样式:类方法调用。
    • 示例:
      train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
      
  5. nn.Module

    • 格式:class YourModelClass(nn.Module)
    • 参数:继承自nn.Module的类定义。
    • 样式:类继承。
    • 示例:
      class BiRNN(nn.Module):
          def __init__(self, ...):
              super(BiRNN, self).__init__()
              ...
      
  6. nn.LSTM

    • 格式:nn.LSTM(input_size, hidden_size, num_layers, batch_first, bidirectional)
    • 参数:
      • input_size —— 输入特征的维度。
      • hidden_size —— 隐藏层的维度。
      • num_layers —— LSTM层的数量。
      • batch_first —— 输入和输出张量的第一个维度是否为批次大小。
      • bidirectional —— 是否使用双向LSTM。
    • 样式:类方法调用。
    • 示例:
      self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
      
  7. nn.Linear

    • 格式:nn.Linear(in_features, out_features)
    • 参数:
      • in_features —— 输入特征的数量。
      • out_features —— 输出特征的数量。
    • 样式:类方法调用。
    • 示例:
      self.fc = nn.Linear(hidden_size * 2, num_classes)
      
  8. nn.CrossEntropyLoss

    • 格式:nn.CrossEntropyLoss()
    • 参数:无默认参数。
    • 样式:类方法调用。
    • 示例:
      criterion = nn.CrossEntropyLoss()
      
  9. torch.optim.Adam

    • 格式:torch.optim.Adam(params, lr)
    • 参数:
      • params —— 模型参数。
      • lr —— 学习率。
    • 样式:类方法调用。
    • 示例:
      optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
      
  10. .to(device)

    • 格式:.to(device)
    • 参数:device —— 指定的计算设备。
    • 样式:方法调用。
    • 示例:
      images = images.to(device)
      
  11. .reshape

    • 格式:.reshape(shape)
    • 参数:shape —— 要重塑成的新形状。
    • 样式:方法调用。
    • 示例:
      images = images.reshape(-1, sequence_length, input_size)
      
  12. torch.zeros

    • 格式:torch.zeros(size, device)
    • 参数:
      • size —— 张量的形状。
      • device —— 张量所在的设备。
    • 样式:函数调用。
    • 示例:
      h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
      
  13. torch.max

    • 格式:torch.max(input, dim, keepdim)
    • 参数:
      • input —— 输入张量。
      • dim —— 要计算最大值的维度。
      • keepdim —— 是否保留计算维度。
    • 样式:函数调用。
    • 示例:
      _, predicted = torch.max(outputs.data, 1)
      
  14. torch.no_grad()

    • 格式:torch.no_grad()
    • 参数:无参数。
    • 样式:上下文管理器。
    • 示例:
      with torch.no_grad():
          ...
      
  15. torch.save

    • 格式:torch.save(object, filename)
    • 参数:
      • object —— 要保存的对象。
      • filename —— 文件名。
    • 样式:函数调用。
    • 示例:
      torch.save(model.state_dict(), 'model.ckpt')
      
  16. plt.imshow

    • 格式:plt.imshow(X, cmap)
    • 参数:
      • X —— 要显示的图像数据。
      • cmap —— 颜色映射。
    • 样式:函数调用。
    • 示例:
      plt.imshow(images[j].squeeze().cpu(), cmap='gray')
      
  17. plt.show

    • 格式:plt.show()
    • 参数:无参数。
    • 样式:函数调用。
    • 示例:
      plt.show()
      
  18. plt.figure

    • 格式:plt.figure(figsize)
    • 参数:figsize —— 图形的尺寸。
    • 样式:函数调用。
    • 示例:
      plt.figure(figsize=(20, 4))
      
  19. plt.subplot

    • 格式:plt.subplot(nrows, ncols, index)
    • 参数:
      • nrows —— 子图的行数。
      • ncols —— 子图的列数。
      • index —— 当前子图的索引。
    • 样式:函数调用。
    • 示例:
      plt.subplot(1, num_samples, j+1)
      

这些函数覆盖了从数据预处理、模型构建、训练、测试到结果可视化的整个流程。

  • 10
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值