pytorchOCR之PSEnet

pytorchOCR之PSEnet

论文链接
官方代码

论文解读这里就不做了,网上很多。这里只对项目代码解读。

标签制作

在这里插入图片描述

  • 借用论文里的图,如图所示,需要生成若干个(自己设定,论文中为6)黑白图,文字部分为白即为1,背景部分为黑即为0. 白色最大的为文字分割图,最小的文中叫做kernel图,通过这样可以分开临近的文本。
  • 在ptocr/dataloader/DetLoad/MakeSegMap.py里的
def shrink(self,bboxes, rate, max_shr=20):
        rate = rate * rate
        shrinked_bboxes = []
        for bbox in bboxes:
            area = plg.Polygon(bbox).area()
            peri = self.perimeter(bbox)

            pco = pyclipper.PyclipperOffset()
            pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
            offset = min((int)(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)

            shrinked_bbox = pco.Execute(-offset)
            if len(shrinked_bbox) == 0:
                shrinked_bboxes.append(bbox)
                continue

            shrinked_bbox = np.array(shrinked_bbox)[0]
            shrinked_bbox = np.array(shrinked_bbox)
            if shrinked_bbox.shape[0] <= 2:
                shrinked_bboxes.append(bbox)
                continue

            shrinked_bboxes.append(shrinked_bbox)

        return np.array(shrinked_bboxes)

通过这个函数将标注框缩小,得到每个缩小的框。最后用opencv生成分割图。

模型解读

该检测方法是基于分割,论文使用FPN作为分割网络,其中backbone为resnet50,参看
ptocr/model/backbone/det_resnet.py部分代码如下

 def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x2 = self.layer1(x)
        x3 = self.layer2(x2)
        x4 = self.layer3(x3)
        x5 = self.layer4(x4)

        return x2, x3, x4, x5

经过该backbone返回四个map(x2,x3,x4,x5),分别为原图的1/4,1/8,1/16,1/32.此四个map 进入ptocr/model/head/det_FPNHead.py,如下:
该部分是fpn不同深度的map融合部分

self.toplayer = ConvBnRelu(in_channels[-1], inner_channels, kernel_size=1, stride=1,padding=0,bias=bias)  # Reduce channels
# Smooth layers
self.smooth1 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1,bias=bias)
self.smooth2 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1,bias=bias)
self.smooth3 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1,bias=bias)
# Lateral layers
self.latlayer1 = ConvBnRelu(in_channels[-2], inner_channels, kernel_size=1, stride=1, padding=0,bias=bias)
self.latlayer2 = ConvBnRelu(in_channels[-3], inner_channels, kernel_size=1, stride=1, padding=0,bias=bias)
self.latlayer3 = ConvBnRelu(in_channels[-4], inner_channels, kernel_size=1, stride=1, padding=0,bias=bias)
# Out map
self.conv_out = ConvBnRelu(inner_channels * 4, inner_channels, kernel_size=3, stride=1, padding=1,bias=bias)

在config的yaml中需要设置in_channels和inner_channels,其中in_channels分别对应着不同尺度输出map(x2,x3,x4,x5)的channel数目,如果你想改变backbone,这里也要根据实际情况做相应改变,inner_channels可以随意设置,但是一般根据backbone来调整。

def forward(self, x):
		c2, c3, c4, c5 = x
        ##
        p5 = self.toplayer(c5)
        c4 = self.latlayer1(c4)
        p4 = upsample_add(p5, c4)
        p4 = self.smooth1(p4)
        c3 = self.latlayer2(c3)
        p3 = upsample_add(p4, c3)
        p3 = self.smooth2(p3)
        c2 = self.latlayer3(c2)
        p2 = upsample_add(p3, c2)
        p2 = self.smooth3(p2)
        ##
        p3 = upsample(p3, p2)
        p4 = upsample(p4, p2)
        p5 = upsample(p5, p2)

        fuse = torch.cat((p2, p3, p4, p5), 1)
        fuse = self.conv_out(fuse)
        return fuse

这里操作就是将深层map向上做插值和上一层的map做融合,最后将不同尺度的map进行concat,论文中对此有描述。至此FPN部分完成。于是进入ptocr/model/segout/det_PSE_segout.py

class SegDetector(nn.Module):
    def __init__(self,inner_channels=256,classes=7):
        super(SegDetector,self).__init__()
        self.binarize = nn.Conv2d(inner_channels,classes,1,1,0)
    def forward(self, x,img):
        x = self.binarize(x)
        x = upsample(x,img)
        if self.training:
            pre_batch = dict(pre_text=x[:,0])
            pre_batch['pre_kernel'] = x[:,1:]
            return pre_batch
        return x

