天池——街景字符编码识别 1

比赛地址:https://tianchi.aliyun.com/competition/entrance/531795/introduction
学习内容:https://github.com/datawhalechina/team-learning/tree/master/03%20%E8%AE%A1%E7%AE%97%E6%9C%BA%E8%A7%86%E8%A7%89/%E8%AE%A1%E7%AE%97%E6%9C%BA%E8%A7%86%E8%A7%89%E5%AE%9E%E8%B7%B5%EF%BC%88%E8%A1%97%E6%99%AF%E5%AD%97%E7%AC%A6%E7%BC%96%E7%A0%81%E8%AF%86%E5%88%AB%EF%BC%89

使用工具:Google Colab
使用语言:Python

赛题:

以计算机视觉中字符识别为背景,要求选手预测真实场景下的字符识别,这是一个典型的字符识别问题。
数据集报名后可见并可下载,该数据来自真实场景的门牌号。训练集数据包括3W张照片,验证集数据包括1W张照片,每张照片包括颜色图像和对应的编码类别和具体位置;为了保证比赛的公平性,测试集A包括4W张照片,测试集B包括4W张照片。

在这里插入图片描述

数据集样本展示

Baseline思路:将不定长字符转换为定长字符的识别问题,并使用CNN完成训练和验证,具体包括以下几个步骤:

  • 赛题数据读取(封装为Pytorch的Dataset和DataLoder)
  • 构建CNN模型(使用Pytorch搭建)
  • 模型训练与验证
  • 模型结果预测
定义Dataset类:
def __init__()		#主要是图像数据的获取
def __getitem__()	#主要是对图像数据的转换和图像增强
def __len__()		#主要是整个数据集的长度
定义train数据:
train_path(顺便排了个序),train_json,train_label
定义val数据:
val_path(也排了序),val_json,val_label
定义dataloader:
train_loader,val_loader
载入模型model:
Resnet18作为特征提取模块
定义训练模块:
# 定义优化器 
optimizer = 
# 定义损失函数
loss = criterion(...)
定义验证模块:
def validate(val_loader, model, criterion)		# 用训练好的模型验证测试数据集。
定义预测模块:
# 定义损失函数
loss = criterion(...)

—— TTA(测试时增强) test time augmentation :为原始图像造出多个不同的版本,包括不同区域裁剪和更改缩放程度,并将它们输入到模型中,然后对多个版本进行计算得到平均输出,作为图像的最终输出分数。
—— 这种技术很有效,因为原始图像显示的区域可能会缺少一些重要特征,在模型中输入图像的多个版本并取平均值,能解决上述问题。

训练:

acc(accuracy)——精度
val_acc ——验证集精度

criterion = nn.CrossEntropyLoss()	#交叉熵
for epoch in range(2):		#训练两个epoch
对测试集样本进行预测,生成提交文件

定义test_loader:

test_loader = torch.utils.data.DataLoader(
    SVHNDataset(test_path, test_label,
                transforms.Compose([
                    transforms.Resize((64, 128)),
                    transforms.RandomCrop((60, 120)),
                    # transforms.ColorJitter(0.3, 0.3, 0.2),
                    # transforms.RandomRotation(5),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])), 
    batch_size=40, 
    shuffle=False, 
    num_workers=10,
)

预测的结果标签:

test_predict_label = np.vstack([
    test_predict_label[:, :11].argmax(1),
    test_predict_label[:, 11:22].argmax(1),
    test_predict_label[:, 22:33].argmax(1),
    test_predict_label[:, 33:44].argmax(1),
    test_predict_label[:, 44:55].argmax(1),
]).T		#转置
生成输出文件
import pandas as pd
df_submit = pd.read_csv('../input/test_A_sample_submit.csv')
df_submit['file_code'] = test_label_pred
df_submit.to_csv('renset18.csv', index=None)
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值