cv第一次打卡
2020.05.20
pytorch环境安装
-
使用centos 无GPU版pytorch
-
首先使用如下命令 在Anaconda中创建一个专门用于本次天池练习赛的虚拟环境。
conda create -n py37_torch131 python=3.7 -
然后激活环境,并安装pytorch1.3.1
source activate py37_torch131
conda install pytorch=1.3.1 torchvision cudatoolkit=10.0 -
通过下面的命令一键安装所需其它依赖库
pip install jupyter tqdm opencv-python matplotlib pandas -
启动notebook,即可开始baseline代码的学习
jupyter-notebook
-
-
完成安装和启动后, 根据本机环境进行适当修改和配置后, 运行baseline, 没有遇到问题.
比赛规则
- 比赛允许使用CIFAR-10和ImageNet数据集的预训练模型,不允许使用其他任何预训练模型和任何外部数据;
- 报名成功后,从天池下载经过预处理的数据,在本地调试算法,提交结果(不应该使用SVHN原始数据进行训练和预测);
- 提交后将进行实时评测;每天排行榜更新时间为12:00和20:00,按照评测指标得分从高到低排序;排行榜将选择历史最优成绩进行展示。
赛题数据
- 包括三个数据集: 训练集,验证集和测试集.
- 训练集数据包括3W张照片,照片中的待识别数字和数字位置框保存在json文件中.
- 验证集数据包括1W张照片,照片中的待识别数字和数字位置框也保存在json文件中.
-
json文件中的字段表
所有的数据(训练集、验证集和测试集)的标注使用JSON格式,并使用文件名进行索引。Field Description top 左上角坐标X height 字符高度 left 左上角最表Y width 字符宽度 label 字符编码
-
- 测试数据集包括4W张照片, 未提供测试数据集的数字位置框.
评分标准
以编码整体识别准确率为评价指标, 这意味着必须对图片中的所有数字全部预测正确, 才算一次正确的预测.
S
c
o
r
e
=
编
码
识
别
正
确
的
数
量
/
测
试
集
图
片
数
量
Score=编码识别正确的数量/测试集图片数量
Score=编码识别正确的数量/测试集图片数量
思路与分析:
-
测试集未提供数字位置框, 使用baseline测试时发现, 准确率较低, 根据阿水在钉钉群的提示, 首先应该利用训练集的照片和数字位置框训练一个标注模型, 对测试集的照片中的数字预测其位置框, 然后再根据baseline的方法, 使用resnet的预训练模型进行训练和预测.
- 在赛题数据中已经给出了训练集、验证集中所有图片中字符的位置,因此可以首先将字符的位置进行识别,利用物体检测的思路完成.
- 此种思路需要参赛选手构建字符检测模型,对测试集中的字符进行识别。选手可以参考物体检测模型SSD或者YOLO来完成。
-
观察发现, 照片中的数据个数并不固定, 一般是2到4个, 这给预测增加了些难度. 处理方法为, 将赛题抽象为一个六个字符的定长字符识别问题,字符23填充为23XXXX,字符231填充为231XXX.最后再处理掉填充内容.
- 在字符识别研究中,有特定的方法来解决此种不定长的字符识别问题,比较典型的有CRNN字符识别模型。
- 在本次赛题中给定的图像数据都比较规整,可以视为一个单词或者一个句子。
-
由于训练集和测试集的数量相等, 均为3w, 相对而言训练集的数据量较小, 在进行训练前, 还需要使用某种数据扩增方法对训练集进行扩增.
# json文件的读取和处理代码(根据task1 赛题理解.ipynb)
import json
train_json = json.load(open('../input/train.json'))
# 数据标注处理
def parse_json(d):
arr = np.array([
d['top'], d['height'], d['left'], d['width'], d['label']
])
arr = arr.astype(int)
return arr
img = cv2.imread('../input/train/000000.png')
arr = parse_json(train_json['000000.png'])
plt.figure(figsize=(10, 10))
plt.subplot(1, arr.shape[1]+1, 1)
plt.imshow(img)
plt.xticks([]); plt.yticks([])
for idx in range(arr.shape[1]):
plt.subplot(1, arr.shape[1]+1, idx+2)
plt.imshow(img[arr[0, idx]:arr[0, idx]+arr[1, idx],arr[2, idx]:arr[2, idx]+arr[3, idx]])
plt.title(arr[4, idx])
plt.xticks([]); plt.yticks([])