【PyTorch】基于 LSTM 的手写数字识别(MNIST)

引言

基于 PyTorch 实现的 LSTM 模型在 MNIST 数据集上的手写数字识别。

用法

代码托管于 GitHub:https://github.com/XavierJiezou/pytorch-lstm-mnist

git clone https://github.com/XavierJiezou/pytorch-lstm-mnist.git
cd pytorch-lstm-mnist
pip install -r requirements.txt
python ./code/train_on_mnist.py

注意:在训练之前你需要先解压 mnist.7z 文件。

配置

你也可以修改配置文件。

data:
  data_root: ./data/mnist # Path to data
  train_ratio: 0.8 # Ratio of training set
  val_ratio: 0.1 # Ratio of validation set
  batch_size: 64 # How many samples per batch to load
  visualize_data_save: ./image/training_data_mnist.png

model:
  input_size: 28 # Number of expected features in the input
  hidden_size: 64 # Number of features in the hidden state
  num_layers: 1 # Number of recurrent layers
  output_size: 10 # Number of expected features in the output

train:
  num_epochs: 100 # How many epochs to use for data training
  sequence_length: 28 # Length of the input sequence
  learning_rate: 0.001 # Learning_rate
  device: cuda:0 # On which a `torch.Tensor` is or will be allocated
  save_path: ./checkpoint/mnist.pth # Path to save the trained model

log:
  sink: ./log/mnist.log # Path to save the logging file
  level: INFO # Logging level: DEBUG | INFO | WARNING | ERROR | SUCCESS | CRITICAL
  format: '{message}' # logging output format. Example: '{time:YYYY-MM-DD at HH:mm:ss} {level} {message}'
  visualize_log_save: ./image/training_log_mnist.png # Path to save the visualization result

数据

MNIST: http://yann.lecun.com/exdb/mnist/

mnist

结果

数据集序列长度输入维度准确率
MNIST28280.9892

result

