阿里天池CV比赛快速入门(1.1) Datawhale 零基础入门CV赛事字符识别-Baseline

比赛链接:
https://tianchi.aliyun.com/competition/entrance/531795/introduction?spm=5176.12281973.1005.1.3dd51f54utcdrq

本次新人赛是Datawhale与天池联合发起的零基础入门系列赛事第二场 —— 零基础入门CV赛事之街景字符识别。

赛题以计算机视觉中字符识别为背景,要求选手预测真实场景下的字符识别,这是一个典型的字符识别问题。通过这道赛题可以引导大家走入计算机视觉的世界,主要针对竞赛选手上手视觉赛题,提高对数据建模能力。

为了更好的引导大家入门,我们同时为本赛题定制了系列学习方案,其中包括数据科学库、通用流程和baseline方案学习三部分。通过对本方案的完整学习,可以帮助掌握数据竞赛基本技能。同时我们也将提供专属的视频直播学习通道。

本文是基于对阿水大佬的baseline代码链接的理解,写的博客,适合新手一步步入门。

对于数据科学比赛汇总的一些分享在我的另一些博客,博客链接是:
https://blog.csdn.net/csphillip/article/details/106327073

baseline思路:使用CNN进行定长字符分类;
运行系统要求:Python2/3,现在多建议用python3吧,内存4G(内存大的话处理大量数据时你的电脑就不会那么卡),有无GPU都可以(GPU对于处理尤其是图片数据有很大的效果,能加速训练速度,当然没有也可以,因为目前数据量也还好,大不了等嘛,hh)

线上得分取决于训练轮数,假设比赛数据路径为…/input(这个地址改成你放数据的地址)

import os, sys, glob, shutil, json
os.environ["CUDA_VISIBLE_DEVICES"] = '0'   //如果使用GPU,就只使用第一个。
import cv2     //OpenCV的库,图像处理库

from PIL import Image  //PIL库,也是图片操作库
import numpy as np			

from tqdm import tqdm, tqdm_notebook  //进度条

%pylab inline   //魔法函数,在notebook上画图

import torch
torch.manual_seed(0)//为CPU设置种子用于生成随机数,以使得结果是确定的 
torch.backends.cudnn.deterministic = False //保证模型可重复性的函数,使得两次出现的结果一样,但这里不要一样可见:https://www.cnblogs.com/yongjieShi/p/9614905.html

torch.backends.cudnn.benchmark = True //设置 torch.backends.cudnn.benchmark=True 将会让程序在开始时花费一点额外时间,为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。适用场景是网络结构固定(不是动态变化的),网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的,其实也就是一般情况下都比较适用。反之,如果卷积层的设置一直变化,将会导致程序不停地做优化,反而会耗费更多的时间。具体可见:https://zhuanlan.zhihu.com/p/73711222

import torchvision.models as models//torchvision,会定义预训练模型model,经典数据集类似于minist,还有一些图像的transforms方法
更多见https://www.cnblogs.com/yjphhw/p/9773333.html
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset

上述代码即是导入一些需要的库
定义读取数据集dataset

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

pytorch定义数据集的经典办法,固定套路,定义
init 初始化
getitem 获得对单个样本的索引,用于索引数据集中的数据
len 获得长度定义整个数据集的长度

Pytorch用torch.utils.data.Dataset构建数据集,想要构建自己的数据集,则需继承Dataset类,并重写两个方法: getitem len
还可以看看知乎上的这个对于这三个函数的理解:
https://zhuanlan.zhihu.com/p/38191551

定义读取数据dataloader
假设数据存放在…/input文件夹(换成你放数据集的路径)下,并进行解压。
很经典的,定义完dataset之后,在定义为dataloader,这是使用pytorch的一般套路,具体流程,还有dataset与dataloader的区别,可以看我的另一篇博客:
https://blog.csdn.net/csphillip/article/details/106327073

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]
print(len(train_path), len(
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值