keras模型训练 + .h5转变可部署.pb

核对环境

tensorflow版本: 2.0.1
keras版本: 2.3.1

数据集文件夹结构
在这里插入图片描述
模型训练

# -*- coding: utf-8 -*-
"""
Created on Sat Jun  6 14:28:50 2020

@author: USER
"""
from __future__ import absolute_import,division,print_function,unicode_literals
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator 
from tensorflow.keras import layers
import os
import matplotlib.pyplot as plt
import pathlib


data_dir = '~/dataset/my_dataset/'
PATH = pathlib.Path(data_dir)

train_dir = os.path.join(PATH,'train')
validation_dir = os.path.join(PATH,'validation')
train_rose_dir = os.path.join(train_dir,'rose')
train_sunflowers_dir = os.path.join(train_dir,'sunflowers')
validation_rose_dir = os.path.join(validation_dir,'rose')
validation_sunflowers_dir = os.path.join(validation_dir,'sunflowers')

num_rose_tr = len(os.listdir(train_rose_dir))
num_sunflowers_tr = len(os.listdir(train_sunflowers_dir))
num_rose_val = len(os.listdir(validation_rose_dir))
num_sunflowers_val = len(os.listdir(validation_sunflowers_dir))
total_train = num_rose_tr + num_sunflowers_tr
total_val = num_rose_val + num_sunflowers_val

batch_size = 128 #batch数量
epochs = 5 #训练次数
IMG_HEIGHT = 150 #图片高
IMG_WIDTH = 150 #图片宽

train_image_generator = ImageDataGenerator(rescale=1./255,         # 归一化
                                           horizontal_flip=True,   # 图片翻转
                                           width_shift_range=.15,  # 宽变化
                                           height_shift_range=.15, # 高变化
                                           rotation_range=45,      # 旋转45度
                                           zoom_range=0.5          # 缩放0.5倍
                                              )
valiadation_image_generator = ImageDataGenerator(rescale=1./255)   # 验证集不进行augmentation处理

train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,               # 训练集路径
                                                           shuffle=True,                      # 打乱图片顺序
                                                           target_size=(IMG_HEIGHT,IMG_WIDTH),# 修改图片尺寸
                                                           class_mode='binary')
val_data_gen = valiadation_image_generator.flow_from_directory(batch_size=batch_size,
                                                             directory=validation_dir,           # 验证集路径
                                                             target_size=(IMG_HEIGHT,IMG_WIDTH), # 修改图片尺寸
                                                             class_mode='binary')

def plotImages(images_arr):
    fig,axes = plt.subplots(1,5,figsize=(20,20))
    axes = axes.flatten()
    for img,ax in zip(images_arr,axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

augemted_images=[train_data_gen[0][0][0] for i in range(5)]
plotImages(augemted_images)

model = tf.keras.models.Sequential([
    layers.Conv2D(16,3,padding='same',activation='relu',input_shape =(IMG_HEIGHT,IMG_WIDTH,3)), #16为filter个数 3为kernel_size 
    layers.MaxPooling2D(), 
    layers.Dropout(0.2), # 防过拟合
    layers.Conv2D(32,3,padding='same',activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(64,3,padding='same',activation='relu'),
    layers.MaxPooling2D(),
    layers.Dropout(0.2),
    layers.Flatten(),
    layers.Dense(512,activation='relu'),
    layers.Dense(1,activation='sigmoid') # 2分类,因此1个神经元
])

model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
model.summary()
history = model.fit_generator(
    train_data_gen,    # 训练集
    steps_per_epoch=3, # 每个epoch训练batchsize的个数
    epochs=epochs,
    validation_data=val_data_gen,# 验证集
    validation_steps=3,
)

def show_train_history(train_history,train,validation):
    plt.plot(train_history.history[train])      # 绘制训练数据的执行结果
    plt.plot(train_history.history[validation]) # 绘制验证数据的执行结果
    plt.title('Train History')                  # 图标题 
    plt.xlabel('epoch')                         # x轴标签
    plt.ylabel(train)                           # y轴标签
    plt.legend(['train','validation'],loc='upper left') # 添加左上角图例

model.evaluate(train_data_gen)

model.save("F:/spyder_project/lidongdong/lab/model/my_model.h5")

模型文件.h5转换.pb

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
 
def convert_h5to_pb():
    model = tf.keras.models.load_model("F:/spyder_project/lidongdong/lab/model/my_model.h5",compile=False)
    model.summary()
    full_model = tf.function(lambda Input: model(Input))
    full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
 
    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()
 
    layers = [op.name for op in frozen_func.graph.get_operations()]
    print("-" * 50)
    print("Frozen model layers: ")
    for layer in layers:
        print(layer)
 
    print("-" * 50)
    print("Frozen model inputs: ")
    print(frozen_func.inputs)
    print("Frozen model outputs: ")
    print(frozen_func.outputs)
 
    # Save frozen graph from frozen ConcreteFunction to hard drive
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="~/model",
                      name="mnist.pb",
                      as_text=False)
  
convert_h5to_pb()
  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值