Datawhale-天池入门赛街景字符编码识别-Task4:模型训练与验证

接上:Datawhale-天池入门赛街景字符编码识别-Task1:赛题理解Datawhale-天池入门赛街景字符编码识别-Task2:数据读取与数据增强Datawhale-天池入门赛街景字符编码识别-Task3:字符识别模型

近期进展

近期开始尝试使用检测模型,首先就要对label进行适当处理,自己尝试通过部分可视化的手段,帮助自己对坐标框进行处理,代码如下:

class Visualization:
    
    def __init__(self, image_path, label_path):
        
        self.image_path = image_path
        self.label_path = label_path
        self.image_name = [i.split('\\')[-1] for i in self.image_path]
        
    def show_box(self,save_path):
        
        save_path = save_path
        if not os.path.exists(save_path):
            os.mkdir(save_path)
            
        image_path = self.image_path
        label_path = self.label_path
        image_name = self.image_name
        for i in range(len(image_path)):
            image = cv2.imread(image_path[i])
            for j in range(len(label_path[image_name[i]]['label'])):
                left = label_path[image_name[i]]['left'][j]
                top = label_path[image_name[i]]['top'][j]
                height = label_path[image_name[i]]['height'][j]
                width = label_path[image_name[i]]['width'][j]
                cv2.rectangle(image, (int(left), int(top)), (int(left+width), int(top+height)), (0, 0, 255), 1)
            cv2.imwrite(os.path.join(save_path, train_name[i]), image)
    
    def show_max_box(self,save_path):
        save_path = save_path
        if not os.path.exists(save_path):
            os.mkdir(save_path)
            
        image_path = self.image_path
        label_path = self.label_path
        image_name = self.image_name
        for i in range(len(image_path)):
            image = cv2.imread(image_path[i])
            x1 = min(label_path[image_name[i]]['left'])
            y1 = min(label_path[image_name[i]]['top'])
            x2 = label_path[image_name[i]]['left'][-1] + label_path[image_name[i]]['width'][-1] 
            y2 = max(label_path[image_name[i]]['top']) + max(label_path[image_name[i]]['height'])
            cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 1)
            cv2.imwrite(os.path.join(save_path, train_name[i]), image)
            
    def count_label(self):
        
        label_path = self.label_path
        image_name = self.image_name
        label_len = [len(label_path[image_name[i]]['label']) for i in range(len(label_path))]
        for i in range(6):
            print('{} {} {:' '<6d} {:.6f}'.format(i+1,':',label_len.count(i+1), label_len.count(i+1)/len(label_path)))
        plt.hist(label_len)
        
    def show_outlier(self):
        
        image_path = self.image_path
        label_path = self.label_path
        image_name = self.image_name
        for i in range(len(label_path)):
            if len(label_path[image_name[i]]['label']) in [5,6]:
                print(image_name[i])
                
    def show_image_size(self):
        
        image_path = self.image_path
        
        width = [cv2.imread(image_path[i]).shape[0] for i in range(len(image_path))]
        height = [cv2.imread(image_path[i]).shape[0] for i in range(len(image_path))]
        print(sum(width) / len(width), sum(height) / len(height))
        plt.scatter(width,height)

通过以上代码,可以可视化每个数字的坐标框。
在这里插入图片描述
也可以将所有数字的坐标框合并,做序列预测。
在这里插入图片描述
此外,通过一定的统计可以发现,数字位数集中在1-4,5和6位的数字很少,甚至可以当作离群点处理。
在这里插入图片描述另外,还可以发现,所有图像的宽高比基本一致,分布比较散。
在这里插入图片描述
另外,还对模型的训练代码做了一些补充。

def train(train_loader, model, criterion, optimizer):
    # 切换模型为训练模式
    model.train()
    train_loss = []
    
    for i, (input, target) in enumerate(train_loader):
        input = input.cuda()
        target = target.cuda().long()    
        c1, c2, c3, c4 ,c5 = model(input)
        loss =  criterion(c1, target[:, 0]) + \
                criterion(c2, target[:, 1]) + \
                criterion(c3, target[:, 2]) + \
                criterion(c4, target[:, 3]) + \
                criterion(c5, target[:, 4])
        # 梯度累加
        accumulation_steps = 10
        # loss = loss/accumulation_steps
        loss.backward()
        if((i+1)%accumulation_steps)==0:
            optimizer.step()
            optimizer.zero_grad()
        # lr_scheduler.step(epoch)
        train_loss.append(loss.item())
    return np.mean(train_loss)

上述代码用到了梯度累加的技巧,可以变相地提升batchsize。

def set_random_seed(seed = 0,deterministic=False,benchmark=True):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
    if benchmark:
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True

上述代码可以固定一些随机种子,提升代码的可复现性。

RESUME = False
    if RESUME:
        path_checkpoint = r"result\0528_135935\checkpoints\ckpt_epoch150.pth"  # 断点路径
        checkpoint = torch.load(path_checkpoint)  # 加载断点
        model = SVHN_Model1().cuda()
        model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
        optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
        start_epoch = checkpoint['epoch']  # 设置开始的epoch
        # lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    else:
        model = SVHN_Model1().cuda()
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True, weight_decay=5e-4)
        # lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 120], gamma=0.1) 
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

上述代码可以从断点恢复模型训练,当然,还要配合以下代码。

# 保存断点
checkpoint = {
        "net": model.state_dict(),
        'optimizer': optimizer.state_dict(),
        "epoch": epoch,
        # 'lr_schedule': lr_scheduler.state_dict()
    }    
checkpoint_path = time_path + '/checkpoints'
if (epoch+1) %5 ==0:    
    if not os.path.exists(checkpoint_path):
        os.mkdir(checkpoint_path) 
    torch.save(checkpoint, checkpoint_path + '/ckpt_epoch%s.pth' % (str(epoch+1)))
    print("Save checkpoint at epoch:", epoch+1)

下一步计划

目前还是没有成功尝试新的模型,不过打算以这次机会,熟悉全套字符识别流程,并掌握常用的Pytorch方法和实用的Pytorch技巧,希望近期内能实现一个检测加识别的新baseline,将分数刷到0.8+。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
天池是一个著名的数据科学竞平台,而datawhale是一家致力于数据科学教育和社群建设的组织。街景字符编码识别是指通过计算机视觉技术,对街道场景中的字符进行自动识别和分类。 街景字符编码识别是一项重要的研究领域,对于提高交通安全、城市管理和智能驾驶技术都具有重要意义。街道场景中的字符包括道路标志、车牌号码、店铺招牌等。通过对这些字符进行准确的识别,可以辅助交通管理人员进行交通监管、道路规划和交通流量分析。同时,在智能驾驶领域,街景字符编码识别也是一项关键技术,可以帮助自动驾驶系统准确地识别和理解道路上的各种标志和标识,为自动驾驶提供可靠的环境感知能力。 天池datawhale联合举办街景字符编码识别,旨在吸引全球数据科学和计算机视觉领域的优秀人才,集思广益,共同推动该领域的研究和发展。通过这个竞,参选手可以使用各种机器学习深度学习算法,基于提供的街景字符数据集,设计和训练模型,实现准确的字符编码识别。这个竞不仅有助于促进算法研发和技术创新,也为各参选手提供了一个学习、交流和展示自己技能的平台。 总之,天池datawhale街景字符编码识别是一个具有挑战性和实际应用需求的竞项目,旨在推动计算机视觉和智能交通领域的技术发展,同时也为数据科学爱好者提供了一个学习和展示自己能力的机会。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值