解决了 RNN 的长期依赖问题、梯度问题。(残差把连乘变为连加)
忘记门(控制 C 的比重)
更新门(控制 h 的比重)
输出门(控制 C 和 h 的比例)
实验(手写数字识别)
数据集:MNIST。
网络结构:LSTM + 全连接。
优化器:Adam。
损失函数:交叉熵(CrossEntropyLoss),自带 one-hot 类型和 softmax。
输出:one-hot 类型,结果为最大的索引值。
网络
import torch
from torch import nn
class MyNet(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(28, 64, 2, batch_first=True)
# 输出层:返回 one-hot 类型
self.mlp = nn.Linear(28 * 64, 10)
def forward(self, x):
out, _ = self.lstm(x)
# [n,s,v] → [n,s*v]
out = out.reshape(-1, 28 * 64)
return self.mlp(out)
训练
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
from PIL import Image, ImageDraw, ImageFont
from matplotlib import pyplot as plt
from net import MyNet
batch_size = 100
net_path = r"modules/mynet.pth"
train_flag = False
# 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
if train_flag:
dataset = datasets.MNIST(r"data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size, shuffle=True)
else:
dataset = datasets.MNIST(r"data", train=False, transform=transform, download=False)
dataloader = DataLoader(dataset, batch_size, shuffle=False)
if __name__ == '__main__':
# 加载网络
if os.path.isfile(net_path):
net = torch.load(net_path)
else:
net = MyNet()
opt = torch.optim.Adam(net.parameters())
loss_fn = nn.CrossEntropyLoss()
if train_flag:
# 训练
net.train()
while True:
for i, (x, y) in enumerate(dataloader):
x = x.reshape(-1, 28, 28)
out = net(x)
loss = loss_fn(out, y)
opt.zero_grad()
loss.backward()
opt.step()
# 结果是 one-hot 类型,取最大索引
result = torch.argmax(out, 1)
acc = torch.mean(torch.eq(result, y).float())
print("i:{},loss:{:.5},acc:{:.3}".format(i, loss, acc))
# 保存网络
torch.save(net, net_path)
else:
# 测试
net.eval()
font = ImageFont.truetype(r"arial.ttf", size=10)
plt.ion()
for x, y in dataloader:
# [n,c,h,w] → [h,w]
img_array = x[0][0] * 255
img = Image.fromarray(img_array.numpy())
draw = ImageDraw.ImageDraw(img)
x = x.reshape(-1, 28, 28)
out = net(x)
result = torch.argmax(out, 1)
draw.text((0, 0), str(result[0].item()), 255, font)
plt.imshow(img)
plt.pause(0.5)
plt.ioff()