时间片轮转源码_MTL for OCR源码解析

前言

这篇文章带来一份OCR源码(PyTorch)的梳理,代码地址为:https://github.com/bityigoss/mtl-text-recognition [1]。这份源码使用了CTC loss和Attention Loss的多任务模型,也可以说是基于CRNN和Image Caption[3]两个任务的多任务模型。

1. MTL详解

575b847f1b8867f5c2def0809d8b3b66.png
图1:MTL网络结构图

1.1 网络概览

MTL的网络结构的后半部分如图1所示,在它之前是一个由CNN组成的特征提取网络,最终得到的Feature Map会以列为时间片为单位输入到RNN中,也就是输入到图像的

。RNN之后有两个Head,一个是CTC,另外一个是Attention Decoder,他们两个共同组成网络的损失函数。如果只考虑左侧CTC的话,那么它就是一个标准的CRNN模型。

1.2 代码梳理

1.2.1 执行脚本

要梳理MTL的代码流程,我们先要知道网络的一些超参,在源码的README中,多任务模型的调用方式如下(源文件README有误):

CUDA_VISIBLE_DEVICES=0 python mtl_train.py 
    --train_data data/synch/lmdb_train 
    --valid_data data/synch/lmdb_val 
    --select_data / --batch_ratio 1 
    --sensitive 
    --num_iter 400000 
    --output_channel 512 
    --hidden_size 256 
    --Transformation None 
    --FeatureExtraction ResNet 
    --SequenceModeling BiLSTM 
    --Prediction CTC 
    --mtl 
    --without_prediction 
    --experiment_name none_resnet_bilstm_ctc 
    --continue_model saved_models/pretrained_model.pth

前面四项是用来控制读取数据的超参。5-7个比较直观,第8个--Transformation是用来控制是否使用STN,第9个--FeatureExtraction是提取图像特征的网络结构,第10个--SequenceModeling是图1中‘Shared Encoder’的结构。--Prediction是预测的时候选择图1中的CTC的分支或者是Attention Decoder分支。--mtl是选择模型的训练方式,是选择一个任务进行训练还是训练多任务模型。--without_prediction是指模型加载的方式是否需要预测模块。

除了上面列出的,在train.py或者mtl_train.py文件中还有很多可以调整的超参,例如优化方式中涉及的学习策略,学习率;数据处理方式的图像尺寸等。

作者在github的README中有个错误

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值