加载tf模型 正确率很低_深度学习之文本检测系列(1)-CTPN模型实现

参考CTPN代码链接:

https://github.com/eragonruan/text-detection-ctpn

环境版本:

python 3
tensorflow 1.12

我改过的完整代码:

https://github.com/orangesdk/text-detection-ctpn

具体步骤如下:

vgg16模型准备

vgg16模型文件vgg_16.ckpt,需要的可以联系我。

ICDAR2015数据集预处理

cd utils/prepare
python split_label.py

预处理思路: ICDAR2015数据集合将image图片和label标签分开存储,label对应image的图片名并添加了'gt_'前缀,还有我们暂时不需要label中的文本具体内容进行训练,只需要文本的上右下左四个点坐标即可。

读取label标签文件需要注意事项: 设置encoding='UTF-8-sig'为了去除ufeff ,参考链接:https://www.cnblogs.com/chongzi1990/p/8694883.html

        # 设置encoding='UTF-8-sig'为了去除ufeff 
        with open(gt_path, 'r', encoding='UTF-8-sig') as f:

使用cpython编译类库

cd utils/bbox
chmod +x make.sh
./make.sh

编译完成会在对应目录下生成2个动态链接库文件:

bbox.cpython-36m-darwin.so
nms.cpython-36m-darwin.so

训练模型

cd main
python train.py

训练完成会产生训练模型参数文件如下:

39a7ad905c7388ec119efa2d92f44fc5.png

测试模型

cd main
python test.py

取出一个测试结果图片如下:

4b79c2dda7cb3055a76ce89032a25abd.png

构建web server

问题:每次都加载模型会很慢,如何让模型只加载一次?

我们这里使用lru_cache对模型只做一次加载,并定义内嵌函数predict_image,并返回内嵌函数返回结果,关键代码如下(详细代码参考server.py文件):

@lru_cache()
def load_model_and_predict():
    input_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
    input_im_info = tf.placeholder(tf.float32, shape=[None, 3], name='input_im_info')
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
    bbox_pred, cls_pred, cls_prob = model.model(input_image)
    variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
    saver = tf.train.Saver(variable_averages.variables_to_restore())

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    checkpoint_state = tf.train.get_checkpoint_state(checkpoint_path)
    model_path = os.path.join(checkpoint_path, os.path.basename(checkpoint_state.model_checkpoint_path))
    print('Restore from {}'.format(model_path))
    saver.restore(sess, model_path)

    def predict_image(image):
        img, (rh, rw) = resize_image(image)
        h, w, c = img.shape
        im_info = np.array([h, w, c]).reshape([1, 3])
        bbox_pred_val, cls_prob_val = sess.run([bbox_pred, cls_prob],
                                               feed_dict={input_image: [img], input_im_info: im_info})

        text_seg, _ = proposal_layer(cls_prob_val, bbox_pred_val, im_info)
        scores = text_seg[:, 0]
        text_seg = text_seg[:, 1:5]

        text_detector = TextDetector(DETECT_MODE='H')
        boxes = text_detector.detect(text_seg, scores[:, np.newaxis], img.shape[:2])
        boxes = np.array(boxes, dtype=np.int)

        for i, box in enumerate(boxes):
            cv2.polylines(img, [box[:8].astype(np.int32).reshape((-1, 1, 2))], True, color=(0, 255, 0),
                          thickness=2)
        img = cv2.resize(img, None, None, fx=1.0 / rh, fy=1.0 / rw, interpolation=cv2.INTER_LINEAR)
        _, buffer = cv2.imencode('.jpg', img)
        pic_str = base64.b64encode(buffer)
        pic_str = 'data:image/jpg;base64,' + pic_str.decode()
        return render_template('index.html', image_base_data=pic_str)

    return predict_image

启动web server

python server.py

执行预览结果如下:

2b648a42ff8eb81297ff11d20e230aec.png
CTPN文本检测模型https://www.zhihu.com/video/1096099537076269056
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值