网络模型(LSTM-带门的 RNN)

解决了 RNN 的长期依赖问题、梯度问题。(残差把连乘变为连加)
LSTM
忘记门(控制 C 的比重)
忘记门
更新门(控制 h 的比重)
更新门1
更新门2
输出门(控制 C 和 h 的比例)
输出门

实验(手写数字识别)

数据集:MNIST。
网络结构:LSTM + 全连接。
优化器:Adam。
损失函数:交叉熵(CrossEntropyLoss),自带 one-hot 类型和 softmax。
输出:one-hot 类型,结果为最大的索引值。

网络

import torch
from torch import nn


class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(28, 64, 2, batch_first=True)
        # 输出层:返回 one-hot 类型
        self.mlp = nn.Linear(28 * 64, 10)

    def forward(self, x):
        out, _ = self.lstm(x)
        # [n,s,v] → [n,s*v]
        out = out.reshape(-1, 28 * 64)
        return self.mlp(out)

训练

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
from PIL import Image, ImageDraw, ImageFont
from matplotlib import pyplot as plt

from net import MyNet


batch_size = 100
net_path = r"modules/mynet.pth"

train_flag = False

# 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
if train_flag:
    dataset = datasets.MNIST(r"data", train=True, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size, shuffle=True)
else:
    dataset = datasets.MNIST(r"data", train=False, transform=transform, download=False)
    dataloader = DataLoader(dataset, batch_size, shuffle=False)


if __name__ == '__main__':
    # 加载网络
    if os.path.isfile(net_path):
        net = torch.load(net_path)
    else:
        net = MyNet()
    opt = torch.optim.Adam(net.parameters())
    loss_fn = nn.CrossEntropyLoss()

    if train_flag:
        # 训练
        net.train()
        while True:
            for i, (x, y) in enumerate(dataloader):
                x = x.reshape(-1, 28, 28)
                out = net(x)
                loss = loss_fn(out, y)
                opt.zero_grad()
                loss.backward()
                opt.step()
                # 结果是 one-hot 类型,取最大索引
                result = torch.argmax(out, 1)
                acc = torch.mean(torch.eq(result, y).float())
                print("i:{},loss:{:.5},acc:{:.3}".format(i, loss, acc))
            # 保存网络
            torch.save(net, net_path)
    else:
        # 测试
        net.eval()
        font = ImageFont.truetype(r"arial.ttf", size=10)
        plt.ion()
        for x, y in dataloader:
            # [n,c,h,w] → [h,w]
            img_array = x[0][0] * 255
            img = Image.fromarray(img_array.numpy())
            draw = ImageDraw.ImageDraw(img)

            x = x.reshape(-1, 28, 28)
            out = net(x)
            result = torch.argmax(out, 1)
            draw.text((0, 0), str(result[0].item()), 255, font)

            plt.imshow(img)
            plt.pause(0.5)
        plt.ioff()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
