拉取data部分
config.py
# 关键词, 改为你想输入的词即可, 相当于在百度图片里搜索一样
keyword = '熊猫'
# 最大下载数量
max_download_images = 1000
# 精简一下网址,去掉网址中无意义的参数
url_init_first = 'https://image.baidu.com/search/flip?tn=baiduimage&word='
# 表头
headers = {
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 11_2_3) AppleWebKit/537.36 (KHTML, like Gecko) '
'Chrome/88.0.4324.192 Safari/537.36'
}
获取图片模块
import os
import re
from typing import List, Tuple
from urllib.parse import quote
import requests
from conf import *
def get_page_urls(page_url: str, headers: dict) -> Tuple[List[str], str]:
"""获取当前翻页的所有图片的链接
Args:
page_url: 当前翻页的链接
headers: 请求表头
Returns:
当前翻页下的所有图片的链接, 当前翻页的下一翻页的链接
"""
if not page_url:
return [], ''
try:
html = requests.get(page_url, headers=headers)
html.encoding = 'utf-8'
html = html.text
except IOError as e:
print(e)
return [], ''
pic_urls = re.findall('"objURL":"(.*?)",', html, re.S)
next_page_url = re.findall(re.compile(r'<a href="(.*)" class="n">下一页</a>'), html, flags=0)
next_page_url = 'http://image.baidu.com' + next_page_url[0] if next_page_url else ''
return pic_urls, next_page_url
def down_pic(pic_urls: List[str], max_download_images: int) -> None:
"""给出图片链接列表,下载所有图片
Args:
pic_urls: 图片链接列表
max_download_images: 最大下载数量
"""
pic_urls = pic_urls[:max_download_images]
for i, pic_url in enumerate(pic_urls):
try:
pic = requests.get(pic_url, timeout=15)
image_output_path = './images/' + str(i + 1) + '.jpg'
with open(image_output_path, 'wb') as f:
f.write(pic.content)
print('成功下载第%s张图片: %s' % (str(i + 1), str(pic_url)))
except IOError as e:
print('下载第%s张图片时失败: %s' % (str(i + 1), str(pic_url)))
print(e)
continue
if __name__ == '__main__':
url_init = url_init_first + quote(keyword, safe='/')
all_pic_urls = []
page_urls, next_page_url = get_page_urls(url_init, headers)
all_pic_urls.extend(page_urls)
page_count = 0 # 累计翻页数
if not os.path.exists('./images'):
os.mkdir('./images')
# 获取图片链接
while 1:
page_urls, next_page_url = get_page_urls(next_page_url, headers)
page_count += 1
print('正在获取第%s个翻页的所有图片链接' % str(page_count))
if next_page_url == '' and page_urls == []:
print('已到最后一页,共计%s个翻页' % page_count)
break
all_pic_urls.extend(page_urls)
if len(all_pic_urls) >= max_download_images:
print('已达到设置的最大下载数量%s' % max_download_images)
break
down_pic(list(set(all_pic_urls)), max_download_images)
读取图片文件夹和训练模块
import pickle
import cv2
import os
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
from keras.models import Sequential
from keras.layers.core import Dropout
import matplotlib.pyplot as plt
from keras.layers.core import Dense
from sklearn.metrics import classification_report
from keras.optimizers import SGD
from keras import initializers
from keras import regularizers
def cv_show(img):
cv2.imshow('name',img)
cv2.waitKey(10)
cv2.destroyAllWindows()
print()
Pathlist=[]
#1:获取文件下所有的文件名字
for (filepath,dirnames,filenames) in os.walk(os.getcwd()+'/images/train/'):
for filename in filenames:
Pathlist.append(filepath+'/'+filename)
data = []
labels =[]
for name in Pathlist:
str='.'+name.split('keras_dnn')[1]
img=cv2.imread(str)
if img is None:
continue
#把数据做成一维的
img=cv2.resize(img,(32,32)).flatten()
data.append(img)
#从路径上获取标签,cat dog,pandas
label=name.split('/')[-2]
labels.append(label)
#对图像进行scala操作,把值全部都放在(0,1)之间
data=np.array(data,dtype='float')/255.0
print(len(labels))
labels=np.array(labels)
print(labels)
#数据集切分
(trainx,testx,trainy,testy)=train_test_split(data,labels, test_size=0.25, random_state=42)
#将label转换
lb=LabelBinarizer()
trainy = lb.fit_transform(trainy)
testy = lb.transform(testy)
#网络模型结构
model=Sequential()
model.add(Dense(512, input_shape=(3072,), activation="relu" ,
kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),
kernel_regularizer=regularizers.l2(0.01)))
#正则化
model.add(Dropout(0.5))
model.add(Dense(256, activation="relu",
kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),
kernel_regularizer=regularizers.l2(0.01)))
model.add(Dropout(0.5))
model.add(Dense(len(lb.classes_), activation="softmax",
kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),
kernel_regularizer=regularizers.l2(0.01)))
#学习率
LR = 0.01
epochs = 50
batch_size=32
#指定损失函数和评估方法
model.compile(loss="categorical_crossentropy", optimizer=SGD(lr=LR),
metrics=["accuracy"])
#训练网络模型
H = model.fit(trainx, trainy, validation_data=(testx, testy),
epochs=epochs, batch_size=batch_size)
# 测试网络模型
print("[INFO] 正在评估模型")
predictions = model.predict(testx, batch_size=batch_size)
print(classification_report(testy.argmax(axis=1),
predictions.argmax(axis=1), target_names=lb.classes_))
#当训练完成时,绘制结果曲线
N = np.arange(0, epochs)
plt.style.use("ggplot")
plt.figure()
#loss、acc ,val_loss、val_acc 分别表示训练集的准确度和损失值、以及验证集的准确度和损失值,注意是验证集的而不是测试集的
plt.plot(N, H.history["loss"], label="train_loss")
plt.plot(N, H.history["val_loss"], label="val_loss")
plt.plot(N, H.history["acc"], label="train_acc")
plt.plot(N, H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy (Simple NN)")
plt.xlabel("Epoch ")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.show()
plt.savefig('.result.png')
# 保存模型到本地
print("保存模型")
model.save('.model')
#保存labels
f = open('.labels', "wb")
print(pickle.dumps(lb))
f.write(pickle.dumps(lb))
f.close()
预测模块
from keras.models import load_model
import cv2
import pickle
img=cv2.imread('D:/worksplace/python/keras/demo/keras_dnn/images/test/3.jpg')
data=img.copy()
data=data.astype('float')/255
data=cv2.resize(data,(32,32)).flatten()
data = data.reshape((1, data.shape[0]))
#加载模型文件.model
model=load_model('.model')
#预测
pred=model.predict(data)
#label=['cat','dog','pandas']
#加载标签.labels文件
lb = pickle.loads(open(".labels", "rb").read())
i = pred.argmax(axis=1)[0]
label = lb.classes_[i]
print(label)
# 在图像中把结果画出来
#准确率:pred[0][i] * 100
text = "{}: {:.2f}%".format(label, pred[0][i] * 100)
cv2.putText(img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
(0, 0, 255), 2)
cv2.imshow("img", img)
cv2.waitKey(0)