Pytorch使用CRNN CTCLoss实现OCR系统

卷积递归神经网络

此项目使用CNN + RNN + CTCLoss实现OCR系统,灵感来自CRNN网络。

一、用法

python ./train.py --help

二、演示

1、使用TestDataset数据生成器训练简单的OCR。训练60-100次。

python train.py --test-init True --test-epoch 10 --output-dir <path_to_folder_with_snapshots>

2、运行test来训练可视化模型。

python test.py --snapshot <path_to_folder_with_snapshots>/crnn_resnet18_10_best --visualize True

三、训练自定义数据集

1、创建数据集

  • 数据集的结构:
<root_dataset_dir>
---- data
-------- <img_filename_0>
...
-------- <img_filename_1>
---- desc.json
  • desc.json的结构:
    {
    "abc": <symbols_in_aphabet>,
    "train": [
    {
    "text": <text_on_image>
    "name": <img_filename>
    },
    ...
    {
    "text": <text_on_image>
    "name": <img_filename>
    }
    ],
    "test": [
    {
    "text": <text_on_image>
    "name": <img_filename>
    },
    ...
    {
    "text": <text_on_image>
    "name": <img_filename>
    }
    ]
    }

    2、使用自定义数据集训练简单的OCR。

    python train.pt --test-init True --test-epoch 10 --output-dir <path_to_folder_with_snapshots> --data-path <path_to_custom_dataset>

    3、运行test来训练可视化模型。

python test.py --snapshot <path_to_folder_with_snapshots>/crnn_resnet18_10_best --visualize True --data-path <path_to_custom_dataset>

四、依赖

五、相关文章

项目地址:BelBES/crnn-pytorch

prtorch参考手册 https://ptorch.com/docs/8/torch-nn

  • 0
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
PyTorch中的CTCLoss是指Connectionist Temporal Classification Loss,它是一种用于解决神经网络标签和输出不对齐问题的方法。CTCLoss的优点是不需要强制对齐标签且标签可以是可变长度的。它主要应用于场景文本识别、语音识别和手写字识别等工程场景。在PyTorch 1.0.x版本内,已经内置了CTCLoss接口,可以直接使用。下面是一个使用CTCLoss的代码示例: ```python import torch import torch.nn as nn ctc_loss = nn.CTCLoss() log_probs = torch.randn(50, 16, 20).log_softmax(2).requires_grad_() targets = torch.randint(1, 20, (16, 30), dtype=torch.long) input_lengths = torch.full((16,), 50, dtype=torch.long) target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) loss.backward() ``` 在这个示例中,我们首先创建了一个CTCLoss实例,然后生成了一些随机的log probabilities作为网络的输出。接着,我们生成了一些随机的目标标签和输入长度以及目标长度。最后,我们使用CTCLoss计算了损失,并进行了反向传播。\[2\] 在创建CTCLoss实例时,可以通过设置参数来自定义一些属性。例如,可以使用`blank`参数来指定空白符的序号,`reduction`参数来指定损失的计算方式。\[3\] 希望这个回答对你有帮助! #### 引用[.reference_title] - *1* [如何使用pytorch内置torch.nn.CTCLoss的方法&&车牌识别应用](https://blog.csdn.net/CSDNwei/article/details/120223026)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [pytorch的torch.nn.CTCLoss方法](https://blog.csdn.net/benben044/article/details/125130411)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [Pytorch中的CTC loss](https://blog.csdn.net/fidbdiej/article/details/124587812)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值