DataWhale-天池街景数字识别竞赛-task1-赛题理解

背景

2020年5月的DW组队学习选择了天池的街景字符编码识别,在这个入门竞赛中,数据集来自Google街景图像中的门牌号数据集(The Street View House Numbers Dataset, SVHN),并根据一定方式采样得到比赛数据集。而数据集共分为训练集(3W)、验证集(1W)和测试集(4W)。

为了降低难度,比赛提供了数据集中字符的位置框(左上角坐标X,字符高度,左上角坐标Y,字符宽度),并结合字符的编码(label)一起放入到一个json文件中。

评测标准为测试集预测结果的准确率,即编码识别正确的数量测试集图片数量的比率。

环境搭建

这次比赛的教程是基于pytorch框架的深度学习模型,所以学习前需要配置一下pytorch。

由于几个月前曾经配过1.2版本的pytorch,但后来因为各种原因被我删除了,需要重新配一下。为了方便以后再配置,这里简单写一下这次配置1.5版本pytorch的小感想。

  1. 流程:利用Anaconda,首先新建虚拟环境,然后conda安装pytorch,然后可以补充安装各种库。
  2. 以前配环境配多了:上一次配置时,跟着某个教程另外配置了CUDA,CuDnn,感觉有点麻烦,这次完全略过了这一步,经过查找资料,才发现原来一直不用配,除非你直接下了pytorch源码。具体详情

If you install PyTorch via the binaries (e.g., pip wheels or conda), it already comes with CUDA and cuDNN pre-packaged. The only case where you need to install CUDA & cuDNN yourself is when you are compiling it from source.

                                                                                                                             ----------by rasbt Sebastian Raschka

3.PyTorch下得有点慢:即使配置了清华镜像源,torch包还是下得很慢,下了5、6次才成功了。

4.PyTorch直接可以搭载tensorboard:之前一直不太了解PyTorch中怎么用tensorboard,以及tensorboard与tensorboardX的区别。经过查阅资料,我发现X版本是为了适配tensorflow以外的框架,但由于1.1(大概)以后PyTorch已经可以支持原生tensorboard,所以直接安装tensorboard即可。具体详情。

TensorboardX was an third-party adaptation of the Tensorboard lib for pytorch. However, due to its popularity, it was recently included in the official pytorch repo. So, just use the one on the repo.

                                                                                                                                                       ------------by tuts_boy

思路简述

 本章教程内容由阿水编写。对于这次的字符识别,考虑到每个样本的数字个数虽然不同,但普遍较少,最多也就6个,所以可以统一转化为定长(6个数字)的数字字符识别

例如,45转化为45XXXX(其中X为填充字符),加上填充字符,就相对于一个11分类的问题,类别为填充字符意味着该位为空。可以搭建一个简单的卷积神经网络对字符进行识别与分类。

除了定长字符识别,还有不定长字符识别,这需要如CRNN这类的模型;对于赛题数据,已经给出了字符所在的位置,但实际上若不给出,还需要进行目标检测,引入物体检测模型SSD或YOLO。

JSON处理代码

以下jupyter代码的链接:在这里

由于赛题需要结合json文件对图片进行预处理,那么就要先熟悉一下利用json库对图片进行处理。

当然,最开始肯定要先看看数据集,由于文件统一以六位数命名,如000000.png、000123.png等,那么我们可以用下面的语句将整型转化为以0补足6位的格式化字符串从而形成路径。(默认在同级的data文件夹下)

num = 123
num_str = '{:0>6d}'.format(num)
path = 'data/mchar_train/'+num_str+'.png'

有了这几条语句,可以获取到训练集、验证集、测试集的图片文件路径,进一步封装成3个函数:

def get_train_path(num: 'int >= 0 && int <= 29999'):
    if num > 29999 or num < 0:
        print('index out of bound!')
        return 'data/mchar_train/'+'000000'+'.png'

    num_str = '{:0>6d}'.format(num) # 格式化字符串,左边补0 直至6位
    return 'data/mchar_train/'+num_str+'.png'

def get_test_path(num: 'int >= 0 && int <= 9999'):
    if num > 9999 or num < 0:
        print('index out of bound!')
        return 'data/mchar_test_a/'+'000000'+'.png'

    num_str = '{:0>6d}'.format(num) # 格式化字符串,左边补0 直至6位
    return 'data/mchar_test_a/'+num_str+'.png'

