深度学习(六)——CNN识别典型地标建筑,并制作pb文件,部署在android studio端,制作app实现该功能

一、背景

两年前的一个项目,识别典型地标,并对图片进行标记。这里只写一下简单思路,最开始是打算用特征点匹配来做的,可以参考:基于特征点匹配方法——SIFT, SURF, ORB的图像识别 ,后来发现效果很差,所以后面又用了深度学习来做。

二、方法

实现的方法还是比较简单的,核心步骤如下:

(1)采集数据并数据预处理get_image.py

(2)编写神经网络模型CNN.py

(3)得到并查看pb文件checkpb.py

(4)参考tensorflow的官方demo,在移动端部署

(5)制作app并实验验证

三、实现

1. 数据获取

所有过程都是从0开始的,没有数据也要自己创造数据。怎么办呢?最简单的办法就是自己爬数据。

要爬数据,首先需要找到要爬数据的主要内容,既然是典型地标,那么一、二线城市必然都有当地的地标建筑,因此我们需要先找到一、二线的城市名单,可以参考:https://baijiahao.baidu.com/s?id=1589132487171121063&wfr=spider&for=pc

编写爬取数据的程序,并命名为get_image.py文件:

# 导入需要的库
import requests
import os
import json

# 爬取百度图片,解析页面的函数
def getManyPages(keyword, pages):
    '''
    参数keyword:要下载的影像关键词
    参数pages:需要下载的页面数
    '''
    params = []

    for i in range(30, 30 * pages + 30, 30):
        params.append({
            'tn': 'resultjson_com',
            'ipn': 'rj',
            'ct': 201326592,
            'is': '',
            'fp': 'result',
            'queryWord': keyword,
            'cl': 2,
            'lm': -1,
            'ie': 'utf-8',
            'oe': 'utf-8',
            'adpicid': '',
            'st': -1,
            'z': '',
            'ic': 0,
            'word': keyword,
            's': '',
            'se': '',
            'tab': '',
            'width': '',
            'height': '',
            'face': 0,
            'istype': 2,
            'qc': '',
            'nc': 1,
            'fr': '',
            'pn': i,
            'rn': 30,
            'gsm': '1e',
            '1488942260214': ''
        })
    url = 'https://image.baidu.com/search/acjson'
    urls = []
    for i in params:
        try:
            urls.append(requests.get(url, params=i).json().get('data'))
        except json.decoder.JSONDecodeError:
            print("解析出错")
    return urls

# 下载图片并保存
def getImg(dataList, localPath):
    '''
    参数datallist:下载图片的地址集
    参数localPath:保存下载图片的路径
    '''
    if not os.path.exists(localPath):  # 判断是否存在保存路径,如果不存在就创建
        os.mkdir(localPath)
    x = 0
    for list in dataList:
        for i in list:
            if i.get('thumbURL') != None:
                print('正在下载:%s' % i.get('thumbURL'))
                ir = requests.get(i.get('thumbURL'))
                open(localPath + '%d.jpg' % x, 'wb').write(ir.content)
                x += 1
            else:
                print('图片链接不存在')

