tf2使用cnn模型训练保存自己的数据(分类)

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2021/7/19 8:22
# @Author : wutiande

import numpy as np
import matplotlib.image as mping
import matplotlib.pyplot as plt
import tensorflow as tf
tf.compat.v1.set_random_seed(2021)
from tensorflow.keras.optimizers import RMSprop,SGD,Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator


model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(16,(3,3),activation="relu",input_shape=(180,180,3)),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(32,(3,3),activation="relu"),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(64,(3,3),activation="relu"),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(128,(3,3),activation="relu"),
    tf.keras.layers.MaxPooling2D(2,2),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(550,activation="relu"),
    tf.keras.layers.Dropout(0.1,seed=2021),
    tf.keras.layers.Dense(400,activation="relu"),
    tf.keras.layers.Dropout(0.3,seed=2021),
    tf.keras.layers.Dense(300,activation="relu"),
    tf.keras.layers.Dropout(0.4,seed=2021),
    tf.keras.layers.Dense(200,activation="relu"),
    tf.keras.layers.Dropout(0.2,seed=2021),
    tf.keras.layers.Dense(2,activation="softmax")
])

# 模型总结
model.summary()

adam = Adam(learning_rate=0.001)
model.compile(optimizer='adam',loss="categorical_crossentropy",metrics=['acc'])

bs = 30
train_dir = r"D:\projects\tensorflow2-learn\dataset\Dataset\train/"
validation_dir = r"D:\projects\tensorflow2-learn\dataset\Dataset\test/"

train_datagen = ImageDataGenerator(rescale=1.0/255.)
test_datagen = ImageDataGenerator(rescale=1.0/255.)

#Flow_from_directory function lets the classifier directly identify the labels from the name of the directories the image lies in
train_generator = train_datagen.flow_from_directory(train_dir,batch_size=bs,class_mode='categorical',
                                                    target_size=(180,180))
validation_generator = test_datagen.flow_from_directory(validation_dir,batch_size=bs,
                                                        class_mode='categorical',
                                                        target_size=(180,180))
# 以上将输出数据和类别信息
"""
Found 119 images belonging to 2 classes.
Found 40 images belonging to 2 classes.
"""

history = model.fit(train_generator,validation_data=validation_generator,
                    # steps_per_epoch=150//bs,
                    epochs=5,
                    # validation_steps=50//bs,
                    # verbose=2,
                    )

tf.saved_model.save(model,"cnnTest")
C:\Users\a\AppData\Local\Programs\Python\Python38\python.exe D:/projects/tensorflow2-learn/CNN/test1.py
2021-07-19 08:44:42.519827: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
2021-07-19 08:44:42.519961: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2021-07-19 08:44:44.042458: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'nvcuda.dll'; dlerror: nvcuda.dll not found
2021-07-19 08:44:44.042583: W tensorflow/stream_executor/cuda/cuda_driver.cc:326] failed call to cuInit: UNKNOWN ERROR (303)
2021-07-19 08:44:44.045917: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: DESKTOP-DA0FST6
2021-07-19 08:44:44.046077: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: DESKTOP-DA0FST6
2021-07-19 08:44:44.046485: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 178, 178, 16)      448       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 89, 89, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 87, 87, 32)        4640      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 43, 43, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 41, 41, 64)        18496     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 20, 20, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 18, 18, 128)       73856     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 9, 9, 128)         0         
_________________________________________________________________
flatten (Flatten)            (None, 10368)             0         
_________________________________________________________________
dense (Dense)                (None, 550)               5702950   
_________________________________________________________________
dropout (Dropout)            (None, 550)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 400)               220400    
_________________________________________________________________
dropout_1 (Dropout)          (None, 400)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 300)               120300    
_________________________________________________________________
dropout_2 (Dropout)          (None, 300)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 200)               60200     
_________________________________________________________________
dropout_3 (Dropout)          (None, 200)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 2)                 402       
=================================================================
Total params: 6,201,692
Trainable params: 6,201,692
Non-trainable params: 0
_________________________________________________________________
Found 119 images belonging to 2 classes.
Found 40 images belonging to 2 classes.
2021-07-19 08:44:44.454215: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/5
4/4 [==============================] - 2s 438ms/step - loss: 0.7406 - acc: 0.4370 - val_loss: 0.6927 - val_acc: 0.5500
Epoch 2/5
4/4 [==============================] - 2s 370ms/step - loss: 0.7001 - acc: 0.3950 - val_loss: 0.6929 - val_acc: 0.5000
Epoch 3/5
4/4 [==============================] - 2s 372ms/step - loss: 0.6934 - acc: 0.4958 - val_loss: 0.6935 - val_acc: 0.5000
Epoch 4/5
4/4 [==============================] - 2s 372ms/step - loss: 0.6920 - acc: 0.5462 - val_loss: 0.6928 - val_acc: 0.5000
Epoch 5/5
4/4 [==============================] - 2s 372ms/step - loss: 0.6967 - acc: 0.4790 - val_loss: 0.6926 - val_acc: 0.5000
2021-07-19 08:44:53.492661: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.

Process finished with exit code 0

数据目录

 

 模型保存

 模型的使用

见这里:这里

 

 

 

 

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值