核对环境:
tensorflow版本: 2.0.1
keras版本: 2.3.1
数据集文件夹结构:
模型训练:
# -*- coding: utf-8 -*-
"""
Created on Sat Jun 6 14:28:50 2020
@author: USER
"""
from __future__ import absolute_import,division,print_function,unicode_literals
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
import os
import matplotlib.pyplot as plt
import pathlib
data_dir = '~/dataset/my_dataset/'
PATH = pathlib.Path(data_dir)
train_dir = os.path.join(PATH,'train')
validation_dir = os.path.join(PATH,'validation')
train_rose_dir = os.path.join(train_dir,'rose')
train_sunflowers_dir = os.path.join(train_dir,'sunflowers')
validation_rose_dir = os.path.join(validation_dir,'rose')
validation_sunflowers_dir = os.path.join(validation_dir,'sunflowers')
num_rose_tr = len(os.listdir(train_rose_dir))
num_sunflowers_tr = len(os.listdir(train_sunflowers_dir))
num_rose_val = len(os.listdir(validation_rose_dir))
num_sunflowers_val = len(os.listdir(validation_sunflowers_dir))
total_train = num_rose_tr + num_sunflowers_tr
total_val = num_rose_val + num_sunflowers_val
batch_size = 128 #batch数量
epochs = 5 #训练次数
IMG_HEIGHT = 150 #图片高
IMG_WIDTH = 150 #图片宽
train_image_generator = ImageDataGenerator(rescale=1./255, # 归一化
horizontal_flip=True, # 图片翻转
width_shift_range=.15, # 宽变化
height_shift_range=.15, # 高变化
rotation_range=45, # 旋转45度
zoom_range=0.5 # 缩放0.5倍
)
valiadation_image_generator = ImageDataGenerator(rescale=1./255) # 验证集不进行augmentation处理
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
directory=train_dir, # 训练集路径
shuffle=True, # 打乱图片顺序
target_size=(IMG_HEIGHT,IMG_WIDTH),# 修改图片尺寸
class_mode='binary')
val_data_gen = valiadation_image_generator.flow_from_directory(batch_size=batch_size,
directory=validation_dir, # 验证集路径
target_size=(IMG_HEIGHT,IMG_WIDTH), # 修改图片尺寸
class_mode='binary')
def plotImages(images_arr):
fig,axes = plt.subplots(1,5,figsize=(20,20))
axes = axes.flatten()
for img,ax in zip(images_arr,axes):
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.show()
augemted_images=[train_data_gen[0][0][0] for i in range(5)]
plotImages(augemted_images)
model = tf.keras.models.Sequential([
layers.Conv2D(16,3,padding='same',activation='relu',input_shape =(IMG_HEIGHT,IMG_WIDTH,3)), #16为filter个数 3为kernel_size
layers.MaxPooling2D(),
layers.Dropout(0.2), # 防过拟合
layers.Conv2D(32,3,padding='same',activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64,3,padding='same',activation='relu'),
layers.MaxPooling2D(),
layers.Dropout(0.2),
layers.Flatten(),
layers.Dense(512,activation='relu'),
layers.Dense(1,activation='sigmoid') # 2分类,因此1个神经元
])
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
model.summary()
history = model.fit_generator(
train_data_gen, # 训练集
steps_per_epoch=3, # 每个epoch训练batchsize的个数
epochs=epochs,
validation_data=val_data_gen,# 验证集
validation_steps=3,
)
def show_train_history(train_history,train,validation):
plt.plot(train_history.history[train]) # 绘制训练数据的执行结果
plt.plot(train_history.history[validation]) # 绘制验证数据的执行结果
plt.title('Train History') # 图标题
plt.xlabel('epoch') # x轴标签
plt.ylabel(train) # y轴标签
plt.legend(['train','validation'],loc='upper left') # 添加左上角图例
model.evaluate(train_data_gen)
model.save("F:/spyder_project/lidongdong/lab/model/my_model.h5")
模型文件.h5转换.pb:
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
def convert_h5to_pb():
model = tf.keras.models.load_model("F:/spyder_project/lidongdong/lab/model/my_model.h5",compile=False)
model.summary()
full_model = tf.function(lambda Input: model(Input))
full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
print(layer)
print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="~/model",
name="mnist.pb",
as_text=False)
convert_h5to_pb()