这里就是输出分割图,并把分割图插值成原图大小,这里输出7个分割图,其中第0个为最大对应着图片中文字的分割图,依次不断减小,kernel就是最小的一个分割图即第6个kernel图作用就是用来区分密集文本。

loss 部分

这里用到了分割常用的dice loss,在ptocr/model/loss/basical_loss.py如下:

class DiceLoss(nn.Module):
    def __init__(self,eps=1e-6):
        super(DiceLoss,self).__init__()
        self.eps = eps
    def forward(self,pre_score,gt_score,train_mask):
        pre_score = pre_score.contiguous().view(pre_score.size()[0], -1)
        gt_score = gt_score.contiguous().view(gt_score.size()[0], -1)
        train_mask = train_mask.contiguous().view(train_mask.size()[0], -1)

        pre_score = pre_score * train_mask
        gt_score = gt_score * train_mask

        a = torch.sum(pre_score * gt_score, 1)
        b = torch.sum(pre_score * pre_score, 1) + self.eps
        c = torch.sum(gt_score * gt_score, 1) + self.eps
        d = (2 * a) / (b + c)
        dice_loss = torch.mean(d)
        return 1 - dice_loss

这里共需要三个输入,一个网络输出的7个图,一个标签制作好的7个图,以及这七个图的train_mask,这里train_mask的作用就是使得部分像素不参与loss计算(即这部分的loss为0)。
这里用到了ohem如下:

def ohem_single(score, gt_text, training_mask):
   pos_num = (int)(np.sum(gt_text > 0.5)) - (int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))

   if pos_num == 0:
       # selected_mask = gt_text.copy() * 0 # may be not good
       selected_mask = training_mask
       selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
       return selected_mask

   neg_num = (int)(np.sum(gt_text <= 0.5))
   neg_num = (int)(min(pos_num * 3, neg_num))

   if neg_num == 0:
       selected_mask = training_mask
       selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
       return selected_mask

   neg_score = score[gt_text <= 0.5]
   neg_score_sorted = np.sort(-neg_score)
   threshold = -neg_score_sorted[neg_num - 1]

   selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5)
   selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
   return selected_mask

这里就是选取负样本中loss排序大的,选择正负样本为1:3,假如正样本有3个,负样本像素就要选9个。选择loss最大的九个。

说明:文中图均来自论文

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
PyTorchOCR是一个基于PyTorch框架的OCR(光学字符识别)工具包。它包含了文本检测和文本识别两个主要模块。在文本检测方面,PyTorchOCR使用了icdar2015数据集进行算法效果对比的训练数据。训练数据包括标注图片和标注文件,标注文件中存放着标注框的坐标和标注框的label。训练时需要使用一个train_list.txt文件来指定图片和标注文件的绝对地址,用于训练时读取。在验证时,需要使用一个test_list.txt文件来指定验证数据的图片的绝对地址,并在config文件中指定test_gt_path来指定验证数据的标注文件地址。\[3\] 在PyTorchOCR中,生成pkl文件的代码可以通过定义一个_label_path_from_index函数来实现。该函数会读取一个train_pkl文件,并返回其内容。具体的代码如下: ```python def _label_path_from_index(self): label_file = os.path.join(self.label_path, "train_pkl") assert os.path.exists(label_file, "path dose not exits:{}".format(label_file)) gt_file = open(label_file, "rb") label_file = cPickle.load(gt_file) gt_file.close() return label_file ```\[1\] 在加载数据集的位置,可以通过修改OCRIter类中的初始化加载函数来指定训练集和测试集的图片路径和标签pkl文件的路径。具体的代码如下: ```python if train_flag: self.data_path = os.path.join(os.getcwd(), "data", "train", "text") self.label_path = os.path.join(os.getcwd(), "data", "train") else: self.data_path = os.path.join(os.getcwd(), "data", "test", "text") self.label_path = os.path.join(os.getcwd(), "data", "test") ```\[2\] 总结起来,PyTorchOCR是一个基于PyTorch框架的OCR工具包,包含文本检测和文本识别两个主要模块。在文本检测方面,使用icdar2015数据集进行训练,训练数据包括标注图片和标注文件。在加载数据集时,可以通过修改代码中的路径来指定训练集和测试集的图片路径和标签pkl文件的路径。生成pkl文件的代码可以通过定义一个函数来实现。 #### 引用[.reference_title] - *1* *2* [PyTorch实现 | 车牌OCR识别,《PyTorch深度学习之目标检测》](https://blog.csdn.net/lgzlgz3102/article/details/129210978)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [pytorchOCR之数据篇](https://blog.csdn.net/fxwfxw7037681/article/details/111933435)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值