第一步,导入库
import os
import zipfile
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
第二步,解压压缩包
local_zip = 'horse-or-human.zip'
zip_ref = zipfile.ZipFile(local_zip,'r')
zip_ref.extractall('horse-or-human')
local_zip = 'validation-horse-or-human.zip'
zip_ref = zipfile.ZipFile(local_zip,'r')
zip_ref.extractall('validation-horse-or-human')
zip_ref.close()
第三步,读取路径
train_horse_dir = os.path.join('horse-or-human/horses')
train_human_dir = os.path.join('horse-or-human/human')
validation_horse_dir = os.path.join('validation-horse-or-human/horses')
validation_human_dir = os.path.join('validation-horse-or-human/humans')
第四步,观察数据,显示图片
x = os.listdir(train_horse_dir)
img = mpimg.imread(os.path.join(train_horse_dir,x[0])) #读取图片
print(img.shape)
plt.imshow(img)
plt.show()
第五步,用图片生成器读取数据
train_datagen = ImageDataGenerator(rescale=1/255) #像素值归一化
validation_datagen = ImageDataGenerator(rescale=1/255)
train_generator = train_datagen.flow_from_directory(
'horse-or-human/', #需要训练的数据来源
target_size=(150,150), #所有图片被压缩为150x150
batch_size=128, # 128个图片为一个batch
class_mode='binary' #根据具体目录horse和human将图片分为两类
)
validation_generator = validation_datagen.flow_from_directory(
'validation-horse-or-human/',
target_size=(150,150),
batch_size=32,
class_mode='binary'
)
第六步,建立模型
model = keras.models.Sequential([
keras.layers.Conv2D(32,(3,3),activation=tf.nn.relu,
input_shape=(150,150,3)),
keras.layers.MaxPooling2D(2,2),
keras.layers.Conv2D(32,(3,3),activation=tf.nn.relu),
keras.layers.MaxPooling2D(2,2),
keras.layers.Conv2D(32,(3,3),activation=tf.nn.relu),
keras.layers.MaxPooling2D(2,2),
keras.layers.Flatten(),
keras.layers.Dense(128,activation=tf.nn.relu),
keras.layers.Dense(1,activation=tf.nn.sigmoid)
])
model.compile(loss=tf.losses.binary_crossentropy,
optimizer=tf.optimizers.RMSprop(lr=0.001),
metrics=['acc']
)
第七步,训练网络模型
history = model.fit_generator(
train_generator, #训练数据
steps_per_epoch=8, # 每一epoch时期分为几次迭代
epochs=15,
verbose=1, #显示训练过程的方式
validation_data=validation_generator, #验证数据
validation_steps=8 #迭代几次
)