CRNN代码笔记

CRNN代码笔记主要由五个模块组成:数据集的加载与切分CRNN代码复现训练过程预测过程训练过程中对的评估文章目录CRNN代码笔记数据集的加载与切分RCNN模型构建训练部分训练辅助函数注意超参数设置判断cuda是否可用,是则基于GPU训练,否则用cpu训练设置训练数据加载器、验证数据加载器,常规操作collate_fn用法实例化CRNN模型,加载模型参数,并运行至可用设备(CPU or GPU)定义优化方法、损失函数开始训练预测部分超参数设置判断cuda是否可用,否则用cpu预测设置预测数据
摘要由CSDN通过智能技术生成

CRNN代码笔记

主要由五个模块组成:

  1. 数据集的加载与切分
  2. CRNN代码复现
  3. 训练过程
  4. 预测过程
  5. 训练过程中对的评估


数据集的加载与切分

import os
import glob

import torch
from torch.utils.data import Dataset
from scipy import signal
from scipy.io import wavfile
import cv2
from PIL import Image
import numpy as np
class Synth90kDataset(Dataset):
    CHARS = '0123456789abcdefghijklmnopqrstuvwxyz'
    CHAR2LABEL = {
   char: i + 1 for i, char in enumerate(CHARS)}
    LABEL2CHAR = {
   label: char for char, label in CHAR2LABEL.items()}

    def __init__(self, root_dir=None, mode=None, paths=None, img_height=32, img_width=100):
        if root_dir and mode and not paths:
            paths, texts = self._load_from_raw_files(root_dir, mode)
        elif not root_dir and not mode and paths:
            texts = None

        self.paths = paths
        self.texts = texts
        self.img_height = img_height
        self.img_width = img_width

    def _load_from_raw_files(self, root_dir, mode):
        mapping = {
   }
        with open(os.path.join(root_dir, 'lexicon.txt'), 'r') as fr:
            for i, line in enumerate(fr.readlines()):
                mapping[i] = line.strip()

        paths_file = None
        if mode == 'train':
            paths_file = 'train.txt'
        elif mode == 'dev':
            paths_file = 'val.txt'
        elif mode == 'test':
            paths_file = 'test.txt'

        paths = []
        texts = []
        with open(os.path.join(root_dir, paths_file), 'r') as fr:
            for line in fr.readlines():
                path, index_str = line.strip().split(' ')
                path = os.path.join(root_dir, path)
                index = int(index_str)
                text = mapping[index]
                paths.append(path)
                texts.append(text)
        return paths, texts

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]

        try:
            image = Image.open(path).convert('L')  # grey-scale
        except IOError:
            print('Corrupted image for %d' % index)
            return self[index + 1]

        image = image.resize((self.img_width, self.img_height), resample=Image.BILINEAR)
        image = np.array(image)
        image = image.reshape((1, self.img_height, self.img_width))
        image = (image / 127.5) - 1.0

        image = torch.FloatTensor(image)
        if self.texts:  #或者图片所对应的label
            text = self.texts[index]
            target = [self.CHAR2LABEL[c] for c in text]
            target_length = [len(target)]

            target = torch.LongTensor(target)
            target_length = torch.LongTensor(target_length)
            # 如果DataLoader不设置collate_fn,则此处返回值为迭代DataLoader时取到的值
            return image, target, target_length
        else: # 测试模式不需要label
            return image

def synth90k_collate_fn(batch):
    # zip(*batch)拆包
    images, targets, target_lengths = zip(*batch)
    # stack就是向量堆叠的意思。一定是扩张一个维度,然后在扩张的维度上,把多个张量纳入仅一个张量。想象向上摞面包片,摞的操作即是stack,0轴即按块stack
    images = torch.stack(images, 0)
    # cat是指向量拼接的意思。一定不扩张维度,想象把两个长条向量cat成一个更长的向量。
    targets = torch.cat(targets, 0)
    target_lengths = torch.cat(target_lengths, 0)
    # 此处返回的数据即使train_loader每次取到的数据,迭代train_loader,每次都会取到三个值,即此处返回值。
    return images, targets, target_lengths
from torch.utils.data import DataLoader
from config import train_config as config
img_width = config['img_width']
img_height = config['img_height']
data_dir = config['data_dir']
train_batch_size = config['train_batch_size']
cpu_workers = config['cpu_workers']
train_dataset = Synth90kDataset(root_dir=data_dir, mode='train',
                                    img_height=img_height, img_width=img_width)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=cpu_workers,
    collate_fn=synth90k_collate_fn)
train_data = train_dataset.__getitem__(2)
print(f'train_data的类型是:{
     type(train_data)}')
print(f'train_data的长度是:{
     len(train_data)}')
train_data的类型是:<class 'tuple'>
train_data的长度是:3
train_data
(tensor([[[0.3804, 0.3804, 0.4196,  ..., 0.4824, 0.4745, 0.4745],
          [0.5451, 0.5451, 0.5451,  ..., 0.4902, 0.4902, 0.4902],
          [0.4824, 0.4824, 0.4667,  ..., 0.4902, 0.4824, 0.4824],
          ...,
          [0.4353, 0.4353, 0.4196,  ..., 0.5059, 0.5059, 0.5059],
          [0.6078, 0.6078, 0.6078,  ..., 0.4902, 0.4824, 0.4824],
          [0.3255, 0.3255, 0.3804,  ..., 0.4667, 0.4745, 0.4745]]]),
 tensor([19, 28, 29]),
 tensor([3]))
