由于vggface2提供的的训练集和测试集类别完全不重合,说明这个数据集本身不是用来做分类问题的,所以以下的代码仅供参考
from __future__ import print_function
import keras
from keras.layers import Input, Dense, Flatten, add
from keras.layers import Conv2D, Activation, MaxPooling2D, AveragePooling2D
from keras import backend as K
from keras.callbacks import ModelCheckpoint
import tensorflow as tf
from keras.models import Model
from keras.utils import plot_model
from sklearn.model_selection import train_test_split
import os
import numpy as np
import cv2
import random
# FLAGS参数设置
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('graph_name', 'vggface2', '模型图片的名字')
# 训练数据路径
tf.app.flags.DEFINE_string('train_path',
'E://dataset//vggface2//train',
'Filepattern for training data.')
# 测试数据路径
tf.app.flags.DEFINE_string('test_path',
'E://dataset//vggface2//test',
'Filepattern for testing data.')
tf.app.flags.DEFINE_string('model_path',
'modeldir.VGGface',
'模型保存路径')
tf.app.flags.DEFINE_integer('height', 190, '')
tf.app.flags.DEFINE_integer('width', 170, '')
tf.app.flags.DEFINE_integer('IMAGE_CHANNELS', 3, '')
tf.app.flags.DEFINE_integer('num_classes', 8631, '类别数')
tf.app.flags.DEFINE_integer('epochs', 9, '训练轮数')
tf.app.flags.DEFINE_integer('batch_size', 4, '')
# 模式:训练、测试
tf.app.flags.DEFINE_string('flag', 'train', 'train or eval.')
def res_block(x, channels, i):
if i == 1: # 第二个block
strides = (1, 1)