将文件夹里的图片===》符合要求的图片(只包含人脸,且转化为一定的大小)
from config import data_folder, generate_folder
import os
import numpy as np
import cv2
from PIL import Image
from glob import glob
from tqdm import tqdm
import pandas as pd
import face_recognition
if not os.path.exists(generate_folder):
os.mkdir(generate_folder)
def scale_img(img, scale_size=139):
h, w = img.shape[:2]
if h > w:
new_h, new_w = scale_size * h / w, scale_size
else:
new_h, new_w = scale_size, scale_size * w / h
new_h, new_w = int(new_h), int(new_w)
img = cv2.resize(img, (new_w, new_h))
top = None
left = None
if h == w:
return img
elif h < w:
if new_w > scale_size:
left = np.random.randint(0, new_w - scale_size)
else:
left = 0
top = 0
elif h > w:
if new_h > scale_size:
top = np.random.randint(0, new_h - scale_size)
else:
top = 0
left = 0
img = img[top: top + scale_size, left: left + scale_size]
return img
def preprocess():
categories = os.listdir(data_folder)
for category in categories:
in_path = os.path.join(data_folder, category)
out_path = os.path.join(generate_folder, category + '_face')
if not os.path.exists(out_path):
os.mkdir(out_path)
for file in glob(in_path + '/*.jpg'):
file_name = file.split('\\')[-1] # Linux系统需要修改这一项file_name = file.split('/')[-1]
print(file_name)
img = face_recognition.load_image_file(file)
if max(img.shape) > 2000:
if img.shape[0] > img.shape[1]:
img = cv2.resize(img, (2000, int(2000 * img.shape[1] / img.shape[0])))
else:
img = cv2.resize(img, (int(2000 * img.shape[0] / img.shape[1]), 2000))
locations = face_recognition.face_locations(img) # 人脸检测,大部分为单个,但也有多个检测结果
if len(locations) <= 0:
print("no face")
else:
for i, (a, b, c, d) in enumerate(locations):
image_split = img[a:c, d:b, :]
image_split = scale_img(image_split)
Image.fromarray(image_split).save(os.path.join(out_path, file_name + '_{}.png'.format(i)))
||
||
将这里的图片都转化为csv文件,两列,一列文件名,一列标签
def generate_desc_csv(folder):
categories = os.listdir(folder)
file_id = []
label = []
for category in tqdm(categories):
# print(type(category))
images = glob(os.path.join(folder, category) + '/*.png') # 这里可以观察一下数据集,发现图片均为jpg格式,正则匹配比较简单
for img in images:
file_id.append(img)
label.append(category)
df_description = pd.DataFrame({'file_id': file_id, 'label': label})
df_description.to_csv('../data/description.csv', encoding='utf8', index=False) # 落地这个csv文件是为了更符合常见的数据集说明文件
查看csv文件,可以看到相应标签的数量,可以观察数据
df_data=pd.read_csv('../data/description.csv',encoding='utf-8')
print(type(df_data['label']))
print(df_data.sample(5))
print(df_data['label'].value_counts())
df_data['label'].value_counts().plot(kind='bar')
# plt.show()
随机采样10张图片展示并显示标签
datasize=len(df_data)
print('datasize',datasize)
sample_index=random.sample(list(range(datasize)),10)
print(sample_index)
plt.figure(figsize=(18,6))
for i in range(10):
plt.subplot(2,5,i+1)
img=cv2.imread(df_data['file_id'][sample_index[i]])
img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
plt.imshow(img)
label=str(df_data['label'][sample_index[i]])
plt.title('label {}'.format(label))
plt.show()
数据增广
"""
Author: Zhou Chen
Date: 2020/1/8
Desc: desc
"""
from keras.preprocessing.image import ImageDataGenerator
import pandas as pd
class DataSet(object):
def __init__(self, root_folder):
self.folder = root_folder
self.df_desc = pd.read_csv(self.folder + 'description.csv', encoding="utf8")
print('the total length of dataset: ',len(self.df_desc))
def get_generator(self, batch_size=32, da=True):
if da:
# 数据增强,多少数据用于训练集集,水平反转()
train_gen = ImageDataGenerator(rescale=1 / 255., validation_split=0.25, horizontal_flip=True)
else:
train_gen = ImageDataGenerator(rescale=1 / 255., validation_split=0.25, horizontal_flip=False)
img_size = (64, 64)
train_generator = train_gen.flow_from_dataframe(dataframe=self.df_desc, #一列为图像的文件名,另一列为图像的类别
directory='.', #目标目录的路径
x_col='file_id',
y_col='label',
batch_size=batch_size,
class_mode='categorical',#将是 2D one-hot 编码标签
target_size=img_size,
color_mode='grayscale', #图像是否被转换成 1 或 3 个颜色通道。
subset='training')
valid_generator = train_gen.flow_from_dataframe(dataframe=self.df_desc,
directory=".",
x_col="file_id",
y_col="label",
batch_size=batch_size,
class_mode="categorical",
target_size=img_size,
color_mode='grayscale',
subset='validation')
return train_generator, valid_generator
if __name__=='__main__':
ds = DataSet('../data/')
train_generator, valid_generator = ds.get_generator(batch_size=64)
print('the length of train dataset',train_generator.n)
print('the size of one batch : ',train_generator.batch_size)
result:
train_gen.flow_from_dataframe出错?
重新安装keras.processing
pip install +具体位置/keras.processing.zip
keras.processing.zip
链接:https://pan.baidu.com/s/1Vo3lPrt4itjrzqudgugfwg
提取码:qigx
模型构建
from tensorflow.keras.layers import Input, Conv2D, PReLU, MaxPooling2D, Dense, Flatten, Dropout, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ResNet50
def CNN(input_shape=(224, 224, 3), n_classes=5):
# input
input_layer = Input(shape=input_shape)
x = Conv2D(32, (1, 1), strides=1, padding='same', activation='relu')(input_layer)
# block1
x = Conv2D(64, (3, 3), strides=1, padding='same')(x)
x = PReLU()(x)
x = Conv2D(64, (5, 5), strides=1, padding='same')(x)
x = PReLU()(x)
x = MaxPooling2D(pool_size=(2, 2), strides=2)(x)
# fc
x = Flatten()(x)
x = Dense(2048, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(1024, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(n_classes, activation='softmax')(x)
model = Model(inputs=input_layer, outputs=x)
return model
def ResNet_pretrained(input_shape=(224, 224, 3), n_classes=5):
input_layer = Input(shape=input_shape)
densenet121 = ResNet50(include_top=False, weights=None, input_tensor=input_layer)
x = GlobalAveragePooling2D()(densenet121.output)
x = Dropout(0.5)(x)
x = Dense(n_classes, activation='softmax')(x)
model = Model(input_layer, x)
return model
if __name__=='__main__':
resnet_model=ResNet_pretrained()
resnet_model.summary()
训练模型
from data import DataSet
from model import CNN, ResNet_pretrained
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.optimizers import Adam
from utils import plot_history
import os
if not os.path.exists("../models/"):
os.mkdir("../models/")
# 数据
ds = DataSet('../data/')
train_generator, valid_generator = ds.get_generator(batch_size=64)
img_shape = (64, 64, 1)
# 模型
model_cnn = CNN(input_shape=img_shape)
model_resnet = ResNet_pretrained(input_shape=img_shape)
# 训练
optimizer_cnn = Adam(lr=3e-4)
optimizer_resnet = Adam(lr=3e-4)
callbacks_cnn = [
ModelCheckpoint('../models/cnn_best_weights.h5', monitor='val_loss', save_best_only=True, verbose=1, save_weights_only=True),
# EarlyStopping(monitor='val_loss', patience=5)
]
callbacks_resnet = [
ModelCheckpoint('../models/resnet_best_weights.h5', monitor='val_loss', save_best_only=True, verbose=1, save_weights_only=True),
# EarlyStopping(monitor='val_loss', patience=5)
]
model_cnn.compile(optimizer=optimizer_cnn, loss='categorical_crossentropy', metrics=['accuracy'])
model_resnet.compile(optimizer=optimizer_resnet, loss='categorical_crossentropy', metrics=['accuracy'])
epochs = 20
history_cnn = model_cnn.fit_generator(train_generator,
steps_per_epoch=train_generator.n//train_generator.batch_size,
validation_data=valid_generator,
validation_steps=valid_generator.n//valid_generator.batch_size,
epochs=epochs,
callbacks=callbacks_cnn
)
history_resnet = model_cnn.fit_generator(train_generator,
steps_per_epoch=train_generator.n//train_generator.batch_size,
validation_data=valid_generator,
validation_steps=valid_generator.n//valid_generator.batch_size,
epochs=epochs,
callbacks=callbacks_resnet
)
plot_history([history_cnn, history_resnet])
画出各个模型分别在训练集,验证集上的精度
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('fivethirtyeight')
def plot_history(his):
cnn_his = his[0].history
resnet_his = his[1].history
plt.figure(figsize=(18, 6))
plt.subplot(1, 2, 1)
plt.plot(np.arange(len(cnn_his['accuracy'])), cnn_his['accuracy'], label="training accuracy")
plt.plot(np.arange(len(cnn_his['val_accuracy'])), cnn_his['val_accuracy'], label="validation accuracy")
plt.title("CNN")
plt.legend(loc=0)
plt.subplot(1, 2, 2)
plt.plot(np.arange(len(resnet_his['accuracy'])), resnet_his['accuracy'], label="training accuracy")
plt.plot(np.arange(len(resnet_his['val_accuracy'])), resnet_his['val_accuracy'], label="validation accuracy")
plt.title("ResNet50")
plt.legend(loc=0)
plt.savefig("../assets/his.png")
plt.show()