原代码:https://github.com/argman/EAST
一、网络结构
下图所示:
为了检测不同大小尺度的文字目标,网络从不同的卷积层提取了f1-f4四个卷积层输出,并通过不同模块h2-h4的上采样后,将不同卷积层的输出按通道叠加,最后输出结果分别经过[1,1,1],[1,1,4],[1,1,1]的卷积核,得到3个输出:
1.score map:[H/4,W/4,1]的张量,表示其中每个像素在文本内容框内的置信。
2.text boxes:[H/4,W/4,4]的张量,表示在文本内容框内的像素到框四边的距离信息。
3.text rotation:[H/4,W/4,1]的张量,表示所处文本框的倾斜角度
下面是源码中的相应部分:
def model(images, weight_decay=1e-5, is_training=True):
'''
define the model, we use slim's implemention of resnet
'''
# 对RGB像素值做标准化,即减去均值
images = mean_image_subtraction(images)
# 先将图片经过resnet_v1网络
# 得到resnet_v1的全部stage的输出,存在end_points里面
with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)):
logits, end_points = resnet_v1.resnet_v1_50(images, is_training=is_training, scope='resnet_v1_50')
with tf.variable_scope('feature_fusion', values=[end_points.values]):
batch_norm_params = {
'decay': 0.997,
'epsilon': 1e-5,
'scale': True,
'is_training': is_training
}
with slim.arg_scope([slim.conv2d],
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params,
weights_regularizer=slim.l2_regularizer(weight_decay)):
# 取第2,3,4,5次池化后的输出
f = [end_points['pool5'], end_points['pool4'],
end_points['pool3'], end_points['pool2']]
for i in range(4):
print('Shape of f_{
} {
}'.format(i, f[i].shape))
g = [None, None, None, None]
h = [None, None, None, None]
num_outputs = [None, 128, 64, 32]
for i in range(4):
# 由网络结构图可知h0=f0
if i == 0:
h[i] = f[i]
# 对其他的hi有,hi = conv(concat(fi,unpool(hi-1)))
else:
c1_1 = slim.conv2d(tf.concat([g[i-1], f[i]], axis=-1), num_outputs[i], 1)
h[i] = slim.conv2d(c1_1, num_outputs[i], 3)
# 由网络结构可知,对于h0,h1,h2都要先经过unpool在与fi进行叠加
if i <= 2:
g[i] = unpool(h[i])
else:
g[i] = slim.conv2d(h[i], num_outputs[i], 3)
print('Shape of h_{
} {
}, g_{
} {
}'.format(i, h[i].shape, i, g[i].shape))
# score map
F_score = slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None)
# text boxes
geo_map = slim.conv2d(g[3], 4, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) * FLAGS.text_scale
# text rotation
angle_map = (slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) - 0.5) * np.pi/2 # angle is between [-45, 45]
# 这里将坐标与角度信息合并输出
F_geometry = tf.concat([geo_map, angle_map], axis=-1)
return F_score, F_geometry
二、loss定义
整个loss主要由三部分组成:
1. 分类loss,即对score_map中预测像素是否处于文本内容内的预测结果的交叉熵
2. 角度loss,对旋转角度预测的一个简单误差函数
3. 定位loss,这里采用了IOU loss
代码很简短,如下:
def loss(y_true_cls, y_pred_cls,
y_true_geo, y_pred_geo,
training_mask):
'''
define the loss used for training, contraning two part,
the first part we use dice loss instead of weighted logloss,
the second part is the iou loss defined in the paper
:param y_true_cls: ground truth of text
:param y_pred_cls: prediction os text
:param y_true_geo: ground truth of geometry
:param y_pred_geo: prediction of geometry
:param training_mask: mask used in training, to ignore some text annotated by ###
:return:
'''
# score交叉熵
classification_loss = dice_coefficient(y_true_cls, y_pred_cls, training_mask)
classification_loss *= 0.01
# d1 -> top, d2->right, d3->bottom, d4->left
# IOU loss计算
d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = tf.split(value=y_true_geo, num_or_size_splits=5, axis=3)
d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = tf.split(value=y_pred_geo, num_or_size_splits=5, axis=3)
area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt)
area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred)
w_union = tf.minimum(d2_gt, d2_pred) + tf.minimum(d4_gt, d4_pred)
h_union = tf.minimum(d1_gt, d1_pred) + tf.minimum(d3_gt, d3_pred)
area_intersect = w_union * h_union
area_union = area_gt + area_pred - area_intersect
L_AABB = -tf.log((area_intersect + 1.0)/(area_union + 1.0))
# 角度误差函数
L_theta = 1 - tf.cos(theta_pred - theta_gt)
tf.summary.scalar('geometry_AABB', tf.reduce_mean(L_AABB * y_true_cls * training_mask))
tf.summary.scalar