一、代码演示
保存模型 shallownet_train.py:
"保存训练模型"
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import imagetoarraypreprocessor
import SimplePreprocessor
import SimpleDatasetLoader
import shallownet
from keras.optimizers import SGD
from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
import argparse
ap = argparse.ArgumentParser()
ap.add_argument('-d','--dataset',required=True,help='path to input dataset')
#保存训练好的模型的路径
ap.add_argument('-m','--model',required=True,help='path to output model')
args = vars(ap.parse_args())
print('[INFO] loading images...')
imagePaths = list(paths.list_images(args['dataset']))
sp = SimplePreprocessor.SimplePreprocessor(32,32)
iap = imagetoarraypreprocessor.ImageToArrayPreprocessor()
sdl = SimpleDatasetLoader.SimpleDatasetLoader(preprocessors=[sp,iap])
(data,labels) = sdl.load(imagePaths,verbose=500)
data = data.astype('float')/255.0
(trainX,testX,trainY,testY) = train_test_split(data,labels,test_size=0.25,random_state=42)
trainY = LabelBinarizer().fit_transform(trainY)
testY = LabelBinarizer().fit_transform(testY)
print('[INFO] compiling model...')
opt = SGD(lr=0.005)
model = shallownet.ShallowNet.build(width=32,height=32,depth=3,classes=3)
model.compile(loss='categorical_crossentropy',optimizer=opt,metrics=['accuracy'])
print('[INFO] training network...')
H = model.fit(trainX,trainY,validation_data=(testX,testY),batch_size=32,epochs=100,verbose=1)
#保存模型,保存格式HDF5
print('[INFO] serializing network...')
model.save(args['model'])
print('[INFO] evaluating network...')
predictions = model.predict(testX,batch_size=32)
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=['cat','dog','panda']))
#画图
plt.style.use('ggplot')
plt.figure()
plt.plot(np.arange(0,100),H.history['loss'],label = 'train_loss')
plt.plot(np.arange(0,100),H.history['val_loss'],label = 'val_loss')
plt.plot(np.arange(0,100),H.history['acc'],label = 'train_acc')
plt.plot(np.arange(0,100),H.history['val_acc'],label = 'val_acc')
plt.title('Training Loss and Accuracy')
plt.xlabel('Epoch #')
plt.ylabel('Loss/Accuracy')
plt.legend()
plt.show()
加载模型 shallownet_load.py:
"加载训练好的模型并进行分类验证"
import imagetoarraypreprocessor
import SimplePreprocessor
import SimpleDatasetLoader
#load_model加载训练好的网络(HDF5文件),解码HDF5文件
from keras.models import load_model
from imutils import paths
import numpy as np
import argparse
import cv2
ap =argparse.ArgumentParser()
ap.add_argument('-d','--dataset',required=True,help='path to input dataset')
ap.add_argument('-m','--model',required=True,help='path to pre-trained model')
args = vars(ap.parse_args())
classLabels = ['cat','dog','panda']
#获取图像列表,随机抽取
print('[INFO] sampling images...')
#加载所有图片
imagePaths = np.array(list(paths.list_images(args['dataset'])))
#0-3000 随机取10个数
idxs = np.random.randint(0,len(imagePaths),size=(10,))
#随机抽取10张图片
imagePaths = imagePaths[idxs]
#对10张图片预处理(裁剪、灰度化)、加载
sp = SimplePreprocessor.SimplePreprocessor(32,32)
iap = imagetoarraypreprocessor.ImageToArrayPreprocessor()
#加载图片,并将像素缩放到[0,1]
sdl = SimpleDatasetLoader.SimpleDatasetLoader(preprocessors=[sp,iap])
#data:(10, 32, 32, 3) labels:(10,)
(data,labels) = sdl.load(imagePaths)
data = data.astype('float')/255.0
#加载训练好的模型
print('[INFO] loading pre-trained network...')
model = load_model(args['model'])
print('----------------------------')
print(model)
#对图片进行预测
print('[INFO] predicting...')
"preds返回数据中每幅图像的概率列表中概率最大的类标签索引"
preds = model.predict(data,batch_size=32).argmax(axis=1)
#可视化
for (i, imagePath) in enumerate(imagePaths):
image = cv2.imread(imagePath)
cv2.putText(image,'Label:{}'.format(classLabels[preds[i]]),(10,30),cv2.FONT_HERSHEY_SIMPLEX,0.7,(0,255,0),2)
cv2.imshow('Image',image)
cv2.waitKey(0)
二、注意事项
1.执行程序shallownet_train.py
最好从电脑终端(或pycharm终端执行)进入 该程序所在路径 执行语句:python shallownet_train.py --dataset ../datasets/animals --model shallownet_weights.hdf5
注:终端要在annoconda环境下
2.执行程序shallownet_load.py
执行语句python shallownet_load.py --dataset ../datasets/animals --model shallownet_weights.hdf5