配置
tensorflow2.4.0
python3.6
猫狗大战数据集
代码
VGG16网络很著名,这里不再介绍。
keras里有预训练好的VGG16,tensorflow2.0以后的版本中已经集成了keras。
解释在代码中。
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras import applications
from tensorflow.keras.layers import Dropout, Flatten, Dense
from tensorflow.keras.optimizers import SGD
import pickle
import numpy as np
# 开启GPU加速
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
OUT_CATEGORIES = 2 # 分类数
batch_size = 2 # 批量大小
epochs = 50 # 迭代次数
imgSize = 256
def model():
img_shape = (imgSize, imgSize, 3)
# 加载不包含全连接层的VGG16网络
base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=img_shape)
bas