一、背景
两年前的一个项目,识别典型地标,并对图片进行标记。这里只写一下简单思路,最开始是打算用特征点匹配来做的,可以参考:基于特征点匹配方法——SIFT, SURF, ORB的图像识别 ,后来发现效果很差,所以后面又用了深度学习来做。
二、方法
实现的方法还是比较简单的,核心步骤如下:
(1)采集数据并数据预处理get_image.py
(2)编写神经网络模型CNN.py
(3)得到并查看pb文件checkpb.py
(4)参考tensorflow的官方demo,在移动端部署
(5)制作app并实验验证
三、实现
1. 数据获取
所有过程都是从0开始的,没有数据也要自己创造数据。怎么办呢?最简单的办法就是自己爬数据。
要爬数据,首先需要找到要爬数据的主要内容,既然是典型地标,那么一、二线城市必然都有当地的地标建筑,因此我们需要先找到一、二线的城市名单,可以参考:https://baijiahao.baidu.com/s?id=1589132487171121063&wfr=spider&for=pc
编写爬取数据的程序,并命名为get_image.py文件:
# 导入需要的库
import requests
import os
import json
# 爬取百度图片,解析页面的函数
def getManyPages(keyword, pages):
'''
参数keyword:要下载的影像关键词
参数pages:需要下载的页面数
'''
params = []
for i in range(30, 30 * pages + 30, 30):
params.append({
'tn': 'resultjson_com',
'ipn': 'rj',
'ct': 201326592,
'is': '',
'fp': 'result',
'queryWord': keyword,
'cl': 2,
'lm': -1,
'ie': 'utf-8',
'oe': 'utf-8',
'adpicid': '',
'st': -1,
'z': '',
'ic': 0,
'word': keyword,
's': '',
'se': '',
'tab': '',
'width': '',
'height': '',
'face': 0,
'istype': 2,
'qc': '',
'nc': 1,
'fr': '',
'pn': i,
'rn': 30,
'gsm': '1e',
'1488942260214': ''
})
url = 'https://image.baidu.com/search/acjson'
urls = []
for i in params:
try:
urls.append(requests.get(url, params=i).json().get('data'))
except json.decoder.JSONDecodeError:
print("解析出错")
return urls
# 下载图片并保存
def getImg(dataList, localPath):
'''
参数datallist:下载图片的地址集
参数localPath:保存下载图片的路径
'''
if not os.path.exists(localPath): # 判断是否存在保存路径,如果不存在就创建
os.mkdir(localPath)
x = 0
for list in dataList:
for i in list:
if i.get('thumbURL') != None:
print('正在下载:%s' % i.get('thumbURL'))
ir = requests.get(i.get('thumbURL'))
open(localPath + '%d.jpg' % x, 'wb').write(ir.content)
x += 1
else:
print('图片链接不存在')
# 根据关键词来下载图片
if __name__ == '__main__':
# 参考:https://baijiahao.baidu.com/s?id=1589132487171121063&wfr=spider&for=pc
# 一线城市:北京、上海、广州、深圳
# 二线城市:成都市、杭州市、武汉市、重庆市、南京市、天津市、苏州市、
# 西安市、长沙市、沈阳市、青岛市、郑州市、大连市、东莞市、宁波市
# 三线城市:厦门市、福州市、无锡市、合肥市、昆明市、哈尔滨市、济南市、佛山市、
# 长春市、温州市、石家庄市、南宁市、常州市、泉州市、南昌市、贵阳市、
# 太原市、烟台市、嘉兴市、南通市、金华市、珠海市、惠州市、徐州市、
# 海口市、乌鲁木齐市、绍兴市、中山市、台州市、兰州市
# ------------------------一线城市--------------------------------
# 北京市
dataList = getManyPages('天安门', 10)
getImg(dataList, './data/1bj-tiananmen/')
dataList = getManyPages('北京鸟巢', 10)
getImg(dataList, './data/1bj-niaochao/')
dataList = getManyPages('长城', 10)
getImg(dataList, './data/1bj-changcheng/')
dataList = getManyPages('天坛', 10)
getImg(dataList, './data/1bj-tiantan/')
dataList = getManyPages('中央电视台总部大楼', 10)
getImg(dataList, './data/1bj-zhongyangdianshitaizongbu/')
dataList = getManyPages('国家大剧院', 10)
getImg(dataList, './data/1bj-guojiadajuyuan/')
dataList = getManyPages('水立方', 10)
getImg(dataList, './data/1bj-shuilifang/')
# 上海市
dataList = getManyPages('东方明珠', 10)
getImg(dataList, './data/1sh-dongfangmingzhu/')
dataList = getManyPages('上海世博园', 10)
getImg(dataList, './data/1sh-shiboyuan/')
dataList = getManyPages('上海外滩', 10)
getImg(dataList, './data/1sh-shanghaiwaitan/')
dataList = getManyPages('杨浦大桥', 10)
getImg(dataList, './data/1sh-yangpudaqiao/')
dataList = getManyPages('上海大世界游乐中心', 10)
getImg(dataList, './data/1sh-youlezhongxin/')
dataList = getManyPages('上海中心大厦', 10)
getImg(dataList, './data/1sh-zhongxindasha/')
# 广州市
dataList = getManyPages('广州塔', 10)
getImg(dataList, './data/1gz-guangzhouta/')
dataList = getManyPages('镇海楼', 10)
getImg(dataList, './data/1gz-zhenhailou/')
dataList = getManyPages('广州大剧院', 10)
getImg(dataList, './data/1gz-dajuyuan/')
dataList = getManyPages('中山纪念堂', 10)
getImg(dataList, './data/1gz-zhongshanjiniantang/')
dataList = getManyPages('五羊石像', 10)
getImg(dataList, './data/1gz-wuyangshixiang/')
# 深圳市
dataList = getManyPages('孺子牛雕塑', 10)
getImg(dataList, './data/1sz-ruziniu/')
dataList = getManyPages('世界之窗 世界广场', 10)
getImg(dataList, './data/1sz-shijiezhichuang/')
dataList = getManyPages('深圳市民中心', 10)
getImg(dataList, './data/1sz-shiminzhongxin/')
# ------------------------二线城市--------------------------------
# 成都市、
dataList = getManyPages('天府广场', 10)
getImg(dataList, './data/2cd-tianfuguangchang/')
dataList = getManyPages('成都远洋太古里', 10)
getImg(dataList, './data/2cd-taiguli/')
dataList = getManyPages('望江楼公园', 10)
getImg(dataList, './data/2cd-wangjianglou/')
dataList = getManyPages('九眼桥', 10)
getImg(dataList, './data/2cd-jiuyanqiao/')
dataList = getManyPages('新世纪环球中心', 10)
getImg(dataList, './data/2cd-huanqiuzhongxin/')
# 杭州市、
dataList = getManyPages('西湖十景之三潭印月', 10)
getImg(dataList, './data/2hz-santanyingyu/')
dataList = getManyPages('六和塔', 10)
getImg(dataList, './data/2hz-liuheta/')
dataList = getManyPages('雷峰塔', 10)
getImg(dataList, './data/2hz-leifengta/')
dataList = getManyPages('钱江新城市民中心', 10)
getImg(dataList, './data/2hz-shiminzhongxin/')
dataList = getManyPages('杭州国际会议中心', 10)
getImg(dataList, './data/2hz-huiyizhongxin/')
dataList = getManyPages('望宸阁', 10)
getImg(dataList, './data/2hz-wangchenge/')
# 武汉市、
dataList = getManyPages('黄鹤楼', 10)
getImg(dataList, './data/2wh-huanghelou/')
dataList = getManyPages('鄂军都督府', 10)
getImg(dataList, './data/2wh-ejundudufu/')
dataList = getManyPages('古德寺', 10)
getImg(dataList, './data/2wh-gudesi/')
# 重庆市、
dataList = getManyPages('重庆解放碑', 10)
getImg(dataList, './data/2cq-jiefangbei/')
dataList = getManyPages('重庆人民大礼堂', 10)
getImg(dataList, './data/2cq-renmindalitang/')
dataList = getManyPages('重庆大剧院', 10)
getImg(dataList, './data/2cq-dajuyuan/')
dataList = getManyPages('重庆黄金双子塔', 10)
getImg(dataList, './data/2cq-huangjinshuangzita/')
# 南京市、
dataList = getManyPages('中山陵', 10)
getImg(dataList, './data/2nj-zhongshanling/')
# 天津市、
dataList = getManyPages('天津之眼', 10)
getImg(dataList, './data/2tj-tianjingzhiyan/')
dataList = getManyPages('天津天塔', 10)
getImg(dataList, './data/2tj-tianta/')
dataList = getManyPages('望海楼教堂', 10)
getImg(dataList, './data/2tj-wanghailoujiaotang/')
# 苏州市、
dataList = getManyPages('东方之门', 10)
getImg(dataList, './data/2sz-dongfangzhimen/')
dataList = getManyPages('环球188', 10)
getImg(dataList, './data/2sz-huanqiu188/')
# 西安市、
dataList = getManyPages('大雁塔', 10) # 参数1:关键字,参数2:要下载的页数
getImg(dataList, './data/2xa-dayanta/') # 参数2:指定保存的路径
dataList = getManyPages('西安钟楼', 10)
getImg(dataList, './data/2xa-zhonglou/')
dataList = getManyPages('西安城墙', 10)
getImg(dataList, './data/2xa-chengqiang/')
dataList = getManyPages('兵马俑', 10)
getImg(dataList, './data/2xa-bingmayong/')
dataList = getManyPages('西安电视塔', 10)
getImg(dataList, './data/2xa-dianshita/')
dataList = getManyPages('西安火车站', 10)
getImg(dataList, './data/2xa-huochezhan/')
# 长沙市、
dataList = getManyPages('岳麓书院', 10)
getImg(dataList, './data/2cs-yuelushuyuan/')
dataList = getManyPages('梅溪湖城市岛', 10)
getImg(dataList, './data/2cs-meixihuchengshidao/')
# 沈阳市、
dataList = getManyPages('沈阳故宫', 10)
getImg(dataList, './data/2sy-shenyanggugong/')
# 青岛市、
dataList = getManyPages('青岛栈桥', 10)
getImg(dataList, './data/2qd-qingdaozhanqiao/')
dataList = getManyPages('青岛五四广场', 10)
getImg(dataList, './data/2qd-wusiguangchang/')
dataList = getManyPages('胶州湾跨海大桥', 10)
getImg(dataList, './data/2qd-jiaozhouwankuahaidaqiao/')
# 郑州市、
dataList = getManyPages('郑州二七纪念塔', 10)
getImg(dataList, './data/2zz-erqijinianta/')
# 大连市、
dataList = getManyPages('大连国际会议中心', 10)
getImg(dataList, './data/2dl-guojihuiyizhongxin/')
# 东莞市、
dataList = getManyPages('东莞市网球中心', 10)
getImg(dataList, './data/2dg-wangqiuzhongxin/')
# 宁波市
dataList = getManyPages('天一阁', 10)
getImg(dataList, './data/2nb-tianyige/')
dataList = getManyPages('宁波港口', 10)
getImg(dataList, './data/2nb-ningbogangkou/')
dataList = getManyPages('宁波商会国贸大厦', 10)
getImg(dataList, './data/2nb-shanghuiguomaodasha/')
dataList = getManyPages('宁波财富中心', 10)
getImg(dataList, './data/2nb-caifuzhongxin/')
dataList = getManyPages('宁波保国寺', 10)
getImg(dataList, './data/2nb-baoguosi/')
dataList = getManyPages('河姆渡遗址', 10)
getImg(dataList, './data/2nb-hemuduyizhi/')
# ------------------------三线城市--------------------------------
2. 数据预处理
爬取到的数据并不是说直接就可以使用的,我们还需要进行预处理。为了方便,我这里只把过宽或者过窄的图像删除,目标不明显的图像进行删除。其实如果有条件的话,也可以做做数据增广处理,具体做法可以参见之前我的博客:从零开始制作人脸表情的数据集,里面有详细介绍如何进行数据增广处理。
爬好的数据统一放到路径'./data/'路径下,该路径下为各个城市相应地标建筑的文件夹:
随便打开一个数据集,其效果为:
我把我用到的所有数据放到了CSDN上,我对原始数据只是做了删减处理,没有做数据的增广。这里给出CSDN的下载地址,有需要的话可以自行下载:。当然,这些数据是自己费了一定的时间精力来做的,所以就收2个积分吧。
3. 编写神经网络模型文件CNN.py
这里我反复修改过很多次,因为最大的问题就在于训练好的pb文件不能太大,毕竟最终是要放到移动端的。测试了一下,pb文件的大小大概是checkpoint文件大小的1/3,最后我构建了一个四层卷积+2层全连接+1层softmax的CNN网络,这样算下来的checkpoint模型文件的大小大概为105M,相应的pb文件大小约为34M,这个大小放到手机端还是可以接受的。
CNN的参考代码网上非常多,这里我参考了:https://blog.csdn.net/u014281392/article/details/74881967
。下面直接给出我的最终代码。
# 参考:https://blog.csdn.net/u014281392/article/details/74881967
import tensorflow as tf
from skimage import io, transform
import glob
import os
import numpy as np
import time
import matplotlib.pyplot as plt
import cv2
from tensorflow.python.framework import graph_util
# 训练过程需要设置的参数
plot_loss = False # 是否绘制loss图
train = True # 是否训练,需要每次进行修改
n_epoch = 1000 # 训练的总次数
# 测试过程需要设置的参数,test之前需要先train
test = False if train else True # 是否测试训练好的模型,需要每次进行修改
test_image = "test_data/4.jpg" # 测试图像的路径,需要每次进行修改[1 14]
batch_size = 64 if train else 1
# 类别目录:
# 一线城市:北京、上海、广州、深圳 7+3+4+1=15
# 二线城市:成都市、杭州市、武汉市、重庆市、南京市、天津市、苏州市、
# 西安市、长沙市、沈阳市、青岛市、郑州市、大连市、东莞市、宁波市
classify = [
"BJ-长城", "BJ-国家大剧院", "BJ-鸟巢", "BJ-水立方", "BJ-天安门", "BJ-天坛", "BJ-中央电视台总部",
"GZ-广州塔", "GZ-五羊石像", "GZ-镇海楼",
"SH-东方明珠", "SH-世博园", "SH-杨浦大桥", "SH-游乐中心",
"SZ-深圳市民中心",
"CD-九眼桥",
"CQ-重庆大剧院", "CQ-解放碑", "CQ-人民大礼堂",
"CS-梅溪湖城市道", "CS-岳麓书院",
"HZ-六和塔", "HZ-三潭映月", "HZ-市民中心", "HZ-杭州忘尘阁",
"NB-宁波港口",
"QD-青岛栈桥", "QD-五四广场",
"SZ-东方之门",
"TJ-天津之眼", "TJ-天津天塔", "TJ-望海楼教堂",
"WH-鄂军都督府", "WH-古德寺", "WH-黄鹤楼",
"XA-兵马俑", "XA-西安城墙", "XA-大雁塔", "XA-电视塔", "XA-火车站", "XA-钟楼",
"ZZ-二七纪念塔"
]
# 设置相关参数
path = 'data/' # 原始数据的路径
width = 128 # 图像resize后的宽
height = 128 # 图像resize后的高
channel = 3 # 图像的通道数,一般3表示RGB图像
ratio = 0.8 # 80%的数据用于训练
model_path = 'checkpoint/' # checkpoint的存储路径
learning_rate = 0.00005 # 学习率
# 定义读取图片的函数, 并将其resize成width*height尺寸大小
def read_img(image_path):
cate = [path+f for f in os.listdir(image_path) if os.path.isdir(image_path+f)]
imgs = []
labels = []
for idx, folder in enumerate(cate):
for im in glob.glob(folder+'/*.jpg'):
print('reading the images:%s' % im)
img = io.imread(im)
try:
if img.shape[2] == 3:
img = transform.resize(img, (width, height))
imgs.append(img)
labels.append(idx)
except:
continue
# imgs = np.asarray(imgs, np.float32)
# return imgs, labels
return np.asarray(imgs, np.float32), np.asarray(labels, np.int32)
def l2_weight_init(shape, stddev, w1):
weight = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
if w1 is not None:
weight_loss = tf.multiply(tf.nn.l2_loss(weight), w1, name='weight_loss')
tf.add_to_collection('losses', weight_loss)
return weight
def loss(logit, labels):
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logit,
labels=labels,
name='cross_entropy_per_example')
# 交叉熵损失
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
# 权重损失
tf.add_to_collection('losses', cross_entropy_mean)
return tf.add_n(tf.get_collection('losses'), name='total_loss')
def accuracy(test_labels, test_y_out):
test_labels = tf.to_int64(test_labels)
prediction_result = tf.equal(test_labels, tf.argmax(test_y_out, 1))
accu = tf.reduce_mean(tf.cast(prediction_result, tf.float32))
return accu
# -----------------构建网络----------------------
# 占位符
x = tf.placeholder(tf.float32, shape=[None, width, height, channel], name='input')
y_ = tf.placeholder(tf.int32, shape=[None, ], name='y_')
global_step = tf.Variable(0, name="global_step", trainable=False)
# 第一个卷积层(128——>64)
weight1 = tf.Variable(tf.truncated_normal([5, 5, 3, 16], stddev=0.05))
biases1 = tf.Variable(tf.random_normal([16]))
conv1 = tf.nn.relu(tf.nn.conv2d(x, weight1, strides=[1, 1, 1, 1],
padding='SAME') + biases1) # shape:[batchsize,128,128,16]
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1], padding='SAME') # shape:[batchsize,64,64,16]
lrnorm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # shape:[batchsize,64,64,16]
# 第二个卷积层(64——>32)
weight2 = tf.Variable(tf.truncated_normal([5, 5, 16, 32], stddev=0.05))
biases2 = tf.Variable(tf.random_normal([32]))
conv2 = tf.nn.relu(tf.nn.conv2d(lrnorm1, weight2, strides=[1, 1, 1, 1],
padding='SAME') + biases2) # shape:[batchsize,64,64,32]
pool2 = tf.nn.max_pool(conv2, ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1], padding='SAME') # shape:[batchsize,32,32,32]
lrnorm2 = tf.nn.lrn(pool2, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # shape:[batchsize,32,32,32]
# 第三个卷积层(32->16)
weight3 = tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev=0.05))
biases3 = tf.Variable(tf.random_normal([64]))
conv3 = tf.nn.relu(tf.nn.conv2d(lrnorm2, weight3, strides=[1, 1, 1, 1],
padding='SAME') + biases3) # shape:[batchsize,32,32,64]
pool3 = tf.nn.max_pool(conv3, ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1], padding='SAME') # shape:[batchsize,16,16,64]
lrnorm3 = tf.nn.lrn(pool3, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # shape:[batchsize,16,16,64]
# 第四个卷积层(16->8)
weight4 = tf.Variable(tf.truncated_normal([5, 5, 64, 128], stddev=0.05))
biases4 = tf.Variable(tf.random_normal([128]))
conv4 = tf.nn.relu(tf.nn.conv2d(lrnorm3, weight4, strides=[1, 1, 1, 1],
padding='SAME') + biases4) # shape:[batchsize,16,16,128]
lrnorm4 = tf.nn.lrn(conv4, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # shape:[batchsize,16,16,128]
pool4 = tf.nn.max_pool(lrnorm4, ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1], padding='SAME') # shape:[batchsize,8,8,128]
# flatten,池化后的特征转换为一维
reshape = tf.reshape(pool4, [-1, 8*8*128]) # batchsize x 4096
n_input = 8 * 8 * 128
# reshape = tf.reshape(pool4, [-1, n_input])
# 全连接隐藏层1
weight5 = l2_weight_init([n_input, 1024], 0.05, w1=0.001)
biases5 = tf.Variable(tf.random_normal([1024]))
fullc1 = tf.nn.relu(tf.matmul(reshape, weight5) + biases5)
# 全连接隐藏层2
weight6 = l2_weight_init([1024, 256], 0.05, w1=0.003)
biases6 = tf.Variable(tf.random_normal([256]))
fullc2 = tf.nn.relu(tf.matmul(fullc1, weight6) + biases6)
# output layer
weight7 = tf.Variable(tf.truncated_normal([256, 42], stddev=0.015))
biases7 = tf.Variable(tf.random_normal([42]))
logits = tf.add(tf.matmul(fullc2, weight7), biases7, name='logits') # 未激活输出
y_out = tf.nn.softmax(logits, name='output')
# ---------------------------网络结束---------------------------
loss = loss(y_out, y_)
train_op = tf.train.AdamOptimizer(learning_rate=0.0005).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(y_out, 1), tf.int32), y_)
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
if train:
# 读取路径下的所有数据,并打上标签
data, label = read_img(path)
# 打乱顺序
num_example = data.shape[0]
arr = np.arange(num_example)
np.random.shuffle(arr)
data = data[arr]
label = label[arr]
# 将所有数据分为训练集和验证集
# ratio = 0.8
s = np.int(num_example * ratio)
x_train = data[:s]
y_train = label[:s]
x_val = data[s:]
y_val = label[s:]
# 启动会话sess并初始化
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
saver = tf.train.Saver()
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
# 逐个epoch内训练
Cross_loss = []
valid_loss = []
# train_loss, train_acc = 0, 0
for i in range(sess.run(global_step), n_epoch):
# 图像每个epoch内可以放(size // batch_size)个size
train_loss, train_acc = 0, 0
start_time = time.time()
for j in range(int(num_example*ratio) // batch_size):
# 训练一个batch的数据
batch_end = j * batch_size + batch_size
# 如果一个batch的最后数目大于数据集的大小,那么就取到数据集的最后一个数据
if batch_end >= int(num_example*ratio):
batch_end = int(num_example*ratio) - 1
# 取出一个batch的数据
x_value = x_train[j * batch_size: batch_end]
y_value = y_train[j * batch_size: batch_end]
# print(x_value.shape, y_value.shape)
_, err, ac = sess.run([train_op, loss, acc], feed_dict={x: x_value, y_: y_value})
Cross_loss.append(err)
# print(err)
train_loss += err
train_acc += ac
every_epoch_time = time.time() - start_time
print("epoch is :{}, train loss is :{:6F}, train accuracy is:{:6F}, time:{:6F}".
format(i, train_loss, train_acc, time.time() - start_time))
# print("epoch is :{}, tra loss is :{:6F}, tra accuracy is:{:6F}, time:{:6F}".format(i, train_loss, train_acc, every_epoch_time))
# 每个epoch训练完毕之后,进行一次validation
val_loss, val_acc = 0, 0
for k in range(int(num_example * (1-ratio)) // batch_size):
# 训练一个batch的数据
batch_end = k * batch_size + batch_size
# 如果一个batch的最后数目大于数据集的大小,那么就取到数据集的最后一个数据
if batch_end >= int(num_example * (1-ratio)):
batch_end = int(num_example * (1-ratio)) - 1
# 取出一个batch的数据
x_valid = x_val[k * batch_size: batch_end]
y_valid = y_val[k * batch_size: batch_end]
v_err, v_ac = sess.run([loss, acc], feed_dict={x: x_valid, y_: y_valid})
valid_loss.append(v_err)
val_loss += v_err
val_acc += v_ac
print("epoch is :{}, val loss is :{:6F}, val accuracy is:{:6F}, time:{:6F}".
format(i, val_loss, val_acc, time.time() - start_time))
# print("epoch is :{}, val loss is :{:6F}, val accuracy is:{:6F}".format(i, val_loss, val_acc))
# # 保存会话
# sess.run(tf.assign(global_step, i + 1))
# saver.save(sess, os.path.join(model_path, "model"), global_step=global_step)
# 将训练好的模型保存为.pb文件,方便在Android studio中使用
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,
output_node_names=['input', 'output'])
with tf.gfile.FastGFile(model_path+'model.pb', mode='wb') as f: # ’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
f.write(output_graph_def.SerializeToString())
if plot_loss:
fig, ax = plt.subplots(figsize=(13, 6))
ax.plot(Cross_loss)
plt.grid()
plt.title('Train loss')
plt.show()
if test:
test_img = io.imread(test_image)
test_img = transform.resize(test_img, (width, height))
test_img = [test_img]
test_img = np.asarray(test_img, np.float32)
saver = tf.train.Saver()
with tf.Session().as_default() as sess:
ckpt = tf.train.get_checkpoint_state(model_path) # checkpoint file information
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line
saver.restore(sess, os.path.join(model_path, ckpt_name))
print(" [*] Success to read {}".format(ckpt_name))
else:
print(" [*] Failed to find a checkpoint")
# 设置测试输出结果的路径pass
result = sess.run(y_out, feed_dict={x: test_img})
print(result)
# 输出结果的判断,max_value表示分类得分最高的结果
max_value = max(max(result))
print(classify[35])
if max_value > 0.8:
print("拍摄于: {}, 概率: {:5f}%".format(classify[int(np.argmax(result))],
result[0][np.argmax(result)]*100))
else:
print("非典型名胜古迹")
# 显示最终结果
img_t = cv2.imread(test_image)
img_t = cv2.resize(img_t, (600, 400))
cv2.imshow("test picture", img_t)
cv2.waitKey(0)
cv2.destroyAllWindows()
4. 查看pb文件checkpb.py
做好CNN.py文件之后,可以进行训练,并生成pb文件。pb文件是一堆二进制数,直接打开文件当然什么也看不出来。做好的pb文件在路径'./checkpoint/model.pb'路径中,训练完成之后,可以查看pb文件,这里先给出代码:
import tensorflow as tf
model = './checkpoint/model.pb' #请将这里的model.pb文件路径改为自己的
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read())
tf.import_graph_def(graph_def, name='graph')
summaryWriter = tf.summary.FileWriter('log/', graph)
执行checkpb.py文件。之后我是参考了博文:【Tensorflow】如何有效的查看已有的pb模型文件?后面就非常简单了,因为我是用的pycharm,编译器中自带terminal,打开,输入:
tensorboard --logdir log
按回车之后,会出来一段描述,不用管,里面有一个网址,直接打开:
打开之后,我们就可以看到我们的网络结构了,这里需要记住我们的输入参数名字为'input',输出参数名字为'output',这个后面还会再用到:
5. 在Android studio中修改相应的文件内容
以上的所有内容在python端即可完成,后面的部分就需要在android studio中完成了,要完成这部分内容,可以先查看我上一篇文章的介绍:深度学习(五)——win10环境下将tensorflow的官方demo在Android Studio上运行。不过这篇文章中主要修改的是build.gradle,而本文还需要修改一些参数。
至于这部分内容都需要改哪些文件,这里给出几篇文章的介绍:
[1]Tensorflow lite for 移动端安卓开发(三)——移动端测试自己的模型
[2]Tensorflow移动端之如何将自己训练的MNIST模型加载到Android手机上
实际上移动端可以用tensorflow训练好的pb文件或者tflite文件,这里是我自己做的pb文件,参考了tensorflow的官方demo和一些博客。我对android studio不太熟悉,所以只能简单改改了。
(1)拷贝tensorflow的官方demo到项目路径下
这部分内容在之前的文章中也有提到,可以参考:深度学习(五)——win10环境下将tensorflow的官方demo在Android Studio上运行。这里再简单说一下,首先在github上找到tensorflow的官方源码,地址:https://github.com/tensorflow/tensorflow。进入后直接打包download。
下载好之后是一个解压文件,将其解压,解压好之后,找到路径'./tensorflow-master/tensorflow/examples/android'下的文件,这些文件就是我们需要用到的文件,可以直接将其拷贝出来,放入到项目所在的文件夹中。
(2)按照提示修改build.gradle
这里请直接参考上一篇文章进行修改:深度学习(五)——win10环境下将tensorflow的官方demo在Android Studio上运行。
(3)将label.txt和model.pb文件放到assets文件夹下
这一步还比较重要,首先我们需要自己制作label.txt文件,这里也没有什么比较好的办法,只能手打了,需要注意的是,汉字在app中会显示乱码,所以我就写的拼音+英文简单表示:
这里还有一个小细节,尽量在写字板中编辑,这样最终的显示结果没有问题。
做好label.txt之后,我们将之前的pb文件和label.txt一起放入到assets文件夹下,不过项目文件中有两个assets文件夹,我也不知道放到哪个,就两个一起放了:
(4)删除不必要文件
tensorflow的官方demo中并非所有文件都是有用的,我们删除路径'./src/org/tensorflow/demo'下部分没用的文件,理论上只保留和分类有关的文件就可以了,这里我保留的文件都包括:
其中,核心文件是classifier,classifierActivity,TensorflowImageClassifier这三个文件。
(5)修改classifierActivity.java
主要修改里面的参数内容,这里我修改的地方包括:
public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener {
static {
System.loadLibrary("tensorflow_inference");
}
......
// 以下内容需要自己根据自己的实际情况来设计
private static final int INPUT_SIZE = 128;
private static final int IMAGE_MEAN = 117;
private static final float IMAGE_STD = 1;
private static final String INPUT_NAME = "input";
private static final String OUTPUT_NAME = "output";
private static final String MODEL_FILE = "file:///android_asset/model.pb";
private static final String LABEL_FILE = "file:///android_asset/label.txt";
private static final boolean MAINTAIN_ASPECT = true;
private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
......
}
修改的地方主要是网络的输入输出节点,和pb以及label的路径,还有输入网络图像的大小。
(6)修改TensorflowImageClassifer.java
因为我最终是只想显示一个概率最高的结果,如果这个结果的概率相对较高,比如出现了两种很类似的塔,模型将其识别为甲的概率为0.5而乙的概率为0.4,那这种情况还是不要输出了,测试者看不到输出结果,也许会换个更明显的角度来拍摄。因此我们取一个绝对阈值0.8,如果某一物体的识别率低于0.8,那就不再显示,因此需要在文件中修改一下这两个参数:
public class TensorFlowImageClassifier implements Classifier {
// 由于TAG的字符长度最大为23,因此这里将"TensorflowImageClassifier"修改为"TFImageClassifier"
private static final String TAG = "TFImageClassifier";
// 可能结果的数量.
private static final int MAX_RESULTS = 1;
private static final float THRESHOLD = 0.8f;
......
}
(7)修改androidManifest.xml
这个文件是控制最终生成的app的,原官方demo是生成了4个app,我们不需要那么多,只需要一个分类的app即可,因此找到该文件,将DetectorActivity,StylizeActivity,SpeechActivity三部分内容全部注释掉即可:
6. 用android手机测试
四、实验结果
1. CNN网络训练精度为88%
实验设置epoch为1000次,最终的训练精度为88%,实验到epoch=200时,精度约为75%:
运算时间来看,我的电脑配置是GTX 1060 3G,运行一个epoch大约需要7秒,1000个epoch也就是7000秒,即两个小时。下面可以随意看看实验效果:
西安大雁塔,成功识别:
长城,成功识别:
天坛,成功识别:
兵马俑,成功识别:
当然,也有识别错的,比如这张西安钟楼,识别为了兵马俑,可能是由于色调与兵马俑的色调比较类似?
2. app效果
五、分析
1. 所有文件结构为:
-- get_image.py
-- CNN.py
-- checkpb.py
-- data
|------ 1bj-changcheng
|------ image1.jpg
|------ image2.jpg
|------ ...
|------ 1bj-guojiadajuyuan
|------ image1.jpg
|------ image2.jpg
|------ ...
|------ 1bj-niaochao
|------ image1.jpg
|------ image2.jpg
|------ ...
|------ ......
|------ ...
|------ ...
|------ ...
2. 关于CNN网络,其实我个人是不建议用CNN的,因为有很多冗余的参数占了很多内存,我看了一下tensorflow的官方demo,是用shuffleNet做的,做好的pb文件大小有50+M,shuffleNet的效果要比CNN的效果好很多,要达到同样效果的CNN网络,还需要加深层数和更多参数。具体的shuffleNet的解读可以参考这篇文献:ShuffleNet总结 ,一般是将通道分为3组,这样参数量约为原来的1/3,计算量为1/9,所以效率和大小上都非常适合移动端,传统CNN在部署到移动端还有很多问题。当然,还有mobileNet, squeezeNet等相关轻量级网络也可以关注。