手把手教你使用pytorch+flask搭建草图检索系统(一)
文章目录
一. 写在前面
弄好了模型该怎么展示,不可能就简单地输入输出把结果拿给别人看吧,总要个界面吧,总要让人操作吧。模型训练大多都是在python上运行的,虽然c语言系的web框架有很多,但把用python写的模型转成c++能用的模型或许有点困难(有,比如pytorch1.0支持的c++ interface, torch::jit),行,找个python能使的web框架吧——flask,要的就是它的简单、轻量、易上手、开发速度快,不用动太多脑筋,前后台开发我可是一窍不通的。
更新:目前pytorch1.3.0及以上版本,转换成c++能运行的模型,实际上已经比较简单了,使用如下命令就可以一键转换了,简单说就是trace输入tensor经过网络的所有信息,不过无法trace例如printf、for循环等内容,lstm可能有点麻烦。
import torch
input = torch.rand(n, c, w, h).to(device)
traced_script_module = torch.jit.trace(model, (input, input, input))
traced_script_module.save(output_path)
1.1 系统框架说明
1.1.1 检索网络
模型是采用的孪生神经网络SketchTriplet,原文为CVIU 2016 paper “Compact Descriptors for Sketch-based Image Retrieval using a Triplet loss Convolutional Neural Network”[Repo|Page],github上的代码是基于caffe的,我这里有份我自己写的pytorch版的代码,网络结构与原版近似,结果也是比较好的。
该网络使用pytorch自带的Triplet Loss,在训练过程中,网络输入为手绘草图与[同类自然图像, 异类自然图像],encoder获得输入图像的特征,Triplet Loss使得encoder在获取特征上让手绘草图与同类自然图像的特征更加近似、手绘草图与异类自然图像的特征的差异更加明显,训练中不断优化这个encoder,在训练结束后,就可以用优化好的encoder提取数据库中自然图像的特征向量并保存,在检索阶段,对输入草图使用训练得到的encoder获取草图特征向量,与数据库中自然图像的特征向量进行相似度比较,从而得到检索结果。
论文中使用了anchor
、positive
、negative
这三个标签,所以按照上面的思路,就是将anchor
的特征与positive
尽量接近,与negative
尽量远离了。这里附上论文中的网络结构图,从图中可以看出,论文中还有个half-sharing
的结构,positive
与negative
共用encoder/decoder的权重,positive
与anchor
则是共用decoder的权重,我猜这样的用意是,用相同的encoder来理解输入的negative
与positive
,毕竟这个negative
与positive
是自然图像提取的边缘,与anchor
是手绘草图不同,而使用同样的decoder来解释特征,是因为预期encoder能得到同样的特征吧。在代码中,我通过建立两套encoder与decoder来达成half-sharing
结构。
1.1.2 检索框架 top-k
极为简单的L2距离,输入草图,经过网络的处理,得到该草图的特征向量,然后与数据集(自然图像)的(离线的,已经提取好的)特征向量进行相似度比较,也就是开头说的L2距离,然后排序,返回前k张自然图像。
1.2 准备工作
1.2.1 环境配置
下载flask,推荐使用anaconda,顺带pytorch什么的就可以一并安装了。装这些东西的教程网上有大把,这里我就提一下,anaconda之类的东西,装到想要的目录,最好别让它装到默认目录,也就是/home/<user_name>/
里面,会让这个挂载盘符越来越小,~
或者/home/<user_name>
可以当做win的c盘一样的东西,满了就比较麻烦。下面是一些简单步骤:
# for anaconda
wget https://repo.anaconda.com/archive/Anaconda3-2019.07-Linux-x86_64.sh
sh ./Anaconda3-2019.07-Linux-x86_64.sh
vim ~/.bashrc
--------------- in ~/.bashrc file, and wq
--- # add anaconda env
--- export PATH=<install_path>/anaconda3/bin:$PATH
---------------
source ~/.bashrc
# for pytorch
conda install pytorch
# for flask
conda install flask
1.2.1 检索数据集与模型
这里我使用的是比较经典的图像数据集Flickr15K,里面有33个类别、共计14,660张自然图像,在这里就可以下载,是我缩小过的,反正都是展示下,分辨率也不用那么高。
训练好的模型可以在这里(谷歌)/或者这里(百度 提取码 4vnw)下载。
之前提到,我们自然图像数据集的离线特征,这里提供两条路线:1) 下载我已经提取好的离线特征,这受限于图像数据集,你能检索到的图像就只能是Flickr15K的,我是用numpy来保存的,特征的大小应该是N*W
,N
代表数据集中的图像个数,W
是提取的特征向量的维度,一张图像的特征应该是1*W
的W
维向量;2) 使用我训练好的模型,对目标数据集,运行extract_feat_photo.py
和extract_feat_sketch.py
脚本,提取目标数据集的离线特征,这样就可以检索任意数据集了。
更新 2020/04/23: SketchTriplet的所有东西我都上传到度盘了,链接在这里(71zw),文件内容如下,虽然度盘比较坑,但目前还没有想到另外低成本、稳定的网盘,先将就用吧:
- SketchTriplet
- feat
- feat_sketch.npz # 离线特征集,由330sketches草图产生
- feat_photo.npz # 离线特征集,由resize_img自然图像产生
- resize_img.zip # 自然图像集,被缩放至100像素
- groundtruth # 自然图像集的groundtruth
- Flickr_15K_edge2.rar # 由自然图像集使用
canny改
算子求得的边缘集 - 500.pth # 模型文件
- 330sketches.zip # 330张草图
- feat
上述东西(模型、离线特征集、图像集)都准备好后,我们就可以正式开始了。
剩余内容大致如下,逐渐更新中:
- 后端搭建
- 前端搭建
- 前后端交互
- demo