EPOCH: 001/100 LR: 0.0010 TRAIN-LOSS: 0.6444 TRAIN-ACC: 0.7916  VAL-LOSS: 0.2733 VAL-ACC: 0.9208  EPOCH-TIME: 0m10s
EPOCH: 002/100 LR: 0.0010 TRAIN-LOSS: 0.1911 TRAIN-ACC: 0.9446  VAL-LOSS: 0.1564 VAL-ACC: 0.9533  EPOCH-TIME: 0m10s
EPOCH: 003/100 LR: 0.0010 TRAIN-LOSS: 0.1317 TRAIN-ACC: 0.9617  VAL-LOSS: 0.1323 VAL-ACC: 0.9621  EPOCH-TIME: 0m9s
EPOCH: 004/100 LR: 0.0010 TRAIN-LOSS: 0.1077 TRAIN-ACC: 0.9686  VAL-LOSS: 0.1204 VAL-ACC: 0.9643  EPOCH-TIME: 0m9s
EPOCH: 005/100 LR: 0.0010 TRAIN-LOSS: 0.0894 TRAIN-ACC: 0.9735  VAL-LOSS: 0.1078 VAL-ACC: 0.9690  EPOCH-TIME: 0m9s
EPOCH: 006/100 LR: 0.0010 TRAIN-LOSS: 0.0776 TRAIN-ACC: 0.9775  VAL-LOSS: 0.0802 VAL-ACC: 0.9756  EPOCH-TIME: 0m9s
EPOCH: 007/100 LR: 0.0010 TRAIN-LOSS: 0.0695 TRAIN-ACC: 0.9794  VAL-LOSS: 0.0675 VAL-ACC: 0.9791  EPOCH-TIME: 0m9s
EPOCH: 008/100 LR: 0.0010 TRAIN-LOSS: 0.0589 TRAIN-ACC: 0.9827  VAL-LOSS: 0.0769 VAL-ACC: 0.9773  EPOCH-TIME: 0m9s
EPOCH: 009/100 LR: 0.0010 TRAIN-LOSS: 0.0536 TRAIN-ACC: 0.9838  VAL-LOSS: 0.0701 VAL-ACC: 0.9790  EPOCH-TIME: 0m9s
EPOCH: 010/100 LR: 0.0010 TRAIN-LOSS: 0.0488 TRAIN-ACC: 0.9850  VAL-LOSS: 0.0669 VAL-ACC: 0.9808  EPOCH-TIME: 0m10s
EPOCH: 011/100 LR: 0.0010 TRAIN-LOSS: 0.0441 TRAIN-ACC: 0.9862  VAL-LOSS: 0.0646 VAL-ACC: 0.9808  EPOCH-TIME: 0m10s
EPOCH: 012/100 LR: 0.0010 TRAIN-LOSS: 0.0403 TRAIN-ACC: 0.9881  VAL-LOSS: 0.0564 VAL-ACC: 0.9837  EPOCH-TIME: 0m10s
EPOCH: 013/100 LR: 0.0010 TRAIN-LOSS: 0.0371 TRAIN-ACC: 0.9889  VAL-LOSS: 0.0485 VAL-ACC: 0.9853  EPOCH-TIME: 0m9s
EPOCH: 014/100 LR: 0.0010 TRAIN-LOSS: 0.0334 TRAIN-ACC: 0.9900  VAL-LOSS: 0.0562 VAL-ACC: 0.9834  EPOCH-TIME: 0m10s
EPOCH: 015/100 LR: 0.0010 TRAIN-LOSS: 0.0320 TRAIN-ACC: 0.9905  VAL-LOSS: 0.0575 VAL-ACC: 0.9834  EPOCH-TIME: 0m9s
EPOCH: 016/100 LR: 0.0010 TRAIN-LOSS: 0.0318 TRAIN-ACC: 0.9906  VAL-LOSS: 0.0617 VAL-ACC: 0.9821  EPOCH-TIME: 0m9s
EPOCH: 017/100 LR: 0.0010 TRAIN-LOSS: 0.0278 TRAIN-ACC: 0.9913  VAL-LOSS: 0.0573 VAL-ACC: 0.9837  EPOCH-TIME: 0m10s
EPOCH: 018/100 LR: 0.0010 TRAIN-LOSS: 0.0258 TRAIN-ACC: 0.9918  VAL-LOSS: 0.0554 VAL-ACC: 0.9841  EPOCH-TIME: 0m9s
EPOCH: 019/100 LR: 0.0010 TRAIN-LOSS: 0.0217 TRAIN-ACC: 0.9933  VAL-LOSS: 0.0490 VAL-ACC: 0.9853  EPOCH-TIME: 0m10s
EPOCH: 020/100 LR: 0.0010 TRAIN-LOSS: 0.0215 TRAIN-ACC: 0.9933  VAL-LOSS: 0.0460 VAL-ACC: 0.9864  EPOCH-TIME: 0m10s
EPOCH: 021/100 LR: 0.0010 TRAIN-LOSS: 0.0206 TRAIN-ACC: 0.9933  VAL-LOSS: 0.0586 VAL-ACC: 0.9847  EPOCH-TIME: 0m10s
EPOCH: 022/100 LR: 0.0010 TRAIN-LOSS: 0.0191 TRAIN-ACC: 0.9939  VAL-LOSS: 0.0533 VAL-ACC: 0.9863  EPOCH-TIME: 0m10s
EPOCH: 023/100 LR: 0.0010 TRAIN-LOSS: 0.0179 TRAIN-ACC: 0.9945  VAL-LOSS: 0.0470 VAL-ACC: 0.9856  EPOCH-TIME: 0m10s
EPOCH: 024/100 LR: 0.0010 TRAIN-LOSS: 0.0182 TRAIN-ACC: 0.9940  VAL-LOSS: 0.0434 VAL-ACC: 0.9884  EPOCH-TIME: 0m10s
EPOCH: 025/100 LR: 0.0010 TRAIN-LOSS: 0.0140 TRAIN-ACC: 0.9957  VAL-LOSS: 0.0458 VAL-ACC: 0.9877  EPOCH-TIME: 0m10s
EPOCH: 026/100 LR: 0.0010 TRAIN-LOSS: 0.0149 TRAIN-ACC: 0.9954  VAL-LOSS: 0.0599 VAL-ACC: 0.9838  EPOCH-TIME: 0m10s
EPOCH: 027/100 LR: 0.0010 TRAIN-LOSS: 0.0154 TRAIN-ACC: 0.9951  VAL-LOSS: 0.0508 VAL-ACC: 0.9858  EPOCH-TIME: 0m10s
EPOCH: 028/100 LR: 0.0010 TRAIN-LOSS: 0.0160 TRAIN-ACC: 0.9949  VAL-LOSS: 0.0471 VAL-ACC: 0.9878  EPOCH-TIME: 0m10s
EPOCH: 029/100 LR: 0.0010 TRAIN-LOSS: 0.0128 TRAIN-ACC: 0.9957  VAL-LOSS: 0.0477 VAL-ACC: 0.9860  EPOCH-TIME: 0m10s
EPOCH: 030/100 LR: 0.0010 TRAIN-LOSS: 0.0107 TRAIN-ACC: 0.9966  VAL-LOSS: 0.0552 VAL-ACC: 0.9856  EPOCH-TIME: 0m10s
EPOCH: 031/100 LR: 0.0010 TRAIN-LOSS: 0.0107 TRAIN-ACC: 0.9967  VAL-LOSS: 0.0514 VAL-ACC: 0.9867  EPOCH-TIME: 0m9s
EPOCH: 032/100 LR: 0.0010 TRAIN-LOSS: 0.0122 TRAIN-ACC: 0.9958  VAL-LOSS: 0.0491 VAL-ACC: 0.9883  EPOCH-TIME: 0m10s
EPOCH: 033/100 LR: 0.0010 TRAIN-LOSS: 0.0106 TRAIN-ACC: 0.9967  VAL-LOSS: 0.0491 VAL-ACC: 0.9878  EPOCH-TIME: 0m9s
EPOCH: 034/100 LR: 0.0010 TRAIN-LOSS: 0.0081 TRAIN-ACC: 0.9973  VAL-LOSS: 0.0516 VAL-ACC: 0.9873  EPOCH-TIME: 0m10s
EPOCH: 035/100 LR: 0.0010 TRAIN-LOSS: 0.0136 TRAIN-ACC: 0.9958  VAL-LOSS: 0.0547 VAL-ACC: 0.9851  EPOCH-TIME: 0m9s
EPOCH: 036/100 LR: 0.0010 TRAIN-LOSS: 0.0082 TRAIN-ACC: 0.9975  VAL-LOSS: 0.0484 VAL-ACC: 0.9878  EPOCH-TIME: 0m10s
EPOCH: 037/100 LR: 0.0010 TRAIN-LOSS: 0.0104 TRAIN-ACC: 0.9965  VAL-LOSS: 0.0533 VAL-ACC: 0.9870  EPOCH-TIME: 0m9s
EPOCH: 038/100 LR: 0.0010 TRAIN-LOSS: 0.0084 TRAIN-ACC: 0.9975  VAL-LOSS: 0.0653 VAL-ACC: 0.9840  EPOCH-TIME: 0m10s
EPOCH: 039/100 LR: 0.0010 TRAIN-LOSS: 0.0084 TRAIN-ACC: 0.9975  VAL-LOSS: 0.0516 VAL-ACC: 0.9880  EPOCH-TIME: 0m9s
EPOCH: 040/100 LR: 0.0010 TRAIN-LOSS: 0.0081 TRAIN-ACC: 0.9974  VAL-LOSS: 0.0474 VAL-ACC: 0.9886  EPOCH-TIME: 0m9s
EPOCH: 041/100 LR: 0.0010 TRAIN-LOSS: 0.0079 TRAIN-ACC: 0.9976  VAL-LOSS: 0.0574 VAL-ACC: 0.9858  EPOCH-TIME: 0m9s
EPOCH: 042/100 LR: 0.0010 TRAIN-LOSS: 0.0072 TRAIN-ACC: 0.9979  VAL-LOSS: 0.0564 VAL-ACC: 0.9866  EPOCH-TIME: 0m9s
EPOCH: 043/100 LR: 0.0010 TRAIN-LOSS: 0.0095 TRAIN-ACC: 0.9971  VAL-LOSS: 0.0557 VAL-ACC: 0.9868  EPOCH-TIME: 0m10s
EPOCH: 044/100 LR: 0.0010 TRAIN-LOSS: 0.0061 TRAIN-ACC: 0.9980  VAL-LOSS: 0.0469 VAL-ACC: 0.9896  EPOCH-TIME: 0m10s
EPOCH: 045/100 LR: 0.0010 TRAIN-LOSS: 0.0089 TRAIN-ACC: 0.9971  VAL-LOSS: 0.0569 VAL-ACC: 0.9868  EPOCH-TIME: 0m10s
EPOCH: 046/100 LR: 0.0010 TRAIN-LOSS: 0.0068 TRAIN-ACC: 0.9979  VAL-LOSS: 0.0540 VAL-ACC: 0.9867  EPOCH-TIME: 0m9s
EPOCH: 047/100 LR: 0.0010 TRAIN-LOSS: 0.0049 TRAIN-ACC: 0.9985  VAL-LOSS: 0.0569 VAL-ACC: 0.9881  EPOCH-TIME: 0m10s
EPOCH: 048/100 LR: 0.0010 TRAIN-LOSS: 0.0079 TRAIN-ACC: 0.9975  VAL-LOSS: 0.0524 VAL-ACC: 0.9880  EPOCH-TIME: 0m9s
EPOCH: 049/100 LR: 0.0010 TRAIN-LOSS: 0.0037 TRAIN-ACC: 0.9991  VAL-LOSS: 0.0541 VAL-ACC: 0.9876  EPOCH-TIME: 0m9s
EPOCH: 050/100 LR: 0.0010 TRAIN-LOSS: 0.0075 TRAIN-ACC: 0.9976  VAL-LOSS: 0.0564 VAL-ACC: 0.9866  EPOCH-TIME: 0m9s
EPOCH: 051/100 LR: 0.0001 TRAIN-LOSS: 0.0031 TRAIN-ACC: 0.9992  VAL-LOSS: 0.0504 VAL-ACC: 0.9890  EPOCH-TIME: 0m9s
EPOCH: 052/100 LR: 0.0001 TRAIN-LOSS: 0.0013 TRAIN-ACC: 0.9998  VAL-LOSS: 0.0497 VAL-ACC: 0.9893  EPOCH-TIME: 0m10s
EPOCH: 053/100 LR: 0.0001 TRAIN-LOSS: 0.0010 TRAIN-ACC: 0.9999  VAL-LOSS: 0.0491 VAL-ACC: 0.9901  EPOCH-TIME: 0m10s
EPOCH: 054/100 LR: 0.0001 TRAIN-LOSS: 0.0008 TRAIN-ACC: 0.9999  VAL-LOSS: 0.0496 VAL-ACC: 0.9897  EPOCH-TIME: 0m9s
EPOCH: 055/100 LR: 0.0001 TRAIN-LOSS: 0.0007 TRAIN-ACC: 0.9999  VAL-LOSS: 0.0504 VAL-ACC: 0.9897  EPOCH-TIME: 0m9s
EPOCH: 056/100 LR: 0.0001 TRAIN-LOSS: 0.0006 TRAIN-ACC: 0.9999  VAL-LOSS: 0.0503 VAL-ACC: 0.9896  EPOCH-TIME: 0m10s
EPOCH: 057/100 LR: 0.0001 TRAIN-LOSS: 0.0005 TRAIN-ACC: 0.9999  VAL-LOSS: 0.0503 VAL-ACC: 0.9900  EPOCH-TIME: 0m10s
EPOCH: 058/100 LR: 0.0001 TRAIN-LOSS: 0.0005 TRAIN-ACC: 0.9999  VAL-LOSS: 0.0525 VAL-ACC: 0.9896  EPOCH-TIME: 0m10s
EPOCH: 059/100 LR: 0.0001 TRAIN-LOSS: 0.0004 TRAIN-ACC: 0.9999  VAL-LOSS: 0.0532 VAL-ACC: 0.9896  EPOCH-TIME: 0m10s
EPOCH: 060/100 LR: 0.0001 TRAIN-LOSS: 0.0003 TRAIN-ACC: 0.9999  VAL-LOSS: 0.0534 VAL-ACC: 0.9898  EPOCH-TIME: 0m10s
EPOCH: 061/100 LR: 0.0001 TRAIN-LOSS: 0.0003 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0546 VAL-ACC: 0.9891  EPOCH-TIME: 0m10s
EPOCH: 062/100 LR: 0.0001 TRAIN-LOSS: 0.0003 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0555 VAL-ACC: 0.9894  EPOCH-TIME: 0m10s
EPOCH: 063/100 LR: 0.0001 TRAIN-LOSS: 0.0002 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0556 VAL-ACC: 0.9894  EPOCH-TIME: 0m10s
EPOCH: 064/100 LR: 0.0001 TRAIN-LOSS: 0.0002 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0586 VAL-ACC: 0.9898  EPOCH-TIME: 0m9s
EPOCH: 065/100 LR: 0.0001 TRAIN-LOSS: 0.0002 TRAIN-ACC: 0.9999  VAL-LOSS: 0.0584 VAL-ACC: 0.9894  EPOCH-TIME: 0m10s
EPOCH: 066/100 LR: 0.0001 TRAIN-LOSS: 0.0001 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0579 VAL-ACC: 0.9893  EPOCH-TIME: 0m9s
EPOCH: 067/100 LR: 0.0001 TRAIN-LOSS: 0.0001 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0599 VAL-ACC: 0.9887  EPOCH-TIME: 0m9s
EPOCH: 068/100 LR: 0.0001 TRAIN-LOSS: 0.0001 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0611 VAL-ACC: 0.9891  EPOCH-TIME: 0m10s
EPOCH: 069/100 LR: 0.0001 TRAIN-LOSS: 0.0001 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0634 VAL-ACC: 0.9890  EPOCH-TIME: 0m9s
EPOCH: 070/100 LR: 0.0001 TRAIN-LOSS: 0.0001 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0621 VAL-ACC: 0.9888  EPOCH-TIME: 0m9s
EPOCH: 071/100 LR: 0.0001 TRAIN-LOSS: 0.0001 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0620 VAL-ACC: 0.9890  EPOCH-TIME: 0m10s
EPOCH: 072/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0657 VAL-ACC: 0.9888  EPOCH-TIME: 0m10s
EPOCH: 073/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0667 VAL-ACC: 0.9887  EPOCH-TIME: 0m10s
EPOCH: 074/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0665 VAL-ACC: 0.9888  EPOCH-TIME: 0m10s
EPOCH: 075/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0675 VAL-ACC: 0.9888  EPOCH-TIME: 0m10s
EPOCH: 076/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0686 VAL-ACC: 0.9888  EPOCH-TIME: 0m10s
EPOCH: 077/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0692 VAL-ACC: 0.9888  EPOCH-TIME: 0m9s
EPOCH: 078/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0707 VAL-ACC: 0.9888  EPOCH-TIME: 0m10s
EPOCH: 079/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0725 VAL-ACC: 0.9884  EPOCH-TIME: 0m10s
EPOCH: 080/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0749 VAL-ACC: 0.9887  EPOCH-TIME: 0m10s
EPOCH: 081/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0750 VAL-ACC: 0.9886  EPOCH-TIME: 0m10s
EPOCH: 082/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0759 VAL-ACC: 0.9887  EPOCH-TIME: 0m9s
EPOCH: 083/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0784 VAL-ACC: 0.9884  EPOCH-TIME: 0m10s
EPOCH: 084/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0810 VAL-ACC: 0.9884  EPOCH-TIME: 0m10s
EPOCH: 085/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0788 VAL-ACC: 0.9887  EPOCH-TIME: 0m10s
EPOCH: 086/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0815 VAL-ACC: 0.9888  EPOCH-TIME: 0m10s
EPOCH: 087/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0817 VAL-ACC: 0.9886  EPOCH-TIME: 0m10s
EPOCH: 088/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0824 VAL-ACC: 0.9884  EPOCH-TIME: 0m10s
EPOCH: 089/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0850 VAL-ACC: 0.9887  EPOCH-TIME: 0m9s
EPOCH: 090/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0869 VAL-ACC: 0.9886  EPOCH-TIME: 0m10s
EPOCH: 091/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0892 VAL-ACC: 0.9883  EPOCH-TIME: 0m9s
EPOCH: 092/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0903 VAL-ACC: 0.9883  EPOCH-TIME: 0m10s
EPOCH: 093/100 LR: 0.0001 TRAIN-LOSS: 0.0001 TRAIN-ACC: 1.0000  VAL-LOSS: 0.1042 VAL-ACC: 0.9868  EPOCH-TIME: 0m9s
EPOCH: 094/100 LR: 0.0001 TRAIN-LOSS: 0.0003 TRAIN-ACC: 0.9999  VAL-LOSS: 0.0972 VAL-ACC: 0.9874  EPOCH-TIME: 0m10s
EPOCH: 095/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0969 VAL-ACC: 0.9874  EPOCH-TIME: 0m9s
EPOCH: 096/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0966 VAL-ACC: 0.9877  EPOCH-TIME: 0m9s
EPOCH: 097/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0962 VAL-ACC: 0.9877  EPOCH-TIME: 0m9s
EPOCH: 098/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0959 VAL-ACC: 0.9877  EPOCH-TIME: 0m9s
EPOCH: 099/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0956 VAL-ACC: 0.9877  EPOCH-TIME: 0m9s
EPOCH: 100/100 LR: 0.0001 TRAIN-LOSS: 0.0000 TRAIN-ACC: 1.0000  VAL-LOSS: 0.0953 VAL-ACC: 0.9877  EPOCH-TIME: 0m10s