img = train_data[0]
label_idx = train_data[1]
label_length = train_data[2]
print(f'img的类型是:
  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
OCR技术是一种能够将图像中的文本内容转化为可编辑文本的技术,其中ctpn和crnn是OCR技术中的两个重要组成部分。 ctpn(Connectionist Text Proposal Network)是一种基于深度学习的文本检测算法,其主要任务是检测图像中的文本行和单个字符,并将其转换为一组矩形边界框。这些边界框可以用于后续的文本识别操作。 crnn(Convolutional Recurrent Neural Network)是一种基于深度学习的文本识别算法,其主要任务是根据文本检测阶段生成的文本行或单个字符图像,识别其中的文本内容。crnn算法通常由卷积神经网络(CNN)和循环神经网络(RNN)两个部分组成,其中CNN用于提取图像特征,RNN用于对特征序列进行建模。 以下是一个基于ctpn和crnn的OCR代码实现示例(Python): ```python import cv2 import numpy as np import tensorflow as tf # 加载ctpn模型 ctpn_model = cv2.dnn.readNet('ctpn.pb') # 加载crnn模型 crnn_model = tf.keras.models.load_model('crnn.h5') # 定义字符集 charset = '0123456789abcdefghijklmnopqrstuvwxyz' # 定义字符到索引的映射表 char_to_index = {char: index for index, char in enumerate(charset)} # 定义CTPN参数 ctpn_params = { 'model': 'ctpn', 'scale': 600, 'max_scale': 1200, 'text_proposals': 2000, 'min_size': 16, 'line_min_score': 0.9, 'text_proposal_min_score': 0.7, 'text_proposal_nms_threshold': 0.3, 'min_num_proposals': 2, 'max_num_proposals': 10 } # 定义CRNN参数 crnn_params = { 'model': 'crnn', 'img_w': 100, 'img_h': 32, 'num_classes': len(charset), 'rnn_units': 128, 'rnn_dropout': 0.25, 'rnn_recurrent_dropout': 0.25, 'rnn_activation': 'relu', 'rnn_type': 'lstm', 'rnn_direction': 'bidirectional', 'rnn_merge_mode': 'concat', 'cnn_filters': 32, 'cnn_kernel_size': (3, 3), 'cnn_activation': 'relu', 'cnn_pool_size': (2, 2) } # 定义文本检测函数 def detect_text(image): # 将图像转换为灰度图 gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 缩放图像 scale = ctpn_params['scale'] max_scale = ctpn_params['max_scale'] if np.max(gray) > 1: gray = gray / 255 rows, cols = gray.shape if rows > max_scale: scale = max_scale / rows gray = cv2.resize(gray, (0, 0), fx=scale, fy=scale) rows, cols = gray.shape elif rows < scale: scale = scale / rows gray = cv2.resize(gray, (0, 0), fx=scale, fy=scale) rows, cols = gray.shape # 文本检测 ctpn_model.setInput(cv2.dnn.blobFromImage(gray)) output = ctpn_model.forward() boxes = [] for i in range(output.shape[2]): score = output[0, 0, i, 2] if score > ctpn_params['text_proposal_min_score']: x1 = int(output[0, 0, i, 3] * cols / scale) y1 = int(output[0, 0, i, 4] * rows / scale) x2 = int(output[0, 0, i, 5] * cols / scale) y2 = int(output[0, 0, i, 6] * rows / scale) boxes.append([x1, y1, x2, y2]) # 合并重叠的文本框 boxes = cv2.dnn.NMSBoxes(boxes, output[:, :, :, 2], ctpn_params['text_proposal_min_score'], ctpn_params['text_proposal_nms_threshold']) # 提取文本行图像 lines = [] for i in boxes: i = i[0] x1, y1, x2, y2 = boxes[i] line = gray[y1:y2, x1:x2] lines.append(line) return lines # 定义文本识别函数 def recognize_text(image): # 缩放图像 img_w, img_h = crnn_params['img_w'], crnn_params['img_h'] image = cv2.resize(image, (img_w, img_h)) # 归一化图像 if np.max(image) > 1: image = image / 255 # 转换图像格式 image = image.transpose([1, 0, 2]) image = np.expand_dims(image, axis=0) # 预测文本 y_pred = crnn_model.predict(image) y_pred = np.argmax(y_pred, axis=2)[0] # 将预测结果转换为文本 text = '' for i in y_pred: if i != len(charset) - 1 and (not (len(text) > 0 and text[-1] == charset[i])): text += charset[i] return text # 读取图像 image = cv2.imread('test.png') # 检测文本行 lines = detect_text(image) # 识别文本 texts = [] for line in lines: text = recognize_text(line) texts.append(text) # 输出识别结果 print(texts) ``` 上述代码实现了一个基于ctpn和crnn的OCR系统,其中ctpn用于检测文本行,crnn用于识别文本内容。在使用代码时,需要将ctpn和crnn的模型文件替换为自己训练的模型文件,并根据实际情况调整参数。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值