cifer数据集提供的是6万张32*32大小的三通道彩色图像,共有10类分别为:飞机,汽车,鸟,猫,鹿,狗,青蛙,马,船,卡车
结构: 一层卷积,两层全连接
6个5×5的卷积核,步长为1
2×2的最大值池化,步长为2
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import Model
import pandas as pd
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.cifar10.load_data()
x_train,x_test=x_train/255.0,x_test/255.0
class Baseline(Model):#继承父类Model
def __init__(self):
super(Baseline,self).__init__() #继承Model的方法属性
self.c1=tf.keras.layers.Conv2D(filters=6,kernel_size=(5,5),padding="same",input_shape=(32,32,3))
self.b1=tf.keras.layers.BatchNormalization()#批标准化
self.a1=tf.keras.layers.Activation("relu")
self.p1=tf.keras.layers.MaxPool2D(pool_size=(2,2),strides=2,padding="same")
self.d1=tf.keras.layers.Dropout(0.2)
self.flatten=tf.keras.layers.Flatten()
self.f1=tf.keras.layers.Dense(128,activation="relu")
self.d2=tf.keras.layers.Dropout(0.2)
self.f2=tf.keras.layers.Dense(10,activation="softmax")
def call(self, x):#重写call,实现自己的前向传播
x=self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.p1(x)
x = self.d1(x)
x = self.flatten(x)
x = self.f1(x)
x = self.d2(x)
y = self.f2(x)
return y
model=Baseline()
model.compile(optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics="sparse_categorical_accuracy")
checkpoint_save_path = "./Baseline_cifar10.ckpt"
#save_best_only=True,被监测数据的最佳模型就不会被覆盖
#save_weights_only=true,只有模型的权重会被保存否则保存整个模型
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True)
history=model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])
model.summary()
with open('./cifar10_weights.txt', 'w') as file: #存储训练好的参数
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
pd.DataFrame(history.history).to_csv("cifar_training_log.csv",index=False)
graph=pd.read_csv("cifar_training_log.csv")
graph.plot(figsize=(5, 7))
plt.xlim(0,4)
plt.ylim(0,2)
plt.grid(1)
plt.show()
num=np.random.randint(1,10000)
demo=tf.reshape(x_test[num],(1,32,32,3))
y_pred=np.argmax(model.predict(demo))
plt.imshow(x_test[num])
plt.show()
print("标签值:"+str(y_test[num,0])+"\n预测值:"+str(y_pred))
#飞机,汽车,鸟,猫,鹿,狗,青蛙,马,船,卡车