CRNN代码笔记

这篇博客详细记录了使用PyTorch实现CRNN模型的过程,包括数据集加载与切分、模型构建、训练、预测和验证。重点介绍了CTCLoss函数的使用,以及如何在GPU和CPU上进行训练和预测。
摘要由CSDN通过智能技术生成

CRNN代码笔记

主要由五个模块组成:

  1. 数据集的加载与切分
  2. CRNN代码复现
  3. 训练过程
  4. 预测过程
  5. 训练过程中对的评估


数据集的加载与切分

import os
import glob

import torch
from torch.utils.data import Dataset
from scipy import signal
from scipy.io import wavfile
import cv2
from PIL import Image
import numpy as np
class Synth90kDataset(Dataset):
    CHARS = '0123456789abcdefghijklmnopqrstuvwxyz'
    CHAR2LABEL = {
   char: i + 1 for i, char in enumerate(CHARS)}
    LABEL2CHAR = {
   label: char for char, label in CHAR2LABEL.items()}

    def __init__(self, root_dir=None, mode=None, paths=None, img_height=32, img_width=100):
        if root_dir and mode and not paths:
            paths, texts = self._load_from_raw_files(root_dir, mode)
        elif not root_dir and not mode and paths:
            texts = None

        self.paths = paths
        self.texts = texts
        self.img_height = img_height
        self.img_width = img_width

    def _load_from_raw_files(self, root_dir, mode):
        mapping = {
   }
        with open(os.path.join(root_dir, 'lexicon.txt'), 'r') as fr:
            for i, line in enumerate(fr.readlines()):
                mapping[i] = line.strip()

        paths_file = None
        if mode == 'train':
            paths_file = 'train.txt'
        elif mode == 'dev':
            paths_file = 'val.txt'
        elif mode == 'test':
            paths_file = 'test.txt'

        paths = []
        texts = []
        with open(os.path.join(root_dir, paths_file), 'r') as fr:
            for line in fr.readlines():
                path, index_str = line.strip().split(' ')
                path = os.path.join(root_dir, path)
                index = int(index_str)
                text = mapping[index]
                paths.append(path)
                texts.append(text)
        return paths, texts

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]

        try:
            image = Image.open(path).convert('L')  # grey-scale
        except IOError:
            print('Corrupted image for %d' % index)
            return self[index + 1]

        image = image.resize((self.img_width, self.img_height), resample=Image.BILINEAR)
        image = np.array(image)
        image = image.reshape((1, self.img_height, self.img_width))
        image = (image / 127.5) - 1.0

        image = torch.FloatTensor(image)
        if self.texts:  #或者图片所对应的label
            text = self.texts[index]
            target = [self.CHAR2LABEL[c] for c in text]
            target_length = [len(target)]

            target = torch.LongTensor(target)
            target_length = torch.LongTensor(target_length)
            # 如果DataLoader不设置collate_fn,则此处返回值为迭代DataLoader时取到的值
            return image, target, target_length
        else: # 测试模式不需要label
            return image

def synth90k_collate_fn(batch):
    # zip(*batch)拆包
    images, targets, target_lengths = zip(*batch)
    # stack就是向量堆叠的意思。一定是扩张一个维度,然后在扩张的维度上,把多个张量纳入仅一个张量。想象向上摞面包片,摞的操作即是stack,0轴即按块stack
    images = torch.stack(images, 0)
    # cat是指向量拼接的意思。一定不扩张维度,想象把两个长条向量cat成一个更长的向量。
    targets = torch.cat(targets, 0)
    target_lengths = torch.cat(target_lengths, 0)
    # 此处返回的数据即使train_loader每次取到的数据,迭代train_loader,每次都会取到三个值,即此处返回值。
    return images, targets, target_lengths
from torch.utils.data import DataLoader
from config import train_config as config
img_width = config['img_width']
img_height = config['img_height']
data_dir = config['data_dir']
train_batch_size = config['train_batch_size']
cpu_workers = config['cpu_workers']
train_dataset = Synth90kDataset(root_dir=data_dir, mode='train',
                                    img_height=img_height, img_width=img_width)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=cpu_workers,
    collate_fn=synth90k_collate_fn)
train_data = train_dataset.__getitem__(2)
print(f'train_data的类型是:{
     type(train_data)}')
print(f'train_data的长度是:{
     len(train_data)}')
train_data的类型是:<class 'tuple'>
train_data的长度是:3
train_data
(tensor([[[0.3804, 0.3804, 0.4196,  ..., 0.4824, 0.4745, 0.4745],
          [0.5451, 0.5451, 0.5451,  ..., 0.4902, 0.4902, 0.4902],
          [0.4824, 0.4824, 0.4667,  ..., 0.4902, 0.4824, 0.4824],
          ...,
          [0.4353, 0.4353, 0.4196,  ..., 0.5059, 0.5059, 0.5059],
          [0.6078, 0.6078, 0.6078,  ..., 0.4902, 0.4824, 0.4824],
          [0.3255, 0.3255, 0.3804,  ..., 0.4667, 0.4745, 0.4745]]]),
 tensor([19, 28, 29]),
 tensor([3]))
img = train_data[0]
label_idx = train_data[1]
label_length = train_data[2]
print(f'img的类型是:
  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值