# 根据关键词来下载图片
if __name__ == '__main__':

    # 参考:https://baijiahao.baidu.com/s?id=1589132487171121063&wfr=spider&for=pc
    # 一线城市:北京、上海、广州、深圳
    # 二线城市:成都市、杭州市、武汉市、重庆市、南京市、天津市、苏州市、
    #         西安市、长沙市、沈阳市、青岛市、郑州市、大连市、东莞市、宁波市
    # 三线城市:厦门市、福州市、无锡市、合肥市、昆明市、哈尔滨市、济南市、佛山市、
    #         长春市、温州市、石家庄市、南宁市、常州市、泉州市、南昌市、贵阳市、
    #         太原市、烟台市、嘉兴市、南通市、金华市、珠海市、惠州市、徐州市、
    #         海口市、乌鲁木齐市、绍兴市、中山市、台州市、兰州市

    # ------------------------一线城市--------------------------------
    # 北京市
    dataList = getManyPages('天安门', 10)
    getImg(dataList, './data/1bj-tiananmen/')
    dataList = getManyPages('北京鸟巢', 10)
    getImg(dataList, './data/1bj-niaochao/')
    dataList = getManyPages('长城', 10)
    getImg(dataList, './data/1bj-changcheng/')
    dataList = getManyPages('天坛', 10)
    getImg(dataList, './data/1bj-tiantan/')
    dataList = getManyPages('中央电视台总部大楼', 10)
    getImg(dataList, './data/1bj-zhongyangdianshitaizongbu/')
    dataList = getManyPages('国家大剧院', 10)
    getImg(dataList, './data/1bj-guojiadajuyuan/')
    dataList = getManyPages('水立方', 10)
    getImg(dataList, './data/1bj-shuilifang/')

    # 上海市
    dataList = getManyPages('东方明珠', 10)
    getImg(dataList, './data/1sh-dongfangmingzhu/')
    dataList = getManyPages('上海世博园', 10)
    getImg(dataList, './data/1sh-shiboyuan/')
    dataList = getManyPages('上海外滩', 10)
    getImg(dataList, './data/1sh-shanghaiwaitan/')
    dataList = getManyPages('杨浦大桥', 10)
    getImg(dataList, './data/1sh-yangpudaqiao/')
    dataList = getManyPages('上海大世界游乐中心', 10)
    getImg(dataList, './data/1sh-youlezhongxin/')
    dataList = getManyPages('上海中心大厦', 10)
    getImg(dataList, './data/1sh-zhongxindasha/')

    # 广州市
    dataList = getManyPages('广州塔', 10)
    getImg(dataList, './data/1gz-guangzhouta/')
    dataList = getManyPages('镇海楼', 10)
    getImg(dataList, './data/1gz-zhenhailou/')
    dataList = getManyPages('广州大剧院', 10)
    getImg(dataList, './data/1gz-dajuyuan/')
    dataList = getManyPages('中山纪念堂', 10)
    getImg(dataList, './data/1gz-zhongshanjiniantang/')
    dataList = getManyPages('五羊石像', 10)
    getImg(dataList, './data/1gz-wuyangshixiang/')

    # 深圳市
    dataList = getManyPages('孺子牛雕塑', 10)
    getImg(dataList, './data/1sz-ruziniu/')
    dataList = getManyPages('世界之窗 世界广场', 10)
    getImg(dataList, './data/1sz-shijiezhichuang/')
    dataList = getManyPages('深圳市民中心', 10)
    getImg(dataList, './data/1sz-shiminzhongxin/')

    # ------------------------二线城市--------------------------------
    # 成都市、
    dataList = getManyPages('天府广场', 10)
    getImg(dataList, './data/2cd-tianfuguangchang/')
    dataList = getManyPages('成都远洋太古里', 10)
    getImg(dataList, './data/2cd-taiguli/')
    dataList = getManyPages('望江楼公园', 10)
    getImg(dataList, './data/2cd-wangjianglou/')
    dataList = getManyPages('九眼桥', 10)
    getImg(dataList, './data/2cd-jiuyanqiao/')
    dataList = getManyPages('新世纪环球中心', 10)
    getImg(dataList, './data/2cd-huanqiuzhongxin/')

    # 杭州市、
    dataList = getManyPages('西湖十景之三潭印月', 10)
    getImg(dataList, './data/2hz-santanyingyu/')
    dataList = getManyPages('六和塔', 10)
    getImg(dataList, './data/2hz-liuheta/')
    dataList = getManyPages('雷峰塔', 10)
    getImg(dataList, './data/2hz-leifengta/')
    dataList = getManyPages('钱江新城市民中心', 10)
    getImg(dataList, './data/2hz-shiminzhongxin/')
    dataList = getManyPages('杭州国际会议中心', 10)
    getImg(dataList, './data/2hz-huiyizhongxin/')
    dataList = getManyPages('望宸阁', 10)
    getImg(dataList, './data/2hz-wangchenge/')

    # 武汉市、
    dataList = getManyPages('黄鹤楼', 10)
    getImg(dataList, './data/2wh-huanghelou/')
    dataList = getManyPages('鄂军都督府', 10)
    getImg(dataList, './data/2wh-ejundudufu/')
    dataList = getManyPages('古德寺', 10)
    getImg(dataList, './data/2wh-gudesi/')

    # 重庆市、
    dataList = getManyPages('重庆解放碑', 10)
    getImg(dataList, './data/2cq-jiefangbei/')
    dataList = getManyPages('重庆人民大礼堂', 10)
    getImg(dataList, './data/2cq-renmindalitang/')
    dataList = getManyPages('重庆大剧院', 10)
    getImg(dataList, './data/2cq-dajuyuan/')
    dataList = getManyPages('重庆黄金双子塔', 10)
    getImg(dataList, './data/2cq-huangjinshuangzita/')

    # 南京市、
    dataList = getManyPages('中山陵', 10)
    getImg(dataList, './data/2nj-zhongshanling/')

    # 天津市、
    dataList = getManyPages('天津之眼', 10)
    getImg(dataList, './data/2tj-tianjingzhiyan/')
    dataList = getManyPages('天津天塔', 10)
    getImg(dataList, './data/2tj-tianta/')
    dataList = getManyPages('望海楼教堂', 10)
    getImg(dataList, './data/2tj-wanghailoujiaotang/')

    # 苏州市、
    dataList = getManyPages('东方之门', 10)
    getImg(dataList, './data/2sz-dongfangzhimen/')
    dataList = getManyPages('环球188', 10)
    getImg(dataList, './data/2sz-huanqiu188/')

    # 西安市、
    dataList = getManyPages('大雁塔', 10)     # 参数1:关键字,参数2:要下载的页数
    getImg(dataList, './data/2xa-dayanta/')            # 参数2:指定保存的路径
    dataList = getManyPages('西安钟楼', 10)
    getImg(dataList, './data/2xa-zhonglou/')
    dataList = getManyPages('西安城墙', 10)
    getImg(dataList, './data/2xa-chengqiang/')
    dataList = getManyPages('兵马俑', 10)
    getImg(dataList, './data/2xa-bingmayong/')
    dataList = getManyPages('西安电视塔', 10)
    getImg(dataList, './data/2xa-dianshita/')
    dataList = getManyPages('西安火车站', 10)
    getImg(dataList, './data/2xa-huochezhan/')

    # 长沙市、
    dataList = getManyPages('岳麓书院', 10)
    getImg(dataList, './data/2cs-yuelushuyuan/')
    dataList = getManyPages('梅溪湖城市岛', 10)
    getImg(dataList, './data/2cs-meixihuchengshidao/')

    # 沈阳市、
    dataList = getManyPages('沈阳故宫', 10)
    getImg(dataList, './data/2sy-shenyanggugong/')

    # 青岛市、
    dataList = getManyPages('青岛栈桥', 10)
    getImg(dataList, './data/2qd-qingdaozhanqiao/')
    dataList = getManyPages('青岛五四广场', 10)
    getImg(dataList, './data/2qd-wusiguangchang/')
    dataList = getManyPages('胶州湾跨海大桥', 10)
    getImg(dataList, './data/2qd-jiaozhouwankuahaidaqiao/')

    # 郑州市、
    dataList = getManyPages('郑州二七纪念塔', 10)
    getImg(dataList, './data/2zz-erqijinianta/')

    # 大连市、
    dataList = getManyPages('大连国际会议中心', 10)
    getImg(dataList, './data/2dl-guojihuiyizhongxin/')

    # 东莞市、
    dataList = getManyPages('东莞市网球中心', 10)
    getImg(dataList, './data/2dg-wangqiuzhongxin/')

    # 宁波市
    dataList = getManyPages('天一阁', 10)
    getImg(dataList, './data/2nb-tianyige/')
    dataList = getManyPages('宁波港口', 10)
    getImg(dataList, './data/2nb-ningbogangkou/')
    dataList = getManyPages('宁波商会国贸大厦', 10)
    getImg(dataList, './data/2nb-shanghuiguomaodasha/')
    dataList = getManyPages('宁波财富中心', 10)
    getImg(dataList, './data/2nb-caifuzhongxin/')
    dataList = getManyPages('宁波保国寺', 10)
    getImg(dataList, './data/2nb-baoguosi/')
    dataList = getManyPages('河姆渡遗址', 10)
    getImg(dataList, './data/2nb-hemuduyizhi/')

    # ------------------------三线城市--------------------------------

