手把手教你使用pytorch+flask搭建草图检索系统(一)

6 篇文章 0 订阅
5 篇文章 2 订阅

手把手教你使用pytorch+flask搭建草图检索系统(一)



一. 写在前面

弄好了模型该怎么展示,不可能就简单地输入输出把结果拿给别人看吧,总要个界面吧,总要让人操作吧。模型训练大多都是在python上运行的,虽然c语言系的web框架有很多,但把用python写的模型转成c++能用的模型或许有点困难(有,比如pytorch1.0支持的c++ interface, torch::jit),行,找个python能使的web框架吧——flask,要的就是它的简单、轻量、易上手、开发速度快,不用动太多脑筋,前后台开发我可是一窍不通的。

更新:目前pytorch1.3.0及以上版本,转换成c++能运行的模型,实际上已经比较简单了,使用如下命令就可以一键转换了,简单说就是trace输入tensor经过网络的所有信息,不过无法trace例如printf、for循环等内容,lstm可能有点麻烦。

import torch
input = torch.rand(n, c, w, h).to(device)
traced_script_module = torch.jit.trace(model, (input, input, input))
traced_script_module.save(output_path)

1.1 系统框架说明

1.1.1 检索网络

模型是采用的孪生神经网络SketchTriplet,原文为CVIU 2016 paper “Compact Descriptors for Sketch-based Image Retrieval using a Triplet loss Convolutional Neural Network”[Repo|Page],github上的代码是基于caffe的,我这里有份我自己写的pytorch版的代码,网络结构与原版近似,结果也是比较好的。

该网络使用pytorch自带的Triplet Loss,在训练过程中,网络输入为手绘草图与[同类自然图像, 异类自然图像],encoder获得输入图像的特征,Triplet Loss使得encoder在获取特征上让手绘草图与同类自然图像的特征更加近似、手绘草图与异类自然图像的特征的差异更加明显,训练中不断优化这个encoder,在训练结束后,就可以用优化好的encoder提取数据库中自然图像的特征向量并保存,在检索阶段,对输入草图使用训练得到的encoder获取草图特征向量,与数据库中自然图像的特征向量进行相似度比较,从而得到检索结果。

在这里插入图片描述

论文中使用了anchorpositivenegative这三个标签,所以按照上面的思路,就是将anchor的特征与positive尽量接近,与negative尽量远离了。这里附上论文中的网络结构图,从图中可以看出,论文中还有个half-sharing的结构,positivenegative共用encoder/decoder的权重,positiveanchor则是共用decoder的权重,我猜这样的用意是,用相同的encoder来理解输入的negativepositive,毕竟这个negativepositive是自然图像提取的边缘,与anchor是手绘草图不同,而使用同样的decoder来解释特征,是因为预期encoder能得到同样的特征吧。在代码中,我通过建立两套encoder与decoder来达成half-sharing结构。

1.1.2 检索框架 top-k

极为简单的L2距离,输入草图,经过网络的处理,得到该草图的特征向量,然后与数据集(自然图像)的(离线的,已经提取好的)特征向量进行相似度比较,也就是开头说的L2距离,然后排序,返回前k张自然图像。

1.2 准备工作

1.2.1 环境配置

下载flask,推荐使用anaconda,顺带pytorch什么的就可以一并安装了。装这些东西的教程网上有大把,这里我就提一下,anaconda之类的东西,装到想要的目录,最好别让它装到默认目录,也就是/home/<user_name>/里面,会让这个挂载盘符越来越小,~或者/home/<user_name>可以当做win的c盘一样的东西,满了就比较麻烦。下面是一些简单步骤:

# for anaconda
wget https://repo.anaconda.com/archive/Anaconda3-2019.07-Linux-x86_64.sh
sh ./Anaconda3-2019.07-Linux-x86_64.sh
vim ~/.bashrc
--------------- in ~/.bashrc file, and wq
--- # add anaconda env 
--- export PATH=<install_path>/anaconda3/bin:$PATH
---------------
source ~/.bashrc

# for pytorch
conda install pytorch

# for flask
conda install flask
1.2.1 检索数据集与模型

这里我使用的是比较经典的图像数据集Flickr15K,里面有33个类别、共计14,660张自然图像,在这里就可以下载,是我缩小过的,反正都是展示下,分辨率也不用那么高。

训练好的模型可以在这里(谷歌)/或者这里(百度 提取码 4vnw)下载。

之前提到,我们自然图像数据集的离线特征,这里提供两条路线:1) 下载我已经提取好的离线特征,这受限于图像数据集,你能检索到的图像就只能是Flickr15K的,我是用numpy来保存的,特征的大小应该是N*WN代表数据集中的图像个数,W是提取的特征向量的维度,一张图像的特征应该是1*WW维向量;2) 使用我训练好的模型,对目标数据集,运行extract_feat_photo.pyextract_feat_sketch.py脚本,提取目标数据集的离线特征,这样就可以检索任意数据集了。


更新 2020/04/23: SketchTriplet的所有东西我都上传到度盘了,链接在这里(71zw),文件内容如下,虽然度盘比较坑,但目前还没有想到另外低成本、稳定的网盘,先将就用吧:

  • SketchTriplet
    • feat
      • feat_sketch.npz # 离线特征集,由330sketches草图产生
      • feat_photo.npz # 离线特征集,由resize_img自然图像产生
    • resize_img.zip # 自然图像集,被缩放至100像素
    • groundtruth # 自然图像集的groundtruth
    • Flickr_15K_edge2.rar # 由自然图像集使用canny改算子求得的边缘集
    • 500.pth # 模型文件
    • 330sketches.zip # 330张草图

上述东西(模型、离线特征集、图像集)都准备好后,我们就可以正式开始了。

剩余内容大致如下,逐渐更新中:

  • 后端搭建
  • 前端搭建
  • 前后端交互
  • demo
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值