心血来潮,想从零开始编写一个相对完整的深度学习小项目。想到就做,那么首先要考虑的问题是,写什么?
思量再三,我决定写一个宠物识别系统,即给定一张图片,判断图片上的宠物是什么。宠物种类暂定为四类——猫、狗、鼠、兔。之所以想到做这个,是因为在不使用公开数据集的情况下,宠物图片数据集获取的难度相对低一些。
小项目分为如下几个部分:
- 爬虫。从网络上下载宠物图片,构建训练用的数据集。
- 模型构建、训练和调优。鉴于我们的数据比较少,这部分需要做迁移学习。
- 模型部署和Web服务。将训练好的模型部署成web接口,并使用Vue.js + Element UI编写测试页面。
好嘞,开搞吧!
本文涉及到的所有代码,均已上传到GitHub:
pets_classifer (https://github.com/AaronJny/pets_classifer)
转载请注明来源:https://blog.csdn.net/aaronjny/article/details/103605988
一、爬虫
训练模型肯定是需要数据集的,那么数据集从哪来?因为是从零开始嘛,假设我们做的这个问题,业内没有公开的数据集,我们需要自己制作数据集。
一个很简单的想法是,利用搜索引擎搜索相关图片,使用爬虫批量下载,然后人工去除不正确的图片。举个例子,我们先处理猫的图片,步骤如下:
- 1.使用搜索引擎搜索猫的图片。
- 2.使用爬虫将搜索出的猫的图片批量下载到本地,放到一个名为
cats
的文件夹里面。 - 3.人工浏览一遍图片,将“不包含猫”的图片和“除猫外还包含其他宠物(狗、鼠、兔)”的图片从文件夹中删除。
这样,猫的图片我们就搜集完成了,其他几个类别的图片也是类似的操作。不用担心人工过滤图片花费的时间较长,全部过一遍也就二十多分钟吧。
然后是搜索引擎的选择。搜索引擎用的比较多的无非两种——Google和百度。我分别使用Google和百度进行了图片搜索,发现百度的搜索结果远不如Google准确,于是就选择了Google,所以我的爬虫代码是基于Google编写的,运行我的爬虫代码需要你的网络能够访问Google。
如果你的网络不能访问Google,可以考虑自行实现基于百度的爬虫程序,逻辑都是相通的。
因为想让项目轻量级一些,故没有使用scrapy框架。爬虫使用requests+beautifulsoup4实现,并发使用gevent实现。
# -*- coding: utf-8 -*-
# @File : spider.py
# @Author : AaronJny
# @Time : 2019/12/16
# @Desc : 从谷歌下载指定图片
from gevent import monkey
monkey.patch_all()
import functools
import logging
import os
from bs4 import BeautifulSoup
from gevent.pool import Pool
import requests
import settings
# 设置日志输出格式
logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',
level=logging.INFO)
# 搜索关键词字典
keywords_map = settings.IMAGE_CLASS_KEYWORD_MAP
# 图片保存根目录
images_root = settings.IMAGES_ROOT
# 每个类别下载多少页图片
download_pages = settings.SPIDER_DOWNLOAD_PAGES
# 图片编号字典,每种图片都从0开始编号,然后递增
images_index_map = dict(zip(keywords_map.keys(), [0 for _ in keywords_map]))
# 图片去重器
duplication_filter = set()
# 请求头
headers = {
'accept-encoding': 'gzip, deflate, br',
'accept-language': 'zh-CN,zh;q=0.9',
'user-agent': 'Mozilla/5.0 (Linux; Android 4.0.4; Galaxy Nexus Build/IMM76B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/46.0.2490.76 Mobile Safari/537.36',
'accept': '*/*',
'referer': 'https://www.google.com/',
'authority': 'www.google.com',
}
# 重试装饰器
def try_again_while_except(max_times=3):
"""
当出现异常时,自动重试。
连续失败max_times次后放弃。
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
error_cnt = 0
error_msg = ''
while error_cnt < max_times:
try:
return func(*args, **kwargs)
except Exception as e:
error_msg = str(e)
error_cnt += 1
if error_msg:
logging.error(error_msg)
return wrapper
return decorator
@try_again_while_except()
def download_image(session, image_url, image_class):
"""
从给定的url中下载图片,并保存到指定路径
"""
# 下载图片
resp = session.get(image_url, timeout=20)
# 检查图片是否下载成功
if resp.status_code != 200:
raise Exception('Response Status Code {}!'.format(resp.status_code))
# 分配一个图片编号
image_index = images_index_map.get(image_class, 0)
# 更新待分配编号
images_index_map[image_class] = image_index + 1
# 拼接图片路径
image_path = os.path.join(images_root, image_class, '{}.jpg'.format(image_index))
# 保存图片
with open(image_path, 'wb') as f:
f.write(resp.content)
# 成功写入了一张图片
return True
@try_again_while_except()
def get_and_analysis_google_search_page(session, page, image_class, keyword):
"""
使用google进行搜索,下载搜索结果页面,解析其中的图片地址,并对有效图片进一步发起请求
"""
logging.info('Class:{} Page:{} Processing...'.format(image_class, page + 1))
# 记录从本页成功下载的图片数量
downloaded_cnt = 0
# 构建请求参数
params = (
('q', keyword),
('tbm', 'isch'),
('async', '_id:islrg_c,_fmt:html'),
('asearch', 'ichunklite'),
('start', str(page * 100)),
('ijn', str(page)),
)
# 进行搜索
resp = requests.get('https://www.google.com/search', params=params, timeout=20)
# 解析搜索结果
bsobj = BeautifulSoup(resp.content, 'lxml')
divs = bsobj.find_all('div', {
'class': 'islrtb isv-r'})
for div in divs:
image_url = div.get('data-ou')
# 只有当图片以'.jpg','.jpeg','.png'结尾时才下载图片
if image_url.endswith('.jpg') or image_url.endswith('.jpeg') or image_url.endswith('.png'):
# 过滤掉相同图片
if image_url not in duplication_filter:
# 使用去重器记录
duplication_filter.add(image_url)
# 下载图片
flag = download_image(session, image_url, image_class)
if flag:
downloaded_cnt += 1
logging.info('Class:{} Page:{} Done. {} images downloaded.'.format(image_class, page + 1, downloaded_cnt))
def search_with_google(image_class, keyword):
"""
通过google下载数据集
"""
# 创建session对象
session = requests.session()
session.headers.update(headers)
# 每个类别下载10页数据
for page