EAST(tensorflow)-代码理解(一)

原代码:https://github.com/argman/EAST

一、网络结构

下图所示:
在这里插入图片描述
为了检测不同大小尺度的文字目标,网络从不同的卷积层提取了f1-f4四个卷积层输出,并通过不同模块h2-h4的上采样后,将不同卷积层的输出按通道叠加,最后输出结果分别经过[1,1,1],[1,1,4],[1,1,1]的卷积核,得到3个输出:

1.score map:[H/4,W/4,1]的张量,表示其中每个像素在文本内容框内的置信。
2.text boxes:[H/4,W/4,4]的张量,表示在文本内容框内的像素到框四边的距离信息。
3.text rotation:[H/4,W/4,1]的张量,表示所处文本框的倾斜角度

下面是源码中的相应部分:

def model(images, weight_decay=1e-5, is_training=True):
    '''
    define the model, we use slim's implemention of resnet
    '''
    # 对RGB像素值做标准化,即减去均值
    images = mean_image_subtraction(images)
 
    # 先将图片经过resnet_v1网络
    # 得到resnet_v1的全部stage的输出,存在end_points里面
    with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)):
        logits, end_points = resnet_v1.resnet_v1_50(images, is_training=is_training, scope='resnet_v1_50')
 
    with tf.variable_scope('feature_fusion', values=[end_points.values]):
        batch_norm_params = {
   
        'decay': 0.997,
        'epsilon': 1e-5,
        'scale': True,
        'is_training': is_training
        }
        with slim.arg_scope([slim.conv2d],
                            activation_fn=tf.nn.relu,
                            normalizer_fn=slim.batch_norm,
                            normalizer_params=batch_norm_params,
                            weights_regularizer=slim.l2_regularizer(weight_decay)):
            # 取第2,3,4,5次池化后的输出
            f = [end_points['pool5'], end_points['pool4'],
                 end_points['pool3'], end_points['pool2']]
            for i in range(4):
                print('Shape of f_{
   } {
   }'.format(i, f[i].shape))
            g = [None, None, None, None]
            h = [None, None, None, None]
            num_outputs = [None, 128, 64, 32]
            for i in range(4):
                # 由网络结构图可知h0=f0
                if i == 0:
                    h[i] = f[i]
                # 对其他的hi有,hi = conv(concat(fi,unpool(hi-1)))
                else:
                    c1_1 = slim.conv2d(tf.concat([g[i-1], f[i]], axis=-1), num_outputs[i], 1)
                    h[i] = slim.conv2d(c1_1, num_outputs[i], 3)
                # 由网络结构可知,对于h0,h1,h2都要先经过unpool在与fi进行叠加
                if i <= 2:
                    g[i] = unpool(h[i])
                else:
                    g[i] = slim.conv2d(h[i], num_outputs[i], 3)
                print('Shape of h_{
   } {
   }, g_{
   } {
   }'.format(i, h[i].shape, i, g[i].shape))
 
            # score map
            F_score = slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None)
            # text boxes
            geo_map = slim.conv2d(g[3], 4, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) * FLAGS.text_scale
            # text rotation
            angle_map = (slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) - 0.5) * np.pi/2 # angle is between [-45, 45]
            # 这里将坐标与角度信息合并输出
            F_geometry = tf.concat([geo_map, angle_map], axis=-1)
 
    return F_score, F_geometry

二、loss定义

整个loss主要由三部分组成:

1. 分类loss,即对score_map中预测像素是否处于文本内容内的预测结果的交叉熵
2. 角度loss,对旋转角度预测的一个简单误差函数
3. 定位loss,这里采用了IOU loss

代码很简短,如下

def loss(y_true_cls, y_pred_cls,
         y_true_geo, y_pred_geo,
         training_mask):
    '''
    define the loss used for training, contraning two part,
    the first part we use dice loss instead of weighted logloss,
    the second part is the iou loss defined in the paper
    :param y_true_cls: ground truth of text
    :param y_pred_cls: prediction os text
    :param y_true_geo: ground truth of geometry
    :param y_pred_geo: prediction of geometry
    :param training_mask: mask used in training, to ignore some text annotated by ###
    :return:
    '''
    # score交叉熵
    classification_loss = dice_coefficient(y_true_cls, y_pred_cls, training_mask)
    classification_loss *= 0.01
 
    # d1 -> top, d2->right, d3->bottom, d4->left
    # IOU loss计算
    d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = tf.split(value=y_true_geo, num_or_size_splits=5, axis=3)
    d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = tf.split(value=y_pred_geo, num_or_size_splits=5, axis=3)
    area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt)
    area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred)
    w_union = tf.minimum(d2_gt, d2_pred) + tf.minimum(d4_gt, d4_pred)
    h_union = tf.minimum(d1_gt, d1_pred) + tf.minimum(d3_gt, d3_pred)
    area_intersect = w_union * h_union
    area_union = area_gt + area_pred - area_intersect
    L_AABB = -tf.log((area_intersect + 1.0)/(area_union + 1.0))
 
    # 角度误差函数
    L_theta = 1 - tf.cos(theta_pred - theta_gt)
 
    
    tf.summary.scalar('geometry_AABB', tf.reduce_mean(L_AABB * y_true_cls * training_mask))
    tf.summary.scalar
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值