import tensorflow as tf
# TODO '注意' 如果下载数据失败则加入下面注释的两行代码
# import ssl
# ssl._create_default_https_context = ssl._create_unverified_context
from tensorflow import keras
cifar100 =keras.datasets.cifar100
(train_images, train_labels), _ = cifar100.load_data() #下载数据
train_images = train_images/255.0 #归一化
train_images = tf.image.resize(train_images[:5000],(224,224)) #取前5000张照片并强制转换为(224*224*3)
train_labels = train_labels[:5000] #取前5000张训练照片标签
# TODO.# ALEXNET解决Cifar100分类问题
# TODO 1.conv1:11*11,96,4,valid 例:11*11(卷积核) ,stride=4(步长)
# Maxpooling:3*3,2
# TODO 2.conv2:5*5,256,1,same
# Maxpooling:3*3,2,valid
# TODO 3. conv3:3*3,384,1,same
# TODO 4. conv4:3*3,384,1,same
# TODO 5. conv5:3*3,256,1,same
# Maxpooling(最大池化):3*3,2,valid
#TODO 6.全连接层(打平) fc1:4096,dropout=0.5
# fc2:4096,dropout=0.5
# output:100(分为100类)
alexNet = keras.Sequential(layers=[
#Conv1
keras.layers.Conv2D(96,11,strides=(4,4),padding='valid',activation='relu'),
keras.layers.BatchNormalization(),
keras.layers.MaxPool2D(pool_size=(3,3),strides=(2,2),padding='valid'),
#Conv2
keras.layers.Conv2D(256,5,strides=(1,1),padding='same',activation='relu'),
keras.layers.BatchNormalization(),
keras.layers.MaxPool2D(pool_size=(3,3),strides=(2,2),padding='valid'),
#conv3
keras.layers.Conv2D(384,3,strides=(1*1),padding='same',activation='relu'),
#Conv4
keras.layers.Conv2D(384,3,strides=(1,1),padding='same',activation='relu'),
#Conv5
keras.layers.Conv2D(256,3,strides=(1,1),padding='same',activation='relu'),
keras.layers.MaxPool2D(pool_size=(3,3),strides=(2,2),padding='valid'),
#Flatten
keras.layers.Flatten(),
#FC1
keras.layers.Dense(4096,activation='relu'),
keras.layers.Dropout(rate=0.5),
#fc2
keras.layers.Dense(4096, activation='relu'),
keras.layers.Dropout(rate=0.5),
keras.layers.Dense(100)
])
#配置模型
alexNet.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
#训练模型
alexNet.fit(x=train_images,y=train_labels,batch_size=16,epochs=10)
如有侵权,请联系我删除