【阿里天池新人赛】之街景字符识别(4)

【阿里天池新人赛】之街景字符识别(4)

正式赛时间:2020.5.14-2020.6.24
比赛网址:https://tianchi.aliyun.com/competition/entrance/531795/introduction
上次介绍了如何训练模型,这次介绍生成测试文件和提交结果。

数据预测(test.py)

import pandas as pd
import os, sys, glob, shutil, json
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from dataset import SVHNDataset
import numpy as np
from model import SVHN_Model1


def predict(test_loader, model, tta=10):
    model.eval()
    test_pred_tta = None

    # TTA 次数
    for _ in range(tta):
        test_pred = []

        with torch.no_grad():
            for i, (input, target) in enumerate(test_loader):
                if use_cuda:
                    input = input.cuda()

                c0, c1, c2, c3, c4 = model(input)
                output = np.concatenate([c0.data.cpu().numpy(), c1.data.cpu().numpy(), c2.data.cpu().numpy(), c3.data.cpu().numpy(), c4.data.cpu().numpy()], axis=1)
                test_pred.append(output)

        test_pred = np.vstack(test_pred)
        if test_pred_tta is None:
            test_pred_tta = test_pred
        else:
            test_pred_tta += test_pred

    return test_pred_tta


test_path = glob.glob('data/mchar_test_a/*.png')
test_path.sort()
test_label = [[1]] * len(test_path)
print(len(test_path), len(test_label))

test_transform = transforms.Compose([transforms.Resize((64, 128)),
                                     transforms.RandomCrop((60, 120)),
                                     # transforms.ColorJitter(0.3, 0.3, 0.2),
                                     # transforms.RandomRotation(5),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_data = SVHNDataset(img_path=test_path, img_label=test_label, transform=test_transform)
test_loader = DataLoader(dataset=test_data, batch_size=40, shuffle=False, num_workers=0)

model = SVHN_Model1()
use_cuda = True

if use_cuda:
    model = model.cuda()

# 加载保存的最优模型
model.load_state_dict(torch.load('model.pt'))
test_predict_label = predict(test_loader, model, 1)
test_label = [''.join(map(str, x)) for x in test_loader.dataset.img_label]
test_predict_label = np.vstack([
 test_predict_label[:, :11].argmax(1),
 test_predict_label[:, 11:22].argmax(1),
 test_predict_label[:, 22:33].argmax(1),
 test_predict_label[:, 33:44].argmax(1),
 test_predict_label[:, 44:55].argmax(1),
]).T

test_label_pred = []
for x in test_predict_label:
    test_label_pred.append(''.join(map(str, x[x != 10])))


df_submit = pd.read_csv('data/mchar_sample_submit_A.csv')
df_submit['file_code'] = test_label_pred
df_submit.to_csv('renset50.csv', index=None)

结果

大概能到0.51-0.57的分数,取决于训练轮数,下次介绍如何提高模型的精度。
在这里插入图片描述

  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值