1. ARIMA 2. SARIMA 3. VAR 4. Auto-ARIMA 5. Auto-SARIMA 6. LSTM 7. GRU 8. RNN 9. CNN 10. MLP 11. DNN 12. MLP-LSTM 13. MLP-GRU 14. MLP-RNN 15. MLP-CNN 16. LSTM-ARIMA 17. LSTM-MLP 18. LSTM-CNN 19. GRU-ARIMA 20. GRU-MLP 21. GRU-CNN 22. RNN-ARIMA 23. RNN-MLP 24. RNN-CNN 25. CNN-ARIMA 26. CNN-MLP 27. CNN-LSTM 28. CNN-GRU 29. ARIMA-SVM 30. SARIMA-SVM 31. VAR-SVM 32. Auto-ARIMA-SVM 33. Auto-SARIMA-SVM 34. LSTM-SVM 35. GRU-SVM 36. RNN-SVM 37. CNN-SVM 38. MLP-SVM 39. LSTM-ARIMA-SVM 40. LSTM-MLP-SVM 41. LSTM-CNN-SVM 42. GRU-ARIMA-SVM 43. GRU-MLP-SVM 44. GRU-CNN-SVM 45. RNN-ARIMA-SVM 46. RNN-MLP-SVM 47. RNN-CNN-SVM 48. CNN-ARIMA-SVM 49. CNN-MLP-SVM 50. CNN-LSTM-SVM 51. CNN-GRU-SVM 52. ARIMA-RF 53. SARIMA-RF 54. VAR-RF 55. Auto-ARIMA-RF 56. Auto-SARIMA-RF 57. LSTM-RF 58. GRU-RF 59. RNN-RF 60. CNN-RF 61. MLP-RF 62. LSTM-ARIMA-RF 63. LSTM-MLP-RF 64. LSTM-CNN-RF 65. GRU-ARIMA-RF 66. GRU-MLP-RF 67. GRU-CNN-RF 68. RNN-ARIMA-RF 69. RNN-MLP-RF 70. RNN-CNN-RF 71. CNN-ARIMA-RF 72. CNN-MLP-RF 73. CNN-LSTM-RF 74. CNN-GRU-RF 75. ARIMA-XGBoost 76. SARIMA-XGBoost 77. VAR-XGBoost 78. Auto-ARIMA-XGBoost 79. Auto-SARIMA-XGBoost 80. LSTM-XGBoost 81. GRU-XGBoost 82. RNN-XGBoost 83. CNN-XGBoost 84. MLP-XGBoost 85. LSTM-ARIMA-XGBoost 86. LSTM-MLP-XGBoost 87. LSTM-CNN-XGBoost 88. GRU-ARIMA-XGBoost 89. GRU-MLP-XGBoost 90. GRU-CNN-XGBoost 91. RNN-ARIMA-XGBoost 92. RNN-MLP-XGBoost 93. RNN-CNN-XGBoost 94. CNN-ARIMA-XGBoost 95. CNN-MLP-XGBoost 96. CNN-LSTM-XGBoost 97. CNN-GRU-XGBoost 98. ARIMA-ANN 99. SARIMA-ANN 100. VAR-ANN 上面这些缩写模型的全称及相关用途功能详细解释
07-15
以下是对缩写模型的全称及相关用途功能的详细解释: 1. ARIMA (AutoRegressive Integrated Moving Average): 适用于平稳时间序列数据的预测,结合了自回归 (AR) 和移动平均 (MA) 的模型。 2. SARIMA (Seasonal ARIMA): 在ARIMA模型基础上添加了对季节性因素的建模,适用于带有季节性的时间序列数据的预测。 3. VAR (Vector Autoregression): 用于多变量时间序列数据的预测,基于自回归模型,能够捕捉变量之间的相互依赖关系。 4. Auto-ARIMA: 自动选择ARIMA模型的参数,通过对多个模型进行评估和选择来实现自动化。 5. Auto-SARIMA: 自动选择SARIMA模型的参数,通过对多个模型进行评估和选择来实现自动化。 6. LSTM (Long Short-Term Memory): 长短期记忆网络,一种适用于处理长期依赖关系的循环神经网络,用于时间序列数据的建模和预测。 7. GRU (Gated Recurrent Unit): 一种类似于LSTM的循环神经网络,具有更简化的结构,适用于时间序列数据的建模和预测。 8. RNN (Recurrent Neural Network): 适用于处理序列数据的神经网络模型,能够捕捉时间序列的动态特性。 9. CNN (Convolutional Neural Network): 卷积神经网络,主要用于图像处理,但也可以用于时间序列数据的预测,特别擅长局部模式的识别。 10. MLP (Multi-Layer Perceptron): 多层感知机,一种前馈神经网络模型,适用于处理非线性关系的时间序列数据。 11. DNN (Deep Neural Network): 深度神经网络,具有多个隐藏层的神经网络模型,能够学习更复杂的特征表示。 12. MLP-LSTM: 结合了多层感知机和长短期记忆网络模型,用于时间序列数据的建模和预测。 13. MLP-GRU: 结合了多层感知机和控循环单元网络模型,用于时间序列数据的建模和预测。 14. MLP-RNN: 结合了多层感知机和循环神经网络模型,用于时间序列数据的建模和预测。 15. MLP-CNN: 结合了多层感知机和卷积神经网络模型,用于时间序列数据的建模和预测。 这些模型可以根据具体问题和数据的特性来选择和使用,以获得最佳的时间序列预测性能。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值