硬件问题真的是搞机器学习的一个痛处,更何况这只是入门级别的。
基于CNN和VGG16,实现对海贼王人物的分类识别。本次自己动手搭建了VGG16 网络,并且和迁移学习的VGG16的网络的实验效果做了一个对比,还包括其中出现的一些幺蛾子。部分代码参考自大神“K同学啊”的博客
1.导入库
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os,PIL,pathlib
from tensorflow import keras
from tensorflow.keras import layers,models,Sequential,Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D,MaxPooling2D,Dense,Flatten,Dropout
2.数据处理
数据所在文件夹
data_dir = "E:/tmp/.keras/datasets/hzw_photos"
data_dir = pathlib.Path(data_dir)
构造一个ImageDataGenerator对图片进行处理(包括归一化和数据增强)
train_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1./255,#归一化
rotation_range=45,#随机翻转
shear_range=0.2,#c错切变换
zoom_range=0.2,
validation_split=0.2,#划分数据集,8:2的划分
horizontal_flip=True#水平翻转
)
以8:2的比例划分训练集和测试集
train_ds = train_data_gen.flow_from_directory(
directory=data_dir,
target_size=(height,width),
batch_size=batch_size,
shuffle=True,
class_mode='categorical',
subset='training'
)
test_ds = train_data_gen.flow_from_directory(
directory=data_dir,
target_size=(height,width),
batch_size=batch_size,
shuffle=True,
class_mode='categorical',
subset='validation'
)
结果如下所示:
Found 499 images belonging to 7 classes.
Found 122 images belonging to 7 classes.
7种类别如下所示:
all_images_paths = list(data_dir.glob('*'))##”*”匹配0个或多个字符
all_images_paths = [str(path) for path in all_images_paths]
all_label_names = [path.split("\\")[5].split(".")[0] for path in all_images_paths]
print(all_label_names)
结果:['lufei', 'luobin', 'namei', 'qiaoba', 'shanzhi',