【阿里天池新人赛】之街景字符识别(4)
正式赛时间:2020.5.14-2020.6.24
比赛网址:https://tianchi.aliyun.com/competition/entrance/531795/introduction
上次介绍了如何训练模型,这次介绍生成测试文件和提交结果。
数据预测(test.py)
import pandas as pd
import os, sys, glob, shutil, json
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from dataset import SVHNDataset
import numpy as np
from model import SVHN_Model1
def predict(test_loader, model, tta=10):
model.eval()
test_pred_tta = None
# TTA 次数
for _ in range(tta):
test_pred = []
with torch.no_grad():
for i, (input, target) in enumerate(test_loader):
if use_cuda:
input = input.cuda()
c0, c1, c2, c3, c4 = model(input)
output = np.concatenate([c0.data.cpu().numpy(), c1.data.cpu().numpy(), c2.data.cpu().numpy(), c3.data.cpu().numpy(), c4.data.cpu().numpy()], axis=1)
test_pred.append(output)
test_pred = np.vstack(test_pred)
if test_pred_tta is None:
test_pred_tta = test_pred
else:
test_pred_tta += test_pred
return test_pred_tta
test_path = glob.glob('data/mchar_test_a/*.png')
test_path.sort()
test_label = [[1]] * len(test_path)
print(len(test_path), len(test_label))
test_transform = transforms.Compose([transforms.Resize((64, 128)),
transforms.RandomCrop((60, 120)),
# transforms.ColorJitter(0.3, 0.3, 0.2),
# transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_data = SVHNDataset(img_path=test_path, img_label=test_label, transform=test_transform)
test_loader = DataLoader(dataset=test_data, batch_size=40, shuffle=False, num_workers=0)
model = SVHN_Model1()
use_cuda = True
if use_cuda:
model = model.cuda()
# 加载保存的最优模型
model.load_state_dict(torch.load('model.pt'))
test_predict_label = predict(test_loader, model, 1)
test_label = [''.join(map(str, x)) for x in test_loader.dataset.img_label]
test_predict_label = np.vstack([
test_predict_label[:, :11].argmax(1),
test_predict_label[:, 11:22].argmax(1),
test_predict_label[:, 22:33].argmax(1),
test_predict_label[:, 33:44].argmax(1),
test_predict_label[:, 44:55].argmax(1),
]).T
test_label_pred = []
for x in test_predict_label:
test_label_pred.append(''.join(map(str, x[x != 10])))
df_submit = pd.read_csv('data/mchar_sample_submit_A.csv')
df_submit['file_code'] = test_label_pred
df_submit.to_csv('renset50.csv', index=None)
结果
大概能到0.51-0.57的分数,取决于训练轮数,下次介绍如何提高模型的精度。