def get_val_path(num: 'int >= 0 && int <= 39999'):
    if num > 39999 or num < 0:
        print('index out of bound!')
        return 'data/mchar_val/'+'000000'+'.png'

    num_str = '{:0>6d}'.format(num) # 格式化字符串,左边补0 直至6位
    return 'data/mchar_val/'+num_str+'.png'

如果要读取图片,可以用cv2的imread函数,传入路径即可,展示可以用plt的imshow:

path = get_val_path(12)
img = cv2.imread(path)
plt.imshow(img)

这里展示的是验证集中的000012.png:

下面可以利用json库读取json文件:

train_json = json.load(open('data\mchar_train.json'))
val_json = json.load(open('data\mchar_val.json'))
list(val_json.values())[12]

输出为(居然是13,看起来像19):

{'height': [23, 23],
 'label': [1, 3],
 'left': [157, 164],
 'top': [106, 106],
 'width': [9, 12]}

下面可以对json文件提供的信息进行提取,返回一个numpy数组,方便后续处理:

# 训练集,验证集的json位置提取,默认0为训练集
def parse_json(num: int, mode=0):
    num_str = '{:0>6d}'.format(num) + '.png'
    if mode==0:
        d = train_json[num_str]
    elif mode==1:
        d = val_json[num_str]
    else:
        print('Mode error!')
        return train_json[num_str]

    arr = np.array([d['top'], d['height'], d['left'], d['width'], d['label']])
    arr = arr.astype(int)
    return arr

然后可以定义一个函数进行提取,并结合plt的函数进行展示:

# 训练集,验证集的数字位置提取,默认为训练集
def show_loc(num: int, mode=0):
    if mode==0:
        path = get_train_path(num)
        arr = parse_json(num)
    elif mode==1:
        path = get_val_path(num)
        arr = parse_json(num, 1)
    else:
        print('Mode error!')
        return 'error!'
    
    img = cv2.imread(path)
    shape = arr.shape[1]
    plt.figure(figsize=(10, 10))
    plt.subplot(1, shape+1, 1)
    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])

    for idx in range(shape):
        plt.subplot(1, shape+1, idx+2)
        plt.imshow(img[arr[0, idx]:arr[0, idx]+arr[1, idx], arr[2, idx]:arr[2, idx]+arr[3, idx]])
        plt.title(arr[4, idx])
        plt.xticks([])
        plt.yticks([])

若输入:

show_loc(12, 1)

则输出为:

至此,利用json文件对图片数字进行提取的部分就完成了。

最后

此次学习的教程由Datawhale提供,学习手册的链接为:点这里

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
天池是一个著名的数据科学竞赛平台,而datawhale是一家致力于数据科学教育和社群建设的组织。街景字符编码识别是指通过计算机视觉技术,对街道场景中的字符进行自动识别和分类。 街景字符编码识别是一项重要的研究领域,对于提高交通安全、城市管理和智能驾驶技术都具有重要意义。街道场景中的字符包括道路标志、车牌号码、店铺招牌等。通过对这些字符进行准确的识别,可以辅助交通管理人员进行交通监管、道路规划和交通流量分析。同时,在智能驾驶领域,街景字符编码识别也是一项关键技术,可以帮助自动驾驶系统准确地识别理解道路上的各种标志和标识,为自动驾驶提供可靠的环境感知能力。 天池datawhale联合举办街景字符编码识别竞赛,旨在吸引全球数据科学和计算机视觉领域的优秀人才,集思广益,共同推动该领域的研究和发展。通过这个竞赛,参赛选手可以使用各种机器学习和深度学习算法,基于提供的街景字符数据集,设计和训练模型,实现准确的字符编码识别。这个竞赛不仅有助于促进算法研发和技术创新,也为各参赛选手提供了一个学习、交流和展示自己技能的平台。 总之,天池datawhale街景字符编码识别是一个具有挑战性和实际应用需求的竞赛项目,旨在推动计算机视觉和智能交通领域的技术发展,同时也为数据科学爱好者提供了一个学习和展示自己能力的机会。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值