Datawhale-天池入门赛街景字符编码识别-Task3:字符识别模型

接上:Datawhale-天池入门赛街景字符编码识别-Task1:赛题理解Datawhale-天池入门赛街景字符编码识别-Task2:数据读取与数据增强

思路

定长序列预测(baseline)

  1. 将序列补齐至相同长度
  2. 训练时预测11个类,0-9,无字符
  3. 测试时将无字符的预测删去

增加length属性,对序列长度预测结果做约束(尝试中)

部分代码如下:

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        self.transform = transform

    def __getitem__(self, idx):
        img = PIL.Image.open(self.img_path[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        # 将所有label处理为0-9+10(无数字)共11类的定长数字串
        lbl = self.img_label[idx]
        length = len(lbl)
        lbl = lbl  + (5 - len(lbl)) * [10]
        return img, length, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

class SVHN_Model1(nn.Module):
    def __init__(self):
        super(SVHN_Model1, self).__init__()
                
        # 加载预训练模型
        base_model = models.wide_resnet101_2(pretrained=False)

        # 选取特定层次
        # self.gray_input = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.conv = nn.Sequential(*list(base_model.children())[:-1])        
        
        self.fca = nn.Linear(2048, 1024)
        self.fcb = nn.Linear(1024, 512)
        # self.fc = nn.Linear(2048,512)

        self.fc0 = nn.Linear(512, 5)
        self.fc1 = nn.Linear(512, 11)
        self.fc2 = nn.Linear(512, 11)
        self.fc3 = nn.Linear(512, 11)
        self.fc4 = nn.Linear(512, 11)
        self.fc5 = nn.Linear(512, 11)
    
    def forward(self, img):    

        # feat = self.gray_input(img)    
        feat = self.conv(img)

        # conv与fc过渡需要flat
        feat = feat.view(feat.shape[0], -1)

        feat = self.fca(feat)
        feat = self.fcb(feat)
        # feat = self.fc(feat)
        
        c0 = self.fc0(feat)
        c1 = self.fc1(feat)
        c2 = self.fc2(feat)
        c3 = self.fc3(feat)
        c4 = self.fc4(feat)
        c5 = self.fc5(feat)
        return c0, c1, c2, c3, c4, c5


def train(train_loader, model, criterion, optimizer):
    # 切换模型为训练模式
    model.train()
    train_loss = []
    
    for i, (input, length, target) in enumerate(train_loader):
        input = input.cuda()
        length = length.cuda()
        target = target.cuda().long()    
        c0, c1, c2, c3, c4 ,c5 = model(input)
        loss1 =  criterion(c0, length) 
        loss2 = criterion(c1, target[:, 0]) + \
                criterion(c2, target[:, 1]) + \
                criterion(c3, target[:, 2]) + \
                criterion(c4, target[:, 3]) + \
                criterion(c5, target[:, 4])
        loss = loss1 + loss2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # lr_scheduler.step(epoch)
        train_loss.append(loss.item())
    return np.mean(train_loss)

def validate(val_loader, model, criterion):
    # 切换模型为预测模型
    model.eval()
    val_loss = []

    with torch.no_grad():
        for i, (input, length, target) in enumerate(val_loader):
            input = input.cuda()
            length = length.cuda()
            target = target.cuda().long()    
            c0, c1, c2, c3, c4 ,c5 = model(input)
            loss1 =  criterion(c0, length) 
            loss2 = criterion(c1, target[:, 0]) + \
                    criterion(c2, target[:, 1]) + \
                    criterion(c3, target[:, 2]) + \
                    criterion(c4, target[:, 3]) + \
                    criterion(c5, target[:, 4])
            loss = loss1 + loss2
            val_loss.append(loss.item())
    return np.mean(val_loss)

def predict(test_loader, model, tta=10):
    model.eval()
    test_pred_tta = None
    
    # TTA 次数
    for _ in range(tta):
        test_pred = []
    
        with torch.no_grad():
            for i, (input, length , target) in enumerate(test_loader):
                input = input.cuda()
                c0, c1, c2, c3, c4 ,c5= model(input)
                output = np.concatenate([
                    c1.data.cpu().numpy(), 
                    c2.data.cpu().numpy(),
                    c3.data.cpu().numpy(), 
                    c4.data.cpu().numpy(),
                    c5.data.cpu().numpy()], axis=1)
                test_pred.append(output)

        test_pred = np.vstack(test_pred)
        if test_pred_tta is None:
            test_pred_tta = test_pred
        else:
            test_pred_tta += test_pred
    
    return test_pred_tta

备注:加长度预测后,训练变慢,结果还未出,后期考虑进一步对不同的loss加上权重,如,加大对长度预测的要求。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值