2. 数据预处理

爬取到的数据并不是说直接就可以使用的,我们还需要进行预处理。为了方便,我这里只把过宽或者过窄的图像删除,目标不明显的图像进行删除。其实如果有条件的话,也可以做做数据增广处理,具体做法可以参见之前我的博客:从零开始制作人脸表情的数据集,里面有详细介绍如何进行数据增广处理。

爬好的数据统一放到路径'./data/'路径下,该路径下为各个城市相应地标建筑的文件夹:

随便打开一个数据集,其效果为:

我把我用到的所有数据放到了CSDN上,我对原始数据只是做了删减处理,没有做数据的增广。这里给出CSDN的下载地址,有需要的话可以自行下载:。当然,这些数据是自己费了一定的时间精力来做的,所以就收2个积分吧。

3. 编写神经网络模型文件CNN.py

这里我反复修改过很多次,因为最大的问题就在于训练好的pb文件不能太大,毕竟最终是要放到移动端的。测试了一下,pb文件的大小大概是checkpoint文件大小的1/3,最后我构建了一个四层卷积+2层全连接+1层softmax的CNN网络,这样算下来的checkpoint模型文件的大小大概为105M,相应的pb文件大小约为34M,这个大小放到手机端还是可以接受的。

CNN的参考代码网上非常多,这里我参考了:https://blog.csdn.net/u014281392/article/details/74881967
。下面直接给出我的最终代码。

