Deep.Learning.for.Computer.Vision.with.Python-----chapter13网络模型保存与加载


一、代码演示

保存模型 shallownet_train.py:

"保存训练模型"
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import imagetoarraypreprocessor
import SimplePreprocessor
import SimpleDatasetLoader
import shallownet
from keras.optimizers import SGD
from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
import argparse

ap = argparse.ArgumentParser()
ap.add_argument('-d','--dataset',required=True,help='path to input dataset')
#保存训练好的模型的路径
ap.add_argument('-m','--model',required=True,help='path to output model')
args = vars(ap.parse_args())
print('[INFO] loading images...')
imagePaths = list(paths.list_images(args['dataset']))

sp = SimplePreprocessor.SimplePreprocessor(32,32)
iap = imagetoarraypreprocessor.ImageToArrayPreprocessor()

sdl = SimpleDatasetLoader.SimpleDatasetLoader(preprocessors=[sp,iap])
(data,labels) = sdl.load(imagePaths,verbose=500)
data = data.astype('float')/255.0

(trainX,testX,trainY,testY) = train_test_split(data,labels,test_size=0.25,random_state=42)
trainY = LabelBinarizer().fit_transform(trainY)
testY = LabelBinarizer().fit_transform(testY)

print('[INFO] compiling model...')
opt = SGD(lr=0.005)
model = shallownet.ShallowNet.build(width=32,height=32,depth=3,classes=3)
model.compile(loss='categorical_crossentropy',optimizer=opt,metrics=['accuracy'])

print('[INFO] training network...')
H = model.fit(trainX,trainY,validation_data=(testX,testY),batch_size=32,epochs=100,verbose=1)
#保存模型,保存格式HDF5
print('[INFO] serializing network...')
model.save(args['model'])

print('[INFO] evaluating network...')
predictions = model.predict(testX,batch_size=32)
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=['cat','dog','panda']))

#画图
plt.style.use('ggplot')
plt.figure()
plt.plot(np.arange(0,100),H.history['loss'],label = 'train_loss')
plt.plot(np.arange(0,100),H.history['val_loss'],label = 'val_loss')
plt.plot(np.arange(0,100),H.history['acc'],label = 'train_acc')
plt.plot(np.arange(0,100),H.history['val_acc'],label = 'val_acc')
plt.title('Training Loss and Accuracy')
plt.xlabel('Epoch #')
plt.ylabel('Loss/Accuracy')
plt.legend()
plt.show()

加载模型 shallownet_load.py:

"加载训练好的模型并进行分类验证"
import imagetoarraypreprocessor
import SimplePreprocessor
import SimpleDatasetLoader
#load_model加载训练好的网络(HDF5文件),解码HDF5文件
from keras.models import load_model
from imutils import paths
import numpy as np
import argparse
import cv2

ap =argparse.ArgumentParser()
ap.add_argument('-d','--dataset',required=True,help='path to input dataset')
ap.add_argument('-m','--model',required=True,help='path to pre-trained model')
args = vars(ap.parse_args())

classLabels = ['cat','dog','panda']
#获取图像列表,随机抽取
print('[INFO] sampling images...')
#加载所有图片
imagePaths = np.array(list(paths.list_images(args['dataset'])))
#0-3000 随机取10个数
idxs = np.random.randint(0,len(imagePaths),size=(10,))
#随机抽取10张图片
imagePaths = imagePaths[idxs]
#对10张图片预处理(裁剪、灰度化)、加载
sp = SimplePreprocessor.SimplePreprocessor(32,32)
iap = imagetoarraypreprocessor.ImageToArrayPreprocessor()
#加载图片,并将像素缩放到[0,1]
sdl = SimpleDatasetLoader.SimpleDatasetLoader(preprocessors=[sp,iap])
#data:(10, 32, 32, 3) labels:(10,)
(data,labels) = sdl.load(imagePaths)
data = data.astype('float')/255.0
#加载训练好的模型
print('[INFO] loading pre-trained network...')
model = load_model(args['model'])
print('----------------------------')
print(model)
#对图片进行预测
print('[INFO] predicting...')
"preds返回数据中每幅图像的概率列表中概率最大的类标签索引"
preds = model.predict(data,batch_size=32).argmax(axis=1)
#可视化
for (i, imagePath) in enumerate(imagePaths):
    image = cv2.imread(imagePath)
    cv2.putText(image,'Label:{}'.format(classLabels[preds[i]]),(10,30),cv2.FONT_HERSHEY_SIMPLEX,0.7,(0,255,0),2)
    cv2.imshow('Image',image)
    cv2.waitKey(0)

二、注意事项

1.执行程序shallownet_train.py
最好从电脑终端(或pycharm终端执行)进入 该程序所在路径 执行语句:python shallownet_train.py --dataset ../datasets/animals --model shallownet_weights.hdf5
注:终端要在annoconda环境下
2.执行程序shallownet_load.py
执行语句python shallownet_load.py --dataset ../datasets/animals --model shallownet_weights.hdf5


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值