参考

  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
以下是基于PyTorchMNIST手写数字识别的步骤: 1.导入必要的库 ```python import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms ``` 2.定义数据预处理操作 ```python transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) ``` 3.加载数据集 ```python train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform) test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform) ``` 4.定义数据加载器 ```python batch_size = 64 train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) ``` 5.定义模型 ```python class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(64*7*7, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = nn.functional.relu(self.conv1(x)) x = nn.functional.relu(self.conv2(x)) x = nn.functional.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = nn.functional.relu(self.fc1(x)) x = self.dropout2(x) x = self.fc2(x) return nn.functional.log_softmax(x, dim=1) model = Net() ``` 6.定义优化器和损失函数 ```python learning_rate = 0.01 optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.5) criterion = nn.CrossEntropyLoss() ``` 7.训练模型 ```python epochs = 10 for epoch in range(epochs): for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() if batch_idx % 100 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx*len(data), len(train_loader.dataset), 100.*batch_idx/len(train_loader), loss.item())) ``` 8.测试模型 ```python test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: output = model(data) test_loss += criterion(output, target).item() pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).sum() test_loss /= len(test_loader.dataset) accuracy = 100. * correct / len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), accuracy)) ``` 这就是基于PyTorchMNIST手写数字识别的步骤。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Xavier Jiezou

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值