LSTM 手动实现车牌识别 Pytorch代码

循环神经网络RNN和长短期记忆网络LSTM的原理,许多文章都讲的很清晰,我就不到处抄了……
听说实现车牌识别还挺简单的,来尝试一下叭~

首先找车牌图片,虽然有一些生成车牌的软件,但是一般不能批量生成,而且我们还要拿到标签进行训练,好叭,自己先写一个看看。

软件生成的车牌:

在这里插入图片描述
我用最简单的代码生成的车牌:

在这里插入图片描述
emmm,怎么说呢,假得很有层次感。
不管了,先把效果跑出来再说,真实数据集反正咱也没办法,让老板花钱去买好了??[裂开]

import os
import torch
import numpy as np
from PIL import Image
import torch.utils.data as data
from torchvision import transforms


class Sampling(data.Dataset):
    def __init__(self, root, is_train):
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        self.imgs = []
        self.labels = []
        folder = "train" if is_train else "test"
        for filenames in os.listdir(os.path.join(root, folder)):
            x = os.path.join(root, folder, filenames)
            y = filenames.split('.')[0]
            self.imgs.append(x)
            self.labels.append(y)
        self.l1 = ['新', '苏', '浙', '赣', '鄂', '桂', '甘', '晋', '蒙', '陕', '吉', '闽', '贵', '粤', '青',
                   '藏', '川', '宁', '琼', '辽', '黑', '湘', '皖', '鲁', '京', '津', '沪', '渝', '冀', '豫', '云']
        self.l2 = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R',
                   'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        img_path = self.imgs[index]
        img = Image.open(img_path)
        img = self.transform(img)
        label = self.labels[index]
        label = self.one_hot(label)
        return img, label

    def one_hot(self, x):
        z = torch.zeros(7, 36)
        index = self.l1.index(x[0])
        z[0][index] = 1
        for i in range(1, 7):
            index = self.l2.index(x[i])
            z[i][index] = 1
        return z

汉字、字母、数字一共67个值,标签可以选择按67个类别进行独热编码,比较省事儿,但是会有点浪费。因为第一个识别的汉字只使用了前面31个编码位,剩余的编码位没有用,而字母数字的识别只用了后面36个编码位,类别的增多还会提升网络的学习难度。

我选择把31个汉字放一组,36个字母和数字放另一种组,用36位独热编码,这样编码位的利用率比较高。

import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
import sampling
import matplotlib.pyplot as plt


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(126, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU()
        )
        self.lstm = nn.LSTM(input_size=128,
                            hidden_size=256,
                            num_layers=2,
                            batch_first=True)   # (S,N,V) -> (N,S,V)

    def forward(self, x):
        # (N,3,42,130) -> (N,126,130) -> (N,130,126) -> (N*130,126) -> (N*130,128) -> (N,130,128) -> (N,128) -> (N,256)
        x = x.reshape(-1, 126, 130).permute(0, 2, 1)
        x = x.reshape(-1, 126)
        fc1 = self.fc1(x)
        fc1 = fc1.reshape(-1, 130, 128)
        lstm, (h_n, h_c) = self.lstm(fc1, None)
        out = lstm[:, -1, :]

        return out


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(input_size=256,
                            hidden_size=128,
                            num_layers=2,
                            batch_first=True)
        self.out = nn.Linear(128, 36)

    def forward(self, x):
        # (N,256) -> (N,7,256) -> (N,7,128) -> (N*7,128) -> (N*7,36) -> (N,7,36)
        x = x.reshape(-1, 1, 256)
        x = x.expand(-1, 7, 256)
        lstm, (h_n, h_c) = self.lstm(x, None)
        y1 = lstm.reshape(-1, 128)
        out = self.out(y1)
        output = out.reshape(-1, 7, 36)
        return output


class MainNet (nn.Module):
    def __init__(self):
        super(MainNet, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        encoder = self.encoder(x)
        decoder = self.decoder(encoder)

        return decoder


if __name__ == '__main__':
    BATCH = 128
    EPOCH = 100
    save_path = r'params/seq2seq.pth'
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = MainNet().to(device)

    opt = torch.optim.Adam(net.parameters())
    loss_func = nn.MSELoss()

    if os.path.exists(save_path):
        net.load_state_dict(torch.load(save_path))
    else:
        print("No Params!")

    train_data = sampling.Sampling(root="./plate", is_train=True)
    train_loader = data.DataLoader(dataset=train_data, batch_size=BATCH, shuffle=True, num_workers=4)
    valid_data = sampling.Sampling(root="./plate", is_train=False)
    valid_loader = data.DataLoader(dataset=valid_data, batch_size=BATCH, shuffle=False, num_workers=4)

    train_acc_list = []
    valid_acc_list = []
    for epoch in range(EPOCH):
        print("epoch--{}".format(epoch))
        net.train()
        true_num = 0
        for i, (x, y) in enumerate(train_loader):
            batch_x = x.to(device)
            batch_y = y.float().to(device)

            output = net(batch_x)
            loss = loss_func(output, batch_y)

            opt.zero_grad()
            loss.backward()
            opt.step()

            label_y = torch.argmax(y, 2).detach().numpy()
            out_y = torch.argmax(output, 2).cpu().detach().numpy()
            true_num += np.sum(out_y == label_y)
            if i % 50 == 0:
                print("loss:{:.6f}".format(loss.item()))
                print("label_y:\t", label_y[0])
                print("out_y:\t\t", out_y[0])

        train_acc = true_num / (len(train_data) * 7)
        print("train_acc:{:.2f}%".format(train_acc * 100))
        train_acc_list.append(train_acc)

        net.eval()
        true_num = 0
        with torch.no_grad():
            for i, (x, y) in enumerate(valid_loader):
                batch_x = x.to(device)
                batch_y = y.float().to(device)
                output = net(batch_x)
                label_y = torch.argmax(y, 2).detach().numpy()
                out_y = torch.argmax(output, 2).cpu().detach().numpy()
                true_num += np.sum(out_y == label_y)
            valid_acc = true_num / (len(valid_data) * 7)
            print("valid_acc:{:.2f}%".format(valid_acc * 100))
            valid_acc_list.append(valid_acc)

        plt.clf()
        plt.plot(train_acc_list, label='train_acc')
        plt.plot(valid_acc_list, label='valid_acc')
        plt.title('accuracy')
        plt.legend()
        plt.savefig('graph/acc_{}'.format(epoch + 1))

        torch.save(net.state_dict(), save_path)

车牌识别用CNN也是可以做的,这次我们来玩一下循环网络编解码结构(Encoder-Decoder),也叫Seq2Seq模型。
img
Pytorch中调用LSTM挺简单的,就是shape变换有点烦,代码中已备注。
LSTM的输入格式为 ( N , S , V ) (N,S,V) (N,S,V),可以理解为将图片从左到右进行扫描,每次扫描得到的向量依次传入循环网络。 S S S为扫描多少步,就是图片宽度130,也就是上图中输入 x x x的个数; V V V为每步扫描得到的向量,就是上图中的每个 x x x,为图片高度×通道数=42×3。
在这里插入图片描述

因为这个车牌图片非常简单,没有添加噪声、角度倾斜、亮度变化等等数据增强,所以训练是比较容易的,5W张训练图片,1W张验证图片,用上述网络训练13轮准确率就到达100%了。验证精度较高的原因是我们训练完一轮才进行验证,算的是平均精度。可以看到精度上升在0.86左右停滞了几轮,这是网络已经学会了字母和数字,正在攻克第一个最难识别的汉字。

测试代码也贴一下叭:

import os
import torch
import torchvision as tv
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from seq2seq import MainNet

img_dir = "./plate/test"
img_names = os.listdir(img_dir)
net = MainNet()
net.load_state_dict(torch.load(r"params\seq2seq.pth"))
net.eval()

trans = tv.transforms.Compose([
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5,), (0.5,))
    ])
l1 = ['新', '苏', '浙', '赣', '鄂', '桂', '甘', '晋', '蒙', '陕', '吉', '闽', '贵', '粤', '青',
      '藏', '川', '宁', '琼', '辽', '黑', '湘', '皖', '鲁', '京', '津', '沪', '渝', '冀', '豫', '云']
l2 = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R',
      'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
a = np.random.randint(0, len(img_names), 9)
with torch.no_grad():
    for i in range(9):
        img_path = os.path.join(img_dir, img_names[a[i]])
        img = Image.open(img_path)
        data = trans(img).unsqueeze(0)
        out = net(data).squeeze()
        out = torch.argmax(out, 1).numpy()
        predict = l1[out[0]]
        for idx in out[1:]:
            predict += l2[idx]
        plt.subplot(3, 3, i + 1)
        plt.imshow(img)
        plt.axis('off')
        plt.title(predict, fontdict={'family': 'SimHei', 'size': 20})
plt.show()

测试效果:

在这里插入图片描述
在实际的车牌识别项目中,除了要使用真实数据集之外,往往要先定位到车牌。根据实际需求,可以先检测车辆位置,再从车辆上检测车牌位置,最后识别车牌号。

  • 3
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值