车牌识别文字识别训练全过程解析 目前代码解读还不算完善 后续会补充
车牌识别github链接
车牌检测end2end实现过程
训练方式按照github上介绍就行
在解释前定义几个方便理解
plate_chr="#京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民航0123456789ABCDEFGHJKLMNPQRSTUVWXYZ危险品"
plate_name="京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民航0123456789ABCDEFGHJKLMNPQRSTUVWXYZ危险品"
plateDict2={'京':0, '京':1, '沪':2, ......, '-': 77}
plateDict={'#':0, '京':1, '沪':2 ......}
pchar表示车牌单个字符 如'京'
pstr表示每一个车牌字符串 如'云A008BC'
p_number表示车牌字符对应的数字, 即plateDict中的0
解析数据集打上标签,生成train.txt和val.txt的程序
- 生成程序
python plateLabel.py --image_path your/train/img/path/ --label_file datasets/train.txt python plateLabel.py --image_path your/val/img/path/ --label_file datasets/val.txt
- plateLabel.py解析
import os import argparse from alphabets import plate_chr # 导入车牌可能出现的所有字符 # 遍历rootfile文件下所有图片 def allFileList(rootfile,allFile): folder =os.listdir(rootfile) for temp in folder: fileName = os.path.join(rootfile,temp) if os.path.isfile(fileName): allFile.append(fileName) else: allFileList(fileName,allFile) # 判断车牌名是不是在palteStr中 当车排名不在plateStr中的 return False def is_str_right(plate_name): for str_ in plate_name: if str_ not in palteStr: return False return True if __name__=="__main__": parser = argparse.ArgumentParser() parser.add_argument('--image_path', type=str, default="datasets/val", help='source') parser.add_argument('--label_file', type=str, default='datasets/val.txt', help='model.pt path(s)') opt = parser.parse_args() rootPath = opt.image_path labelFile = opt.label_file # palteStr=r"#京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民深危险品0123456789ABCDEFGHJKLMNPQRSTUVWXYZ" # palteStr=r"#京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民航深0123456789ABCDEFGHJKLMNPQRSTUVWXYZ" palteStr=plate_chr print(len(palteStr)) # 生成一个字典plateDict plateDict ={} for i in range(len(list(palteStr))): plateDict[palteStr[i]]=i fp = open(labelFile,"w",encoding="utf-8") file =[] allFileList(rootPath,file) # 遍历rootPath下所有图片 保存在file中 picNum = 0 # 遍历每一张图片 for jpgFile in file: print(jpgFile) jpgName = os.path.basename(jpgFile) # 获得图片名称 如: 云A008BC_0.jpg name =jpgName.split("_")[0] # 获得车牌文字pstr 如: 云A008BC if " " in name: continue labelStr=" " if not is_str_right(name): # 如果车牌文字pstr存在不在plateDict中的字符pchar 则直接continue continue strList = list(name) # 将车牌文字转化为列表 如: ['云','A','0','0','8','B','C'] for i in range(len(strList)): labelStr+=str(plateDict[strList[i]])+" " # 将车牌文字转化为对应的数字p_number 如: "25 52 42 42 50 53 54" # while i<7: # labelStr+=str(0)+" " # i+=1 picNum+=1 # print(jpgFile+labelStr) fp.write(jpgFile+labelStr+"\n") # 将图片路径和对应的标签写入labelFile中 如 datasets/val\云A008BC_0.jpg 25 52 42 42 50 53 54 fp.close()
代码解析 按照train.py的代码一步一步解析, 只阐述重点的地方
训练代码解析
-
加载config: config = parse_arg()
def parse_arg(): parser = argparse.ArgumentParser(description="train crnn") parser.add_argument('--cfg', help='experiment configuration filename', default='./lib/config/360CC_config.yaml', type=str) # 配置文件 parser.add_argument('--img_h', type=int, default=48, help='height') # 模型input的h parser.add_argument('--img_w',type=int,default=168,help='width') # 模型input的w args = parser.parse_args() with open(args.cfg, 'r') as f: # config = yaml.load(f, Loader=yaml.FullLoader) config = yaml.load(f, Loader=yaml.FullLoader) config = edict(config) # 将config转化为edict形式的 即从config['DATASET']['ALPHABETS']变成config.DATASET.ALPHABETS'] config.DATASET.ALPHABETS = plateName # 字符集plate_name 比plate_chr少了一个blank字符"#" config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS) # 字符集plate_name长度 77 config.HEIGHT=args.img_h # 输入图片的h config.WIDTH = args.img_w # 输入图片的w return config
-
所有保存文件的输出路径: output_dict = utils.create_log_folder(config, phase=‘train’)
def create_log_folder(cfg, phase='train'): root_output_dir = Path(cfg.OUTPUT_DIR) # set up logger if not root_output_dir.exists(): print('=> creating {}'.format(root_output_dir)) root_output_dir.mkdir() dataset = cfg.DATASET.DATASET #数据集名称 '360CC' model = cfg.MODEL.NAME #模型名称 'crnn' time_str = time.strftime('%Y-%m-%d-%H-%M') #时间 2023-12-14-16-09 checkpoints_output_dir = root_output_dir / dataset / model / time_str / 'checkpoints' #输出文件路径 print('=> creating {}'.format(checkpoints_output_dir)) checkpoints_output_dir.mkdir(parents=True, exist_ok=True) tensorboard_log_dir = root_output_dir / dataset / model / time_str / 'log' #tensotborad日志路径 print('=> creating {}'.format(tensorboard_log_dir)) tensorboard_log_dir.mkdir(parents=True, exist_ok=True) return {'chs_dir': str(checkpoints_output_dir), 'tb_dir': str(tensorboard_log_dir)}
-
数据集的读取
class _360CC(data.Dataset): def __init__(self, config, input_w=168,input_h=48,is_train=True): self.root = config.DATASET.ROOT self.is_train = is_train self.inp_h = config.MODEL.IMAGE_SIZE.H self.inp_w = config.MODEL.IMAGE_SIZE.W self.input_w = input_w # 输入图片的宽 self.input_h= input_h # 输入图片的高 self.dataset_name = config.DATASET.DATASET self.mean = np.array(config.DATASET.MEAN, dtype=np.float32) self.std = np.array(config.DATASET.STD, dtype=np.float32) char_file = config.DATASET.CHAR_FILE # with open(char_file, 'rb') as file: # char_dict = {num: char.strip().decode('gbk', 'ignore') for num, char in enumerate(file.readlines())} # with open(char_file, 'r',encoding='utf-8') as file: # char_dict = {num: char.strip() for num, char in enumerate(file.readlines())} # I resaved char_std_5990.txt in utf-8 format, so no need decode gbk # char_dict = {num: char.strip() for num, char in enumerate(file.readlines())} char_dict = {num:char.strip() for num,char in enumerate(plate_chr)} char_dict[0]="blank" #训练的字符字典plateDict中第一个代表的是空白, 这个跟CTCLoss有关, 可以看一看CTCLoss后就可以理解了 txt_file = config.DATASET.JSON_FILE['train'] if is_train else config.DATASET.JSON_FILE['val'] # convert name:indices to name:string self.labels = [] with open(txt_file, 'r', encoding='utf-8') as file: contents = file.readlines() for c in contents: c=c.strip(" \n") imgname = c.split(' ')[0] indices = c.split(' ')[1:] string = ''.join([char_dict[int(idx)] for idx in indices]) self.labels.append({imgname: string}) print("load {} images!".format(self.__len__())) def __len__(self): return len(self.labels) def __getitem__(self, idx): img_name = list(self.labels[idx].keys())[0] # img = cv2.imread(os.path.join(self.root, img_name)) img = cv_imread(os.path.join(self.root, img_name)) if img.shape[-1]==4: img=cv2.cvtColor(img,cv2.COLOR_BGRA2BGR) # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img_h, img_w ,_= img.shape # img = cv2.resize(img, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC) img = cv2.resize(img, (self.input_w,self.input_h)) # img = np.reshape(img, (48, 168, 3)) # img = np.reshape(img, (self.inp_h, self.inp_w, 1)) img = img.astype(np.float32) img = (img/255. - self.mean) / self.std img = img.transpose([2, 0, 1]) #[h, w, c] -> [c, h, w] 这里没有brg -> rgb 在end2end预测的时候也没有 所以是可以的 return img, idx
-
模型的训练过程中也有一些注意的事项 ———— lib/core/function.py (def train)
# 注意这里的idx还是索引, 可以从上面的数据集读取上看到return idx # labels: ['苏A8C4A8', '川AE00K0', '冀EL2392', '鲁ATN619', '川A5E1Z9', '闽FQZ790', '辽MD7792'......] len=256 labels = utils.get_batch_label(dataset, idx)
# inference # preds: torch.Size([21, 256, 78]) # 21: 车牌预测的字符个数的最大上限 也就是一张车牌最多预测21个字符pchar # 256:图片的batchsize # 78:车牌字符集:77 + 1 (77是车牌字符串plate_name的长度 1是blank, 相当于#) # 这个空白#的存在其实和CTCLoss有关 这里就不过多介绍了 preds = model(inp).cpu()
计算损失这里可能难以理解
# compute loss # batchsize: 256 batch_size = inp.size(0) # text: tensor([11, 52, 50, ..., 47, 42, 47], dtype=torch.int32) shape: torch.Size([1798]) # length: tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, ......] shape:torch.Size([256]) # 从上面的输出结果可以看出text中的是将labels中所有的车牌字符串pstr都拼接在一起, 其中的值代表的是每一个plate_chr{#京沪......}对应的下标 # length中的值则可以很轻松的看出是每一个车牌pstr的长度 text, length = converter.encode(labels) # length = 一个batch中的总字符长度, text = 一个batch中的字符所对应的下标 # preds_size: tensor([21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, ......] shape:torch.Size([256]) preds_size = torch.IntTensor([preds.size(0)] * batch_size) # timestep * batchsize # torch官网上的CTCLoss的使用的参数要求 可以直接看官网 官网更详细 # preds: (T, N, C) T=input length N=batch size C=number of classes(including blank) # text: (N, S) or (sum(target_lengths)) sum就是将所有的字符串pstr拼接在一起并转化为对应的plateDict下标 其中0是blank # preds_size: (N, ) # length: (N, ) loss = criterion(preds, text, preds_size, length)
-
text, length = converter.encode(labels)中的converter问题: converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
# encode的作用就是利用{'#': 0, '京': 1, ......}中的一一对应关系 将车牌名字转化为对应的数字 def encode(self, text): """Support batch or single str. Args: text (str or list of str): texts to convert. Returns: torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. torch.IntTensor [n]: length of each text. """ length = [] result = [] decode_flag = True if type(text[0])==bytes else False for item in text: if decode_flag: item = item.decode('utf-8','strict') length.append(len(item)) for char in item: index = self.dict[char] result.append(index) text = result return (torch.IntTensor(text), torch.IntTensor(length))
训练过程中的验证代码解析
-
模型验证过程中也会有一些注意的地方 ———— lib\core\function.py (def validate)
# preds: shape:(21, 128, 78) # 在max之后 _: 是最大值 shape(21, 128) preds: 是最大值的索引 shape(21, 128) # 这个部分主要是选出车牌字符串集plateDict 78中最大概率的那个作为该位置的输出 _, preds = preds.max(2) # preds: torch.Size([2688]) # 先转化为[128, 21], 主要是为了decode中每一个相邻的21个位置都是同一个车牌上的预测结果 preds = preds.transpose(1, 0).contiguous().view(-1) # preds: tensor([30, 0, 0, ..., 43, 43, 0]) torch.Size([2688]) # preds_size: tensor([21, 21, 21, ......]) shape: torch.Size([128]) sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
-
decode代码解析 lib\utils\utils.py (def encode())
这里的converter.decode是跟CTCLoss进行配合的, 需要好好理解一下
这个部分的作用首先是将预测输出的torch.size([128, 21])中的所有21个对应的最大概率的索引(下标为0-77)转化为对应的plateDict2中的字符, 这里为什么不是{‘-’: 0, …}即blank作为0, 这个原因可以见decode代码解析
然后将转化好的21个字符中去掉重复的字符, 直接举个例子吧, -代表是blank, 以这个间隔, 删掉重复的字母
左边是转化好的21个字符, 右边是得到的字符苏---E--D--3-S-11-22-- => 苏ED3S12 , gt: 苏ED3S12 赣--EE--K--1--3-22-0-- => 赣EK1320 , gt: 赣EK1320 渝-AA---66-P--8--6-11- => 渝A6P861 , gt: 渝A6P861 新--LL--66-0-11-77-99- => 新L60179 , gt: 新L60179 鄂---N--66-MM-Y-11-22- => 鄂N6MY12 , gt: 鄂N6MY12 豫---B--D--8--2-11-11- => 豫BD8211 , gt: 渝BD8211 苏--E---1--1-LL-33-LL- => 苏E11L3L , gt: 苏E11L3L 赣---C--55-55-T--6-22- => 赣C55T62 , gt: 赣C55T62 川-AA---88-8--Y--3-55- => 川A88Y35 , gt: 川A88Y35 例如 川-AA---88 => 就会删掉一个A ,删掉所有的-, 删掉一个8, 得到川A8
def decode(self, t, length, raw=False): """Decode encoded texts back into strs. Args: torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. torch.IntTensor [n]: length of each text. Raises: AssertionError: when the texts and its length does not match. Returns: text (str or list of str): texts to convert. """ # 这个就是一个[21]的车牌字符的转化: 川-AA---88-8--Y--3-55- => 川A88Y35 if length.numel() == 1: length = length[0] assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length) if raw: return ''.join([self.alphabet[i - 1] for i in t]) else: char_list = [] for i in range(length): # t[i] != 0 代表的是不为'-', not(i>0 and t[i-1]==t[i])表示的是当是第一个字符或者前后两个字符是相同的时候为True if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): char_list.append(self.alphabet[t[i] - 1]) #将对应的索引转化为车牌字符 return ''.join(char_list) # 这个是将[128, 21]分开为128个21, 即每一个车牌单独的送入上面的if length.numel() == 1:中 else: # batch mode assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum()) texts = [] index = 0 for i in range(length.numel()): l = length[i] texts.append( self.decode( t[index:index + l], torch.IntTensor([l]), raw=raw)) index += l return texts
-
验证集指标的出现 ———— lib\core\function.py (def validate)
for pred, target in zip(sim_preds, labels): sum+=1 if pred == target: n_correct += 1 accuracy = n_correct / sum
很明显, 这个指标是完全预测准确的车牌/总的预测的车牌
最终训练结果展示
(lvxiaoleother) C:\Users\HUST\Desktop\crnn_plate_recognition> c: && cd c:\Users\HUST\Desktop\crnn_plate_recognition && cmd /C "C:\Users\HUST\anaconda3\envs\lvxiaoleother\python.exe c:\Users\HUST\.vscode\extensions\ms-python.python-2023.4.1\pythonFiles\lib\python\debugpy\adapter/../..\debugpy\launcher 52925 -- C:\Users\HUST\Desktop\crnn_plate_recognition\train.py "
=> creating output\360CC\crnn\2023-12-14-17-17\checkpoints
=> creating output\360CC\crnn\2023-12-14-17-17\log
layer name gradient parameters shape mu sigma
0 feature.0.weight True 1200 [16, 3, 5, 5] -0.00129 0.0672
1 feature.0.bias True 16 [16] -0.00503 0.0724
2 feature.1.weight True 16 [16] 1 0
3 feature.1.bias True 16 [16] 0 0
4 feature.3.weight True 2304 [16, 16, 3, 3] 0.00171 0.0482
5 feature.3.bias True 16 [16] 0.0188 0.0519
6 feature.4.weight True 16 [16] 1 0
7 feature.4.bias True 16 [16] 0 0
8 feature.6.weight True 4608 [32, 16, 3, 3] 4.04e-05 0.0484
9 feature.6.bias True 32 [32] -0.00626 0.0452
10 feature.7.weight True 32 [32] 1 0
11 feature.7.bias True 32 [32] 0 0
12 feature.9.weight True 9216 [32, 32, 3, 3] 0.000206 0.0341
13 feature.9.bias True 32 [32] 0.00341 0.0318
14 feature.10.weight True 32 [32] 1 0
15 feature.10.bias True 32 [32] 0 0
16 feature.13.weight True 18432 [64, 32, 3, 3] 6.27e-05 0.034
17 feature.13.bias True 64 [64] 0.0084 0.0305
18 feature.14.weight True 64 [64] 1 0
19 feature.14.bias True 64 [64] 0 0
20 feature.16.weight True 36864 [64, 64, 3, 3] 6.22e-07 0.0241
21 feature.16.bias True 64 [64] -0.00337 0.0243
22 feature.17.weight True 64 [64] 1 0
23 feature.17.bias True 64 [64] 0 0
24 feature.20.weight True 55296 [96, 64, 3, 3] -3.25e-05 0.0241
25 feature.20.bias True 96 [96] 0.00238 0.0236
26 feature.21.weight True 96 [96] 1 0
27 feature.21.bias True 96 [96] 0 0
28 feature.23.weight True 82944 [96, 96, 3, 3] 6.25e-06 0.0196
29 feature.23.bias True 96 [96] -0.000941 0.02
30 feature.24.weight True 96 [96] 1 0
31 feature.24.bias True 96 [96] 0 0
32 feature.27.weight True 110592 [128, 96, 3, 3] -2.34e-05 0.0196
33 feature.27.bias True 128 [128] -0.000311 0.0204
34 feature.28.weight True 128 [128] 1 0
35 feature.28.bias True 128 [128] 0 0
36 feature.30.weight True 294912 [256, 128, 3, 3] 2.84e-05 0.017
37 feature.30.bias True 256 [256] 0.00124 0.016
38 feature.31.weight True 256 [256] 1 0
39 feature.31.bias True 256 [256] 0 0
40 newCnn.weight True 19968 [78, 256, 1, 1] 0.000316 0.036
41 newCnn.bias True 78 [78] -0.000652 0.0396
Model Summary: 42 layers, 638814 parameters, 638814 gradients
load 62863 images!
load 2014 images!
Epoch: [0][0/246] Time 1200.660s (1200.660s) Speed 0.2 samples/s Data 11.834s (11.834s) Loss 10.75087 (10.75087)
Epoch: [0][100/246] Time 0.060s (16.706s) Speed 4266.7 samples/s Data 0.002s (0.236s) Loss 0.80793 (2.61764)
Epoch: [0][200/246] Time 0.063s (8.425s) Speed 4063.5 samples/s Data 0.002s (0.120s) Loss 0.11421 (1.48051)
苏---E--D--3-S-11-22-- => 苏ED3S12 , gt: 苏ED3S12
赣--EE--K--1--3-22-0-- => 赣EK1320 , gt: 赣EK1320
渝-AA---66-P--8--6-11- => 渝A6P861 , gt: 渝A6P861
新--LL--66-0-11-77-99- => 新L60179 , gt: 新L60179
鄂---N--66-MM-Y-11-22- => 鄂N6MY12 , gt: 鄂N6MY12
豫---B--D--8--2-11-11- => 豫BD8211 , gt: 渝BD8211
苏--E---1--1-LL-33-LL- => 苏E11L3L , gt: 苏E11L3L
赣---C--55-55-T--6-22- => 赣C55T62 , gt: 赣C55T62
川-AA---88-8--Y--3-55- => 川A88Y35 , gt: 川A88Y35
渝--BB--NN-77-9-88-66- => 渝BN7986 , gt: 渝BN7986
1540
128000
Test loss: 0.1785, accuray: 0.7646
is best: True
best acc is: 0.7646474677259185