卷积神经网络实现---cifar10数据集

cifer数据集提供的是6万张32*32大小的三通道彩色图像,共有10类分别为:飞机,汽车,鸟,猫,鹿,狗,青蛙,马,船,卡车

结构: 一层卷积,两层全连接
6个5×5的卷积核,步长为1
2×2的最大值池化,步长为2

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import Model
import pandas as pd

(x_train,y_train),(x_test,y_test)=tf.keras.datasets.cifar10.load_data()

x_train,x_test=x_train/255.0,x_test/255.0

class Baseline(Model):#继承父类Model
    def __init__(self):
        super(Baseline,self).__init__()  #继承Model的方法属性
        self.c1=tf.keras.layers.Conv2D(filters=6,kernel_size=(5,5),padding="same",input_shape=(32,32,3))
        self.b1=tf.keras.layers.BatchNormalization()#批标准化
        self.a1=tf.keras.layers.Activation("relu")
        self.p1=tf.keras.layers.MaxPool2D(pool_size=(2,2),strides=2,padding="same")
        self.d1=tf.keras.layers.Dropout(0.2)
        
        self.flatten=tf.keras.layers.Flatten()
        self.f1=tf.keras.layers.Dense(128,activation="relu")
        self.d2=tf.keras.layers.Dropout(0.2)
        self.f2=tf.keras.layers.Dense(10,activation="softmax")
        
    def call(self, x):#重写call,实现自己的前向传播
        x=self.c1(x)
        x = self.b1(x)
        x = self.a1(x)
        x = self.p1(x)
        x = self.d1(x)
 
        x = self.flatten(x)
        x = self.f1(x)
        x = self.d2(x)
        y = self.f2(x)
        return y

model=Baseline()

model.compile(optimizer="adam",
             loss="sparse_categorical_crossentropy",
             metrics="sparse_categorical_accuracy")

checkpoint_save_path = "./Baseline_cifar10.ckpt"
#save_best_only=True,被监测数据的最佳模型就不会被覆盖
#save_weights_only=true,只有模型的权重会被保存否则保存整个模型
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True)

history=model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])

model.summary()

with open('./cifar10_weights.txt', 'w') as file:  #存储训练好的参数
    for v in model.trainable_variables:
        file.write(str(v.name) + '\n')
        file.write(str(v.shape) + '\n')
        file.write(str(v.numpy()) + '\n')
        
pd.DataFrame(history.history).to_csv("cifar_training_log.csv",index=False)
graph=pd.read_csv("cifar_training_log.csv")

graph.plot(figsize=(5, 7))
plt.xlim(0,4)
plt.ylim(0,2)
plt.grid(1)
  
plt.show()

num=np.random.randint(1,10000)
demo=tf.reshape(x_test[num],(1,32,32,3))
y_pred=np.argmax(model.predict(demo))
plt.imshow(x_test[num])
plt.show()
print("标签值:"+str(y_test[num,0])+"\n预测值:"+str(y_pred))
#飞机,汽车,鸟,猫,鹿,狗,青蛙,马,船,卡车




  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值