# 参考:https://blog.csdn.net/u014281392/article/details/74881967

import tensorflow as tf
from skimage import io, transform
import glob
import os
import numpy as np
import time
import matplotlib.pyplot as plt
import cv2
from tensorflow.python.framework import graph_util

# 训练过程需要设置的参数
plot_loss = False                               # 是否绘制loss图
train = True                                    # 是否训练,需要每次进行修改
n_epoch = 1000                                  # 训练的总次数
# 测试过程需要设置的参数,test之前需要先train
test = False if train else True                 # 是否测试训练好的模型,需要每次进行修改
test_image = "test_data/4.jpg"                  # 测试图像的路径,需要每次进行修改[1 14]
batch_size = 64 if train else 1


# 类别目录:
# 一线城市:北京、上海、广州、深圳    7+3+4+1=15
# 二线城市:成都市、杭州市、武汉市、重庆市、南京市、天津市、苏州市、
#         西安市、长沙市、沈阳市、青岛市、郑州市、大连市、东莞市、宁波市
classify = [
    "BJ-长城", "BJ-国家大剧院", "BJ-鸟巢", "BJ-水立方", "BJ-天安门", "BJ-天坛", "BJ-中央电视台总部",
    "GZ-广州塔", "GZ-五羊石像", "GZ-镇海楼",
    "SH-东方明珠", "SH-世博园", "SH-杨浦大桥", "SH-游乐中心",
    "SZ-深圳市民中心",
    "CD-九眼桥",
    "CQ-重庆大剧院", "CQ-解放碑", "CQ-人民大礼堂",
    "CS-梅溪湖城市道", "CS-岳麓书院",
    "HZ-六和塔", "HZ-三潭映月", "HZ-市民中心", "HZ-杭州忘尘阁",
    "NB-宁波港口",
    "QD-青岛栈桥", "QD-五四广场",
    "SZ-东方之门",
    "TJ-天津之眼", "TJ-天津天塔", "TJ-望海楼教堂",
    "WH-鄂军都督府", "WH-古德寺", "WH-黄鹤楼",
    "XA-兵马俑", "XA-西安城墙", "XA-大雁塔", "XA-电视塔", "XA-火车站", "XA-钟楼",
    "ZZ-二七纪念塔"
]

# 设置相关参数
path = 'data/'              # 原始数据的路径
width = 128                 # 图像resize后的宽
height = 128                # 图像resize后的高
channel = 3                 # 图像的通道数,一般3表示RGB图像

ratio = 0.8                 # 80%的数据用于训练
model_path = 'checkpoint/'  # checkpoint的存储路径
learning_rate = 0.00005      # 学习率


# 定义读取图片的函数, 并将其resize成width*height尺寸大小
def read_img(image_path):
    cate = [path+f for f in os.listdir(image_path) if os.path.isdir(image_path+f)]
    imgs = []
    labels = []
    for idx, folder in enumerate(cate):
        for im in glob.glob(folder+'/*.jpg'):
            print('reading the images:%s' % im)
            img = io.imread(im)
            try:
                if img.shape[2] == 3:
                    img = transform.resize(img, (width, height))
                    imgs.append(img)
                    labels.append(idx)
            except:
                continue
    # imgs = np.asarray(imgs, np.float32)
    # return imgs, labels
    return np.asarray(imgs, np.float32), np.asarray(labels, np.int32)


def l2_weight_init(shape, stddev, w1):
    weight = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
    if w1 is not None:
        weight_loss = tf.multiply(tf.nn.l2_loss(weight), w1, name='weight_loss')
        tf.add_to_collection('losses', weight_loss)
    return weight


def loss(logit, labels):
    labels = tf.cast(labels, tf.int64)
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logit,
                                                                   labels=labels,
                                                                   name='cross_entropy_per_example')
    # 交叉熵损失
    cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
    # 权重损失
    tf.add_to_collection('losses', cross_entropy_mean)
    return tf.add_n(tf.get_collection('losses'), name='total_loss')


def accuracy(test_labels, test_y_out):
    test_labels = tf.to_int64(test_labels)
    prediction_result = tf.equal(test_labels, tf.argmax(test_y_out, 1))
    accu = tf.reduce_mean(tf.cast(prediction_result, tf.float32))
    return accu


# -----------------构建网络----------------------
# 占位符
x = tf.placeholder(tf.float32, shape=[None, width, height, channel], name='input')
y_ = tf.placeholder(tf.int32, shape=[None, ], name='y_')

global_step = tf.Variable(0, name="global_step", trainable=False)

# 第一个卷积层(128——>64)
weight1 = tf.Variable(tf.truncated_normal([5, 5, 3, 16], stddev=0.05))
biases1 = tf.Variable(tf.random_normal([16]))
conv1 = tf.nn.relu(tf.nn.conv2d(x, weight1, strides=[1, 1, 1, 1],
                                padding='SAME') + biases1)              # shape:[batchsize,128,128,16]
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1],
                       strides=[1, 2, 2, 1], padding='SAME')            # shape:[batchsize,64,64,16]
lrnorm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75)     # shape:[batchsize,64,64,16]


# 第二个卷积层(64——>32)
weight2 = tf.Variable(tf.truncated_normal([5, 5, 16, 32], stddev=0.05))
biases2 = tf.Variable(tf.random_normal([32]))
conv2 = tf.nn.relu(tf.nn.conv2d(lrnorm1, weight2, strides=[1, 1, 1, 1],
                                padding='SAME') + biases2)              # shape:[batchsize,64,64,32]
pool2 = tf.nn.max_pool(conv2, ksize=[1, 3, 3, 1],
                       strides=[1, 2, 2, 1], padding='SAME')            # shape:[batchsize,32,32,32]
lrnorm2 = tf.nn.lrn(pool2, 4, bias=1.0, alpha=0.001/9.0, beta=0.75)     # shape:[batchsize,32,32,32]


# 第三个卷积层(32->16)
weight3 = tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev=0.05))
biases3 = tf.Variable(tf.random_normal([64]))
conv3 = tf.nn.relu(tf.nn.conv2d(lrnorm2, weight3, strides=[1, 1, 1, 1],
                                padding='SAME') + biases3)              # shape:[batchsize,32,32,64]
pool3 = tf.nn.max_pool(conv3, ksize=[1, 3, 3, 1],
                       strides=[1, 2, 2, 1], padding='SAME')            # shape:[batchsize,16,16,64]
lrnorm3 = tf.nn.lrn(pool3, 4, bias=1.0, alpha=0.001/9.0, beta=0.75)     # shape:[batchsize,16,16,64]

# 第四个卷积层(16->8)
weight4 = tf.Variable(tf.truncated_normal([5, 5, 64, 128], stddev=0.05))
biases4 = tf.Variable(tf.random_normal([128]))
conv4 = tf.nn.relu(tf.nn.conv2d(lrnorm3, weight4, strides=[1, 1, 1, 1],
                                padding='SAME') + biases4)              # shape:[batchsize,16,16,128]
lrnorm4 = tf.nn.lrn(conv4, 4, bias=1.0, alpha=0.001/9.0, beta=0.75)     # shape:[batchsize,16,16,128]
pool4 = tf.nn.max_pool(lrnorm4, ksize=[1, 3, 3, 1],
                       strides=[1, 2, 2, 1], padding='SAME')            # shape:[batchsize,8,8,128]


# flatten,池化后的特征转换为一维
reshape = tf.reshape(pool4, [-1, 8*8*128])                           # batchsize x 4096
n_input = 8 * 8 * 128
# reshape = tf.reshape(pool4, [-1, n_input])

# 全连接隐藏层1
weight5 = l2_weight_init([n_input, 1024], 0.05, w1=0.001)
biases5 = tf.Variable(tf.random_normal([1024]))
fullc1 = tf.nn.relu(tf.matmul(reshape, weight5) + biases5)

# 全连接隐藏层2
weight6 = l2_weight_init([1024, 256], 0.05, w1=0.003)
biases6 = tf.Variable(tf.random_normal([256]))
fullc2 = tf.nn.relu(tf.matmul(fullc1, weight6) + biases6)

# output layer
weight7 = tf.Variable(tf.truncated_normal([256, 42], stddev=0.015))
biases7 = tf.Variable(tf.random_normal([42]))
logits = tf.add(tf.matmul(fullc2, weight7), biases7, name='logits')    # 未激活输出
y_out = tf.nn.softmax(logits, name='output')

# ---------------------------网络结束---------------------------

loss = loss(y_out, y_)
train_op = tf.train.AdamOptimizer(learning_rate=0.0005).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(y_out, 1), tf.int32), y_)
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


if train:
    # 读取路径下的所有数据,并打上标签
    data, label = read_img(path)

    # 打乱顺序
    num_example = data.shape[0]
    arr = np.arange(num_example)
    np.random.shuffle(arr)
    data = data[arr]
    label = label[arr]

    # 将所有数据分为训练集和验证集
    # ratio = 0.8
    s = np.int(num_example * ratio)
    x_train = data[:s]
    y_train = label[:s]
    x_val = data[s:]
    y_val = label[s:]

    # 启动会话sess并初始化
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    saver = tf.train.Saver()
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    # 逐个epoch内训练
    Cross_loss = []
    valid_loss = []

    # train_loss, train_acc = 0, 0
    for i in range(sess.run(global_step), n_epoch):
        # 图像每个epoch内可以放(size // batch_size)个size
        train_loss, train_acc = 0, 0
        start_time = time.time()
        for j in range(int(num_example*ratio) // batch_size):

            # 训练一个batch的数据
            batch_end = j * batch_size + batch_size
            # 如果一个batch的最后数目大于数据集的大小,那么就取到数据集的最后一个数据
            if batch_end >= int(num_example*ratio):
                batch_end = int(num_example*ratio) - 1
            # 取出一个batch的数据
            x_value = x_train[j * batch_size: batch_end]
            y_value = y_train[j * batch_size: batch_end]
            # print(x_value.shape, y_value.shape)

            _, err, ac = sess.run([train_op, loss, acc], feed_dict={x: x_value, y_: y_value})
            Cross_loss.append(err)
            # print(err)
            train_loss += err
            train_acc += ac
            every_epoch_time = time.time() - start_time
        print("epoch is :{}, train loss is :{:6F}, train accuracy is:{:6F}, time:{:6F}".
              format(i, train_loss, train_acc, time.time() - start_time))
        # print("epoch is :{}, tra loss is :{:6F}, tra accuracy is:{:6F}, time:{:6F}".format(i, train_loss, train_acc, every_epoch_time))

        # 每个epoch训练完毕之后,进行一次validation
        val_loss, val_acc = 0, 0
        for k in range(int(num_example * (1-ratio)) // batch_size):
            # 训练一个batch的数据
            batch_end = k * batch_size + batch_size
            # 如果一个batch的最后数目大于数据集的大小,那么就取到数据集的最后一个数据
            if batch_end >= int(num_example * (1-ratio)):
                batch_end = int(num_example * (1-ratio)) - 1
            # 取出一个batch的数据
            x_valid = x_val[k * batch_size: batch_end]
            y_valid = y_val[k * batch_size: batch_end]

            v_err, v_ac = sess.run([loss, acc], feed_dict={x: x_valid, y_: y_valid})
            valid_loss.append(v_err)
            val_loss += v_err
            val_acc += v_ac
        print("epoch is :{}, val loss is :{:6F}, val accuracy is:{:6F}, time:{:6F}".
              format(i, val_loss, val_acc, time.time() - start_time))
        # print("epoch is :{}, val loss is :{:6F}, val accuracy is:{:6F}".format(i, val_loss, val_acc))

        # # 保存会话
        # sess.run(tf.assign(global_step, i + 1))
        # saver.save(sess, os.path.join(model_path, "model"), global_step=global_step)

        # 将训练好的模型保存为.pb文件,方便在Android studio中使用
        output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,
                                                                     output_node_names=['input', 'output'])
        with tf.gfile.FastGFile(model_path+'model.pb', mode='wb') as f:  # ’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
            f.write(output_graph_def.SerializeToString())

    if plot_loss:
        fig, ax = plt.subplots(figsize=(13, 6))
        ax.plot(Cross_loss)
        plt.grid()
        plt.title('Train loss')
        plt.show()

if test:

    test_img = io.imread(test_image)
    test_img = transform.resize(test_img, (width, height))
    test_img = [test_img]
    test_img = np.asarray(test_img, np.float32)

    saver = tf.train.Saver()

    with tf.Session().as_default() as sess:
        ckpt = tf.train.get_checkpoint_state(model_path)  # checkpoint file information
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)  # first line
            saver.restore(sess, os.path.join(model_path, ckpt_name))
            print(" [*] Success to read {}".format(ckpt_name))
        else:
            print(" [*] Failed to find a checkpoint")

        # 设置测试输出结果的路径pass
        result = sess.run(y_out, feed_dict={x: test_img})
        print(result)
        # 输出结果的判断,max_value表示分类得分最高的结果
        max_value = max(max(result))
        print(classify[35])
        if max_value > 0.8:
            print("拍摄于: {}, 概率: {:5f}%".format(classify[int(np.argmax(result))],
                                               result[0][np.argmax(result)]*100))
        else:
            print("非典型名胜古迹")

        # 显示最终结果
        img_t = cv2.imread(test_image)
        img_t = cv2.resize(img_t, (600, 400))
        cv2.imshow("test picture", img_t)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

4. 查看pb文件checkpb.py

做好CNN.py文件之后,可以进行训练,并生成pb文件。pb文件是一堆二进制数,直接打开文件当然什么也看不出来。做好的pb文件在路径'./checkpoint/model.pb'路径中,训练完成之后,可以查看pb文件,这里先给出代码:

import tensorflow as tf

model = './checkpoint/model.pb'   #请将这里的model.pb文件路径改为自己的
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read())
tf.import_graph_def(graph_def, name='graph')
summaryWriter = tf.summary.FileWriter('log/', graph)

执行checkpb.py文件。之后我是参考了博文:【Tensorflow】如何有效的查看已有的pb模型文件?后面就非常简单了,因为我是用的pycharm,编译器中自带terminal,打开,输入:

tensorboard --logdir log

按回车之后,会出来一段描述,不用管,里面有一个网址,直接打开:

打开之后,我们就可以看到我们的网络结构了,这里需要记住我们的输入参数名字为'input',输出参数名字为'output',这个后面还会再用到:

5. 在Android studio中修改相应的文件内容

以上的所有内容在python端即可完成,后面的部分就需要在android studio中完成了,要完成这部分内容,可以先查看我上一篇文章的介绍:深度学习(五)——win10环境下将tensorflow的官方demo在Android Studio上运行。不过这篇文章中主要修改的是build.gradle,而本文还需要修改一些参数。

至于这部分内容都需要改哪些文件,这里给出几篇文章的介绍:

[1]Tensorflow lite for 移动端安卓开发(三)——移动端测试自己的模型

[2]Tensorflow移动端之如何将自己训练的MNIST模型加载到Android手机上

实际上移动端可以用tensorflow训练好的pb文件或者tflite文件,这里是我自己做的pb文件,参考了tensorflow的官方demo和一些博客。我对android studio不太熟悉,所以只能简单改改了。

(1)拷贝tensorflow的官方demo到项目路径下

这部分内容在之前的文章中也有提到,可以参考:深度学习(五)——win10环境下将tensorflow的官方demo在Android Studio上运行。这里再简单说一下,首先在github上找到tensorflow的官方源码,地址:https://github.com/tensorflow/tensorflow。进入后直接打包download。

下载好之后是一个解压文件,将其解压,解压好之后,找到路径'./tensorflow-master/tensorflow/examples/android'下的文件,这些文件就是我们需要用到的文件,可以直接将其拷贝出来,放入到项目所在的文件夹中。

(2)按照提示修改build.gradle

这里请直接参考上一篇文章进行修改:深度学习(五)——win10环境下将tensorflow的官方demo在Android Studio上运行

(3)将label.txt和model.pb文件放到assets文件夹下

这一步还比较重要,首先我们需要自己制作label.txt文件,这里也没有什么比较好的办法,只能手打了,需要注意的是,汉字在app中会显示乱码,所以我就写的拼音+英文简单表示:

这里还有一个小细节,尽量在写字板中编辑,这样最终的显示结果没有问题。

做好label.txt之后,我们将之前的pb文件和label.txt一起放入到assets文件夹下,不过项目文件中有两个assets文件夹,我也不知道放到哪个,就两个一起放了:

(4)删除不必要文件

tensorflow的官方demo中并非所有文件都是有用的,我们删除路径'./src/org/tensorflow/demo'下部分没用的文件,理论上只保留和分类有关的文件就可以了,这里我保留的文件都包括:

其中,核心文件是classifier,classifierActivity,TensorflowImageClassifier这三个文件。

(5)修改classifierActivity.java

主要修改里面的参数内容,这里我修改的地方包括:

public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener {

  static {
    System.loadLibrary("tensorflow_inference");
  }

  ......

  // 以下内容需要自己根据自己的实际情况来设计
  private static final int INPUT_SIZE = 128;
  private static final int IMAGE_MEAN = 117;
  private static final float IMAGE_STD = 1;
  private static final String INPUT_NAME = "input";
  private static final String OUTPUT_NAME = "output";

  private static final String MODEL_FILE = "file:///android_asset/model.pb";
  private static final String LABEL_FILE = "file:///android_asset/label.txt";

  private static final boolean MAINTAIN_ASPECT = true;

  private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);

  ......
}

修改的地方主要是网络的输入输出节点,和pb以及label的路径,还有输入网络图像的大小。

(6)修改TensorflowImageClassifer.java

因为我最终是只想显示一个概率最高的结果,如果这个结果的概率相对较高,比如出现了两种很类似的塔,模型将其识别为甲的概率为0.5而乙的概率为0.4,那这种情况还是不要输出了,测试者看不到输出结果,也许会换个更明显的角度来拍摄。因此我们取一个绝对阈值0.8,如果某一物体的识别率低于0.8,那就不再显示,因此需要在文件中修改一下这两个参数:

public class TensorFlowImageClassifier implements Classifier {
  // 由于TAG的字符长度最大为23,因此这里将"TensorflowImageClassifier"修改为"TFImageClassifier"
  private static final String TAG = "TFImageClassifier";

  // 可能结果的数量.
  private static final int MAX_RESULTS = 1;
  private static final float THRESHOLD = 0.8f;

  ......
}

(7)修改androidManifest.xml

这个文件是控制最终生成的app的,原官方demo是生成了4个app,我们不需要那么多,只需要一个分类的app即可,因此找到该文件,将DetectorActivity,StylizeActivity,SpeechActivity三部分内容全部注释掉即可:

6. 用android手机测试

 

四、实验结果

1. CNN网络训练精度为88%

实验设置epoch为1000次,最终的训练精度为88%,实验到epoch=200时,精度约为75%:

运算时间来看,我的电脑配置是GTX 1060 3G,运行一个epoch大约需要7秒,1000个epoch也就是7000秒,即两个小时。下面可以随意看看实验效果:

西安大雁塔,成功识别:

长城,成功识别:

天坛,成功识别:

兵马俑,成功识别:

当然,也有识别错的,比如这张西安钟楼,识别为了兵马俑,可能是由于色调与兵马俑的色调比较类似?

2. app效果

五、分析

1. 所有文件结构为:

-- get_image.py
-- CNN.py
-- checkpb.py
-- data
    |------ 1bj-changcheng
                |------ image1.jpg
                |------ image2.jpg
                |------ ...
    |------ 1bj-guojiadajuyuan
                |------ image1.jpg
                |------ image2.jpg
                |------ ...
    |------ 1bj-niaochao
                |------ image1.jpg
                |------ image2.jpg
                |------ ...
    |------ ......
                |------ ...
                |------ ...
                |------ ...

2. 关于CNN网络,其实我个人是不建议用CNN的,因为有很多冗余的参数占了很多内存,我看了一下tensorflow的官方demo,是用shuffleNet做的,做好的pb文件大小有50+M,shuffleNet的效果要比CNN的效果好很多,要达到同样效果的CNN网络,还需要加深层数和更多参数。具体的shuffleNet的解读可以参考这篇文献:ShuffleNet总结 ,一般是将通道分为3组,这样参数量约为原来的1/3,计算量为1/9,所以效率和大小上都非常适合移动端,传统CNN在部署到移动端还有很多问题。当然,还有mobileNet, squeezeNet等相关轻量级网络也可以关注。

  • 1
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

全部梭哈迟早暴富

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值