一、赛题理解
1.数据集
训练集数据包括3W张照片
验证集数据包括1W张照片
每张照片包括颜色图像和对应的编码类别和具体位置
目标:识别图片中所有的字符
2.数据标签
标签文件是.json格式,(top,height,left,width,label)
同一张照片可能有多个数字,(预测结果需要考虑按x坐标升序排列)
3.评价指标
以编码整体识别准确率为评价指标。任何一个字符错误都为错误,最终评测指标结果越大越好,具体计算公式如下:
Score=编码识别正确的数量/测试集图片数量
4.代码实践
#json格式数据读取
import json
train_json = json.load(open(’…/input/train.json’))
#数据标注处理
def parse_json(d):
arr = np.array([
d[‘top’], d[‘height’], d[‘left’], d[‘width’], d[‘label’]
])
arr = arr.astype(int)
return arr
img = cv2.imread(’…/input/train/000000.png’)
arr = parse_json(train_json[‘000000.png’])
plt.figure(figsize=(10, 10))
plt.subplot(1, arr.shape[1]+1, 1)
plt.imshow(img)
plt.xticks([]); plt.yticks([])
for idx in range(arr.shape[1]):
plt.subplot(1, arr.shape[1]+1, idx+2)
plt.imshow(img[arr[0, idx]:arr[0, idx]+arr[1, idx],arr[2, idx]:arr[2, idx]+arr[3, idx]])
plt.title(arr[4, idx])
plt.xticks([]); plt.yticks([])
二、初步分析:
1.赛题虽然只要求预测字符,不要求具体位置,是分类问题,但可以尝试用目标检测解决
2.查看数据集发现:很多图片本身像素较低,有些图片上的字符非常模糊(‘012677.png’标注异常,字符高度大于图片高度)
3.字符颜色、排列变化多样,考虑数据增强有可能提高效果
4.用yolo跑了一遍,效果不是很好