import os
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras import Model
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
# tensorflow.keras.applications
# 里面有很多现成训练好的模型可以直接利用
### 导入模型
from tensorflow.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.resnet import ResNet101
from tensorflow.keras.applications.inception_v3 import InceptionV3
pre_trained_model = ResNet101(input_shape = (75, 75, 3), # 输入大小
include_top = False, # 不要最后的全连接层
weights = 'imagenet')
# 可以选择训练哪些层
for layer in pre_trained_model.layers:
layer.trainable = False
# callback的作用
# 相当于一个监视器,在训练过程中可以设置一些自定义项,比如提前停止,改变学习率等
# callbacks = [
# 如果连续两个epoch还没降低就停止:
# tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# 可以动态改变学习率:
# tf.keras.callbacks.LearningRateScheduler
# 保存模型:
# tf.keras.callbacks.ModelCheckpoint
# 自定义方法:
# tf.keras.callbacks.Callback
# ]
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if(logs.get('acc')>0.95):
print("\nReached 95% accuracy so cancelling training!")
self.model.stop_training = True
from tensorflow.keras.optimizers import Adam
# 为全连接层准备
x = layers.Flatten()(pre_trained_model.output)
# 加入全连接层,这个需要重头训练的
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)
# 输出层
x = layers.Dense(1, activation='sigmoid')(x)
# 构建模型序列
model = Model(pre_trained_model.input, x)
model.compile(optimizer = Adam(lr=0.001),
loss = 'binary_crossentropy',
metrics = ['acc'])
train_datagen = ImageDataGenerator(rescale = 1./255.,
rotation_range = 40,
width_shift_range = 0.2,
height_shift_range = 0.2,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = True)
test_datagen = ImageDataGenerator( rescale = 1.0/255. )
train_generator = train_datagen.flow_from_directory(train_dir,
batch_size = 20,
class_mode = 'binary',
target_size = (75, 75))
validation_generator = test_datagen.flow_from_directory( validation_dir,
batch_size = 20,
class_mode = 'binary',
target_size = (75, 75))
#
# 训练模型
# 加入Callback()模块
callbacks = myCallback()
history = model.fit_generator(
train_generator,
validation_data = validation_generator,
steps_per_epoch = 100,
epochs = 10,
validation_steps = 50,
verbose = 2,
callbacks=[callbacks])
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
基于ResNet的迁移学习
最新推荐文章于 2024-07-19 14:39:32 发布