#%%
import tensorflow as tf
#%%
print('Tensorflow version: {}'.format(tf.__version__))
#%%
from tensorflow import keras
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
#%%
imgs_path = glob.glob('dataset/birds/*/*.jpg')
#%%
imgs_path[:5]
#%%
img_p = imgs_path[100]
img_p
#%%
img_p.split('\\')[1].split('.')[1]
#%%
label_names = [img_p.split('\\')[1].split('.')[1] for img_p in imgs_path]
#%%
unique_label = np.unique(label_names)
#%%
label_to_index = dict((v, k) for k, v in enumerate(unique_label))
#%%
label_to_index
#%%
index_to_label = dict((v, k) for k, v in label_to_index.items())
#%%
index_to_label
#%%
all_labels = [label_to_index.get(la) for la in label_names]
#%%
all_labels[:5]
#%%
all_labels[-5: ]
#%%
len(imgs_path)
#%%
np.random.seed(2021)
random_index = np.random.permutation(len(imgs_path))
#%%
imgs_path = np.array(imgs_path)[random_index]
all_labels = np.array(all_labels)[random_index]
#%%
i = int(len(imgs_path)*0.8)
#%%
train_path = imgs_path[ :i]
train_labels = all_labels[ :i]
test_path = imgs_path[i: ]
test_labels = all_labels[i: ]
#%%
train_ds = tf.data.Dataset.from_tensor_slices((train_path, train_labels))
test_ds = tf.data.Dataset.from_tensor_slices((test_path, test_labels))
#%% md
加载和格式化图像
#%%
def load_and_preprocess_image(path, label):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [256, 256])
image = tf.cast(image, tf.float32)
image = image/255.0 # normalize to [0,1] range
return image, label
#%%
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = train_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
#%%
test_count = len(test_path)
train_count = len(train_path)
#%%
BATCH_SIZE = 32
#%%
train_ds = train_ds.repeat().shuffle(buffer_size=300).batch(BATCH_SIZE)
train_ds
#%%
test_ds = test_ds.batch(BATCH_SIZE)
#%% md
建立模型
#%%
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(64, (3, 3), input_shape=(256, 256, 3),
activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(256, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(256, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1024, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(200)
])
#%%
model.summary()
#%%
model.compile(optimizer=tf.keras.optimizers.Adam(0.0001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['acc']
)
#%%
steps_per_epoch = train_count//BATCH_SIZE
validation_steps = test_count//BATCH_SIZE
#%%
history = model.fit(train_ds, epochs=10,
steps_per_epoch=steps_per_epoch,
validation_data=test_ds,
validation_steps=validation_steps)
#%%
history.history.keys()
#%%
plt.plot(history.epoch, history.history.get('acc'), label='acc')
plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
plt.legend()
#%%
plt.plot(history.epoch, history.history.get('loss'), label='loss')
plt.plot(history.epoch, history.history.get('val_loss'), label='val_loss')
plt.legend()
#%%
model.save('birds.h5')
#%%
model.save_weights('birds_weights.h5')
#%%
def load_and_preprocess_image(path):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [256, 256])
image = tf.cast(image, tf.float32)
image = image/255.0 # normalize to [0,1] range
return image
#%%
test_img = 'Bobolink_0013_9367.jpg'
test_tensor = load_and_preprocess_image(test_img)
test_tensor = tf.expand_dims(test_tensor, axis=0)
pred = model.predict(test_tensor)
#%%
index_to_label.get(np.argmax(pred))
#%%
两百种鸟类图片分类
最新推荐文章于 2024-06-09 22:34:13 发布