爬取百度图片,使用keras进行图片识别

拉取data部分
config.py

# 关键词, 改为你想输入的词即可, 相当于在百度图片里搜索一样
keyword = '熊猫'

# 最大下载数量
max_download_images = 1000

# 精简一下网址,去掉网址中无意义的参数
url_init_first = 'https://image.baidu.com/search/flip?tn=baiduimage&word='

# 表头
headers = {
    'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 11_2_3) AppleWebKit/537.36 (KHTML, like Gecko) '
                  'Chrome/88.0.4324.192 Safari/537.36'
}

获取图片模块


import os
import re
from typing import List, Tuple
from urllib.parse import quote

import requests

from conf import *


def get_page_urls(page_url: str, headers: dict) -> Tuple[List[str], str]:
    """获取当前翻页的所有图片的链接
    Args:
        page_url: 当前翻页的链接
        headers: 请求表头
    Returns:
        当前翻页下的所有图片的链接, 当前翻页的下一翻页的链接
    """
    if not page_url:
        return [], ''
    try:
        html = requests.get(page_url, headers=headers)
        html.encoding = 'utf-8'
        html = html.text
    except IOError as e:
        print(e)
        return [], ''
    pic_urls = re.findall('"objURL":"(.*?)",', html, re.S)
    next_page_url = re.findall(re.compile(r'<a href="(.*)" class="n">下一页</a>'), html, flags=0)
    next_page_url = 'http://image.baidu.com' + next_page_url[0] if next_page_url else ''
    return pic_urls, next_page_url


def down_pic(pic_urls: List[str], max_download_images: int) -> None:
    """给出图片链接列表,下载所有图片
    Args:
        pic_urls: 图片链接列表
        max_download_images: 最大下载数量
    """
    pic_urls = pic_urls[:max_download_images]
    for i, pic_url in enumerate(pic_urls):
        try:
            pic = requests.get(pic_url, timeout=15)
            image_output_path = './images/' + str(i + 1) + '.jpg'
            with open(image_output_path, 'wb') as f:
                f.write(pic.content)
                print('成功下载第%s张图片: %s' % (str(i + 1), str(pic_url)))
        except IOError as e:
            print('下载第%s张图片时失败: %s' % (str(i + 1), str(pic_url)))
            print(e)
            continue


if __name__ == '__main__':
    url_init = url_init_first + quote(keyword, safe='/')
    all_pic_urls = []
    page_urls, next_page_url = get_page_urls(url_init, headers)
    all_pic_urls.extend(page_urls)

    page_count = 0  # 累计翻页数
    if not os.path.exists('./images'):
        os.mkdir('./images')

    # 获取图片链接
    while 1:
        page_urls, next_page_url = get_page_urls(next_page_url, headers)
        page_count += 1
        print('正在获取第%s个翻页的所有图片链接' % str(page_count))
        if next_page_url == '' and page_urls == []:
            print('已到最后一页,共计%s个翻页' % page_count)
            break
        all_pic_urls.extend(page_urls)
        if len(all_pic_urls) >= max_download_images:
            print('已达到设置的最大下载数量%s' % max_download_images)
            break

    down_pic(list(set(all_pic_urls)), max_download_images)

读取图片文件夹和训练模块

import pickle
import cv2
import os
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
from keras.models import Sequential
from keras.layers.core import Dropout
import matplotlib.pyplot as plt
from keras.layers.core import Dense
from sklearn.metrics import classification_report
from keras.optimizers import SGD
from keras import initializers
from keras import regularizers



def cv_show(img):
    cv2.imshow('name',img)
    cv2.waitKey(10)
    cv2.destroyAllWindows()

print()
Pathlist=[]
#1:获取文件下所有的文件名字
for (filepath,dirnames,filenames)  in os.walk(os.getcwd()+'/images/train/'):
    for filename in filenames:
        Pathlist.append(filepath+'/'+filename)
data = []
labels =[]

for name in Pathlist:
    str='.'+name.split('keras_dnn')[1]
    img=cv2.imread(str)
    if img is None:
        continue
    #把数据做成一维的
    img=cv2.resize(img,(32,32)).flatten()
    data.append(img)
    #从路径上获取标签,cat dog,pandas
    label=name.split('/')[-2]
    labels.append(label)
#对图像进行scala操作,把值全部都放在(0,1)之间
data=np.array(data,dtype='float')/255.0
print(len(labels))
labels=np.array(labels)
print(labels)

#数据集切分
(trainx,testx,trainy,testy)=train_test_split(data,labels, test_size=0.25, random_state=42)
#将label转换
lb=LabelBinarizer()
trainy = lb.fit_transform(trainy)
testy = lb.transform(testy)


#网络模型结构
model=Sequential()
model.add(Dense(512, input_shape=(3072,), activation="relu" ,
                kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),
                kernel_regularizer=regularizers.l2(0.01)))
#正则化
model.add(Dropout(0.5))
model.add(Dense(256, activation="relu",
                kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),
                kernel_regularizer=regularizers.l2(0.01)))
model.add(Dropout(0.5))
model.add(Dense(len(lb.classes_), activation="softmax",
                kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),
                kernel_regularizer=regularizers.l2(0.01)))


#学习率
LR = 0.01
epochs = 50
batch_size=32


#指定损失函数和评估方法
model.compile(loss="categorical_crossentropy", optimizer=SGD(lr=LR),
              metrics=["accuracy"])

#训练网络模型
H = model.fit(trainx, trainy, validation_data=(testx, testy),
          epochs=epochs, batch_size=batch_size)

# 测试网络模型
print("[INFO] 正在评估模型")
predictions = model.predict(testx, batch_size=batch_size)

print(classification_report(testy.argmax(axis=1),
                            predictions.argmax(axis=1), target_names=lb.classes_))
#当训练完成时,绘制结果曲线
N = np.arange(0, epochs)
plt.style.use("ggplot")
plt.figure()

#loss、acc ,val_loss、val_acc 分别表示训练集的准确度和损失值、以及验证集的准确度和损失值,注意是验证集的而不是测试集的
plt.plot(N, H.history["loss"], label="train_loss")
plt.plot(N, H.history["val_loss"], label="val_loss")
plt.plot(N, H.history["acc"], label="train_acc")
plt.plot(N, H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy (Simple NN)")
plt.xlabel("Epoch ")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.show()
plt.savefig('.result.png')

# 保存模型到本地
print("保存模型")
model.save('.model')
#保存labels
f = open('.labels', "wb")

print(pickle.dumps(lb))
f.write(pickle.dumps(lb))
f.close()

预测模块

from keras.models import load_model
import cv2
import pickle

img=cv2.imread('D:/worksplace/python/keras/demo/keras_dnn/images/test/3.jpg')

data=img.copy()
data=data.astype('float')/255
data=cv2.resize(data,(32,32)).flatten()
data = data.reshape((1, data.shape[0]))

#加载模型文件.model
model=load_model('.model')
#预测
pred=model.predict(data)

#label=['cat','dog','pandas']

#加载标签.labels文件
lb = pickle.loads(open(".labels", "rb").read())
i = pred.argmax(axis=1)[0]
label = lb.classes_[i]
print(label)
# 在图像中把结果画出来
#准确率:pred[0][i] * 100
text = "{}: {:.2f}%".format(label, pred[0][i] * 100)
cv2.putText(img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
            (0, 0, 255), 2)

cv2.imshow("img", img)
cv2.waitKey(0)






  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值