RNN手写字体分类

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd 
import time
import copy 
import torch
from torch import nn 
import torch.nn.functional  as F 
from torch.utils.data import DataLoader,Dataset
import torch.optim as optim
import torchvision
from torchvision import transforms
import torch.optim as optim
import hiddenlayer as hl
train = torchvision.datasets.MNIST(root="./data/MNIST",train = True,
                                   transform = transforms.ToTensor()
                                  ,download =True)
test = torchvision.datasets.MNIST(root ="./data/MNIST",train = False,
                                  transform =transforms.ToTensor(),
                                  download = True )
E:\Anaconda\lib\site-packages\torchvision\datasets\mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ..\torch\csrc\utils\tensor_numpy.cpp:189.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
train_loader = DataLoader(dataset = train,batch_size=64,shuffle=True,num_workers = 0)
test_loader = DataLoader(dataset = test,batch_size=64,shuffle=True,num_workers = 0)
for i, (x,y) in enumerate(train_loader):
    print(x.size())
    print(y)
    break
torch.Size([64, 1, 28, 28])
tensor([9, 8, 5, 9, 6, 9, 6, 8, 1, 9, 9, 8, 3, 1, 2, 2, 8, 6, 8, 6, 5, 1, 0, 1,
        6, 6, 0, 7, 3, 5, 7, 3, 6, 7, 7, 2, 9, 0, 8, 2, 1, 3, 0, 9, 7, 0, 5, 3,
        1, 3, 4, 0, 7, 8, 1, 1, 3, 1, 1, 7, 7, 9, 3, 1])
for i, (x,y) in enumerate(test_loader):
    print(x.size())
    print(y)
    break
torch.Size([64, 1, 28, 28])
tensor([7, 6, 0, 6, 8, 1, 3, 8, 8, 7, 7, 5, 9, 3, 6, 7, 3, 1, 8, 3, 3, 8, 8, 5,
        0, 2, 7, 6, 8, 6, 4, 5, 7, 5, 5, 7, 3, 5, 3, 4, 2, 6, 9, 3, 4, 5, 3, 1,
        6, 2, 7, 0, 1, 5, 0, 5, 0, 9, 1, 6, 5, 9, 5, 2])

搭建RNN网络

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
device(type='cpu')
class RNNimc(nn.Module):
    def __init__(self,input_dim,hidden_dim,layer_dim,output_dim):
        """
        param:
        input_dim :输入维度
        hidden_dim:隐藏层维度
        layer_dim:rnn的层数
        output_dim: 隐藏层输出维度
        
        """
        super(RNNimc,self).__init__()
        self.input_dim = input_dim
        self.layer_dim = layer_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.rnn = nn.RNN(input_dim,hidden_dim,layer_dim,batch_first = True,nonlinearity = "relu")
        self.fc1 = nn.Linear(hidden_dim,output_dim)
    
    def forward(self,x):
        #x:[batch,time_step,input_size]
        #out:[batch,time_step,output_size]
        #.h_m[layer_dim,batch,hidden_dim]
        out,h_n = self.rnn(x,None)
        #输出最后一个时间节点的out的输出
        out = self.fc1(out[:,-1,:])
        return out
# 模型的调用
input_dim  =28
hidden_dim = 128
layer_dim =1
output_dim = 10
MyRNNimc= RNNimc(input_dim,hidden_dim,layer_dim,output_dim)
print(MyRNNimc)
RNNimc(
  (rnn): RNN(28, 128, batch_first=True)
  (fc1): Linear(in_features=128, out_features=10, bias=True)
)
hl_graph = hl.build_graph(MyRNNimc,torch.zeros([1,28,28]))
hl_graph.theme = hl.graph.THEMES["blue"].copy()
hl_graph

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WOVM0RF6-1646186975046)(https://wangzhenxi971006.oss-cn-beijing.aliyuncs.com/output_9_0.svg)]

RNN分类器的训练和预测

for i, (x,y) in enumerate(train_loader):
    print(x.shape)
    print(y.shape)
    x = x.view(-1,28,28)
    print(x.shape)
    y = y.view(x.shape[0])
    print(y.shape)
    break
torch.Size([64, 1, 28, 28])
torch.Size([64])
torch.Size([64, 28, 28])
torch.Size([64])
MyRNNimc = MyRNNimc.to(device)
# 对模型进行训练
optimizer = optim.RMSprop(MyRNNimc.parameters(),lr =0.0003)
criterion = nn.CrossEntropyLoss()
train_loss_all = []
train_acc_all = []
test_loss_all = []
test_acc_all = []
num_epochs =30
for epoch in range(1,num_epochs+1):
    print(f"Epoch:{epoch}/{num_epochs}")
    MyRNNimc.train()
    corrects = 0
    train_num = 0
    for step,(b_x,b_y) in enumerate(train_loader):
        x = b_x.view(-1,28,28)
        y = b_y.view(x.shape[0])
        x = x.to(device,non_blocking = True)
        y = y.to(device,non_blocking = True)
        sigma =MyRNNimc.forward(x)
#         print("sigma.shape",sigma.size())
#         print("sigma",sigma)
        
        
#         print("pre_lab:",pre_lab,pre_lab.shape)
        loss = criterion(sigma,y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        pre_lab = torch.argmax(sigma,1)
        loss+= loss.item()*x.size(0)
        corrects+=torch.sum(pre_lab==y.data)
        train_num+=x.size(0)
    train_loss_all.append(loss/train_num)
    train_acc_all.append(corrects.double().item()/train_num)
    print(f"Epoch:{epoch} Train Loss:{train_loss_all[-1]:.4f} Train Acc:{train_acc_all[-1]:.4f}")
    
    del x,y,corrects,loss,pre_lab
    gc.collect()
    torch.cuda.empty_cache()
    
    #设置为验证模式
    MyRNNimc.eval()
    corrects=0
    test_num = 0
    for step,(x,y) in enumerate(test_loader):
        x =x.view(-1,28,28)
        y = y.view(x.shape[0])
        
        x = x.to(device,non_blocking = True)
        y = y.to(device,non_blocking = True)
        
        sigma =MyRNNimc.forward(x)
        pre_lab = torch.argmax(sigma,1)
        loss = criterion(sigma,y)
        loss+=loss.item()*x.size(0)
        corrects+= torch.sum(pre_lab ==y.data)
        test_num+=x.size(0)
    test_loss_all.append(loss/test_num)
    test_acc_all.append(corrects.double().item()/test_num)
    print(f"Epoch:{epoch} Test Loss:{test_loss_all[-1]:.4f} Test Acc:{test_acc_all[-1]:.4f}")
    del x,y,corrects,loss,pre_lab
    gc.collect()
    torch.cuda.empty_cache()
Epoch:1/30
Epoch:1 Train Loss:0.0001 Train Acc:0.9606
Epoch:1 Test Loss:0.0004 Test Acc:0.9165
Epoch:2/30
Epoch:2 Train Loss:0.0001 Train Acc:0.9639
Epoch:2 Test Loss:0.0005 Test Acc:0.9668
Epoch:3/30
Epoch:3 Train Loss:0.0000 Train Acc:0.9673
Epoch:3 Test Loss:0.0003 Test Acc:0.9696
Epoch:4/30
Epoch:4 Train Loss:0.0001 Train Acc:0.9697
Epoch:4 Test Loss:0.0000 Test Acc:0.9686
Epoch:5/30
Epoch:5 Train Loss:0.0001 Train Acc:0.9715
Epoch:5 Test Loss:0.0000 Test Acc:0.9748
Epoch:6/30
Epoch:6 Train Loss:0.0000 Train Acc:0.9734
Epoch:6 Test Loss:0.0000 Test Acc:0.9744
Epoch:7/30
Epoch:7 Train Loss:0.0001 Train Acc:0.9750
Epoch:7 Test Loss:0.0000 Test Acc:0.9703
Epoch:8/30
Epoch:8 Train Loss:0.0000 Train Acc:0.9765
Epoch:8 Test Loss:0.0000 Test Acc:0.9793
Epoch:9/30
Epoch:9 Train Loss:0.0000 Train Acc:0.9781
Epoch:9 Test Loss:0.0009 Test Acc:0.9743
Epoch:10/30
Epoch:10 Train Loss:0.0001 Train Acc:0.9792
Epoch:10 Test Loss:0.0001 Test Acc:0.9757
Epoch:11/30
Epoch:11 Train Loss:0.0000 Train Acc:0.9795
Epoch:11 Test Loss:0.0001 Test Acc:0.9746
Epoch:12/30
Epoch:12 Train Loss:0.0000 Train Acc:0.9802
Epoch:12 Test Loss:0.0001 Test Acc:0.9733
Epoch:13/30
Epoch:13 Train Loss:0.0000 Train Acc:0.9807
Epoch:13 Test Loss:0.0000 Test Acc:0.9777
Epoch:14/30
Epoch:14 Train Loss:0.0000 Train Acc:0.9815
Epoch:14 Test Loss:0.0004 Test Acc:0.9763
Epoch:15/30
Epoch:15 Train Loss:0.0000 Train Acc:0.9828
Epoch:15 Test Loss:0.0004 Test Acc:0.9780
Epoch:16/30
Epoch:16 Train Loss:0.0000 Train Acc:0.9831
Epoch:16 Test Loss:0.0001 Test Acc:0.9758
Epoch:17/30
Epoch:17 Train Loss:0.0000 Train Acc:0.9838
Epoch:17 Test Loss:0.0005 Test Acc:0.9821
Epoch:18/30
Epoch:18 Train Loss:0.0002 Train Acc:0.9836
Epoch:18 Test Loss:0.0000 Test Acc:0.8875
Epoch:19/30
Epoch:19 Train Loss:0.0000 Train Acc:0.9844
Epoch:19 Test Loss:0.0000 Test Acc:0.9765
Epoch:20/30
Epoch:20 Train Loss:0.0000 Train Acc:0.9846
Epoch:20 Test Loss:0.0000 Test Acc:0.9801
Epoch:21/30
Epoch:21 Train Loss:0.0001 Train Acc:0.9853
Epoch:21 Test Loss:0.0000 Test Acc:0.9759
Epoch:22/30
Epoch:22 Train Loss:0.0001 Train Acc:0.9856
Epoch:22 Test Loss:0.0000 Test Acc:0.9604
Epoch:23/30
Epoch:23 Train Loss:0.0000 Train Acc:0.9861
Epoch:23 Test Loss:0.0001 Test Acc:0.9771
Epoch:24/30
Epoch:24 Train Loss:0.0000 Train Acc:0.9866
Epoch:24 Test Loss:0.0001 Test Acc:0.9713
Epoch:25/30
Epoch:25 Train Loss:0.0000 Train Acc:0.9869
Epoch:25 Test Loss:0.0000 Test Acc:0.9798
Epoch:26/30
Epoch:26 Train Loss:0.0000 Train Acc:0.9874
Epoch:26 Test Loss:0.0000 Test Acc:0.9828
Epoch:27/30
Epoch:27 Train Loss:0.0000 Train Acc:0.9880
Epoch:27 Test Loss:0.0000 Test Acc:0.9801
Epoch:28/30
Epoch:28 Train Loss:0.0000 Train Acc:0.9879
Epoch:28 Test Loss:0.0000 Test Acc:0.9832
Epoch:29/30
Epoch:29 Train Loss:0.0000 Train Acc:0.9890
Epoch:29 Test Loss:0.0011 Test Acc:0.9806
Epoch:30/30
Epoch:30 Train Loss:0.0000 Train Acc:0.9886
Epoch:30 Test Loss:0.0001 Test Acc:0.9791

可视化训练过程

plt.figure(figsize=(14,5))
plt.subplot(1,2,1)
plt.plot(train_loss_all,"ro-",label = "Train Loss")
plt.plot(test_loss_all,"bs-",label = "Val Loss")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("Loss")

plt.subplot(1,2,2)
plt.plot(train_acc_all,"ro-",label = "Train acc")
plt.plot(test_acc_all,"bs-",label = "Val acc")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("acc")
plt.show()

png

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值