Datawhale
作者:阿水、陈信达 Datawhale成员
本文针对阿里天池《零基础入门CV赛事-街景字符编码识别》,给出了百行代码Baseline,帮助cv学习者更好地结合赛事实践。同时,从赛题数据分析和解题思路分析两方面进行了详细的解读,以便于大家进阶学习。
数据及背景
https://tianchi.aliyun.com/competition/entrance/531795/information(阿里天池-零基础入门CV赛事)
百行Baseline
Baseline以定长字符识别为解题思路,进行了必要的注释和代码实现,分数在0.6左右,运用时长:CPU大约需要2小时,GPU大约10分钟。
import glob, json
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path, self.img_label, self.transform = img_path, img_label, transform
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB') # 读取数据
img = self.transform(img) # 做相应变换
if self.img_label:
lbl = np.array(self.img_label[index], dtype=np.int) # 制作标签
lbl = list(lbl) + (5 - len(lbl)) * [10] # 标签长度少于五的用10来填充
return img, torch.from_numpy(np.array(lbl[:5]))
else:
return img
def __len__(self):
return len(self.img_path)
# 定义模型
class SVHN_Model1(nn.Module):
def __init__(self):
super(SVHN_Model1, self).__init__()
self.cnn = models.resnet50(pretrained=True) # 加载resnet50
self.cnn.avgpool = nn.AdaptiveAvgPool2d(1) # 将平均池化改为自适应平均池化
self.cnn = nn.Sequential(*list(self.cnn.children())[:-1]) # 去除最后的线性层
self.fc1,self.fc2,self.fc3 = nn.Linear(2048, 11), nn.Linear(2048, 11), nn.Linear(2048, 11)
self.fc4,self.fc5 = nn.Linear(2048, 11), nn.Linear(2048, 11)
def forward(self, img):
feat = self.cnn(img)
feat = feat.view(feat.shape[0], -1)
c1,c2,c3 = self.fc1(feat), self.fc2(feat), self.fc3(feat)
c4,c5 = self.fc4(feat), self.fc5(feat)
return c1, c2, c3, c4, c5
def train(train_loader, model, criterion, optimizer):
model.train() # 切换模型为训练模式
train_loss = []
for input, target in tqdm(train_loader): # 取出数据与对应标签
if use_cuda: # 如果是gpu版本
input, target = input.cuda(), target.cuda()
target = target.long(