全连接cifar10数据集图像分类(tensorflow-gpu2.6)-自定义网络类实现

import numpy as np
import tensorflow as tf

import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import tensorflow.keras.optimizers as optimizers
import tensorflow.keras.metrics as metrics
import tensorflow.keras.datasets as datasets
from tensorflow.keras import Sequential

import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd

#定义数据预处理函数
def preprocess(x,y):
    #转换为特征值浮点类型以及归一化
    x=tf.cast(x,tf.float32)/255.
    #展平
    x=tf.reshape(x,[32*32*3])
    #转换标签值为整型
    y=tf.cast(y,tf.int32)
    #返回处理后的值
    return x,y
#加载数据集-cifar10数据集分50000个训练集 10000个测试集 图片为32*32*3
(x,y),(x_test,y_test)=datasets.cifar10.load_data()
#删除标签维度为1的维度-训练集和测试集-把标签(50000,1)->(50000,)矩阵->向量
y=tf.squeeze(y)
y_test=tf.squeeze(y_test)
#对标签进行one_hot编码-训练集和测试集
y=tf.one_hot(y,depth=10)
y_test=tf.one_hot(y_test,depth=10)
#构造数据集对象-训练集和测试集
db_train=tf.data.Dataset.from_tensor_slices((x,y))
db_test=tf.data.Dataset.from_tensor_slices((x_test,y_test))
#数据预处理+随机打散+批训练
db_train=db_train.map(preprocess).shuffle(100).batch(128)
db_test=db_test.map(preprocess).shuffle(100).batch(128)
#定义自己的网络层类
class MyDense(layers.Layer):
    #初始化方法
    def __init__(self,inp_dim,oup_dim):
        super(MyDense,self).__init__()
        #添加网络参数
        self.kernel=self.add_variable('w',[inp_dim,oup_dim])
    #重写call方法
    def call(self,inputs,training=None):
        #定义前向传播逻辑
        x=inputs@self.kernel
        #返回输出值
        return x

#定义自己的网络类
class MyNetwork(keras.Model):
    #初始化方法
    def __init__(self):
        super(MyNetwork, self).__init__()
        #使用自己的网络层类堆叠网络 如果不在写激活函数层 则在call方法中用函数方式调用激活函数
        self.fc1=MyDense(32*32*3,256)
        self.fc2=MyDense(256,128)
        self.fc3=MyDense(128,64)
        self.fc4=MyDense(64,32)
        self.fc5=MyDense(32,10)
    #重写call方法
    def call(self,inputs,training=None):
        #定义前向传播逻辑
        x=self.fc1(inputs)
        x=tf.nn.relu(x)
        x=self.fc2(x)
        x=tf.nn.relu(x)
        x=self.fc3(x)
        x=tf.nn.relu(x)
        x=self.fc4(x)
        x=tf.nn.relu(x)
        x=self.fc5(x)
        #返回输出值
        return x

#创建网络实例
network=MyNetwork()
#装配网络
network.compile(optimizer=optimizers.Adam(0.01),loss=keras.losses.CategoricalCrossentropy(from_logits=True),metrics='accuracy')
#训练网络
network.fit(db_train,epochs=30,validation_data=db_test,validation_freq=2)
#验证(测试)网络
network.evaluate(db_test)
#保存参数
network.save_weights('weights.ckpt')
print('参数已经保存')
#删除网络 释放资源
del network
#------------加载网络-----------
#恢复网络模型-装配
network=MyNetwork()
network.compile(optimizer=optimizers.Adam(0.01),loss=keras.losses.CategoricalCrossentropy(from_logits=True),metrics='accuracy')
#加载参数
network.load_weights('weights.ckpt')
print('参数已经加载')
#验证模型
network.evaluate(db_test)






控制台输出结果:

Epoch 1/30
2023-07-09 20:57:02.630087: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
2023-07-09 20:57:03.613787: I tensorflow/stream_executor/cuda/cuda_blas.cc:1760] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
391/391 [==============================] - 3s 4ms/step - loss: 2.1952 - accuracy: 0.2295
Epoch 2/30
391/391 [==============================] - 2s 5ms/step - loss: 1.8685 - accuracy: 0.3157 - val_loss: 1.8184 - val_accuracy: 0.3325
Epoch 3/30
391/391 [==============================] - 1s 4ms/step - loss: 1.8131 - accuracy: 0.3393
Epoch 4/30
391/391 [==============================] - 2s 4ms/step - loss: 1.7851 - accuracy: 0.3495 - val_loss: 1.8058 - val_accuracy: 0.3405
Epoch 5/30
391/391 [==============================] - 1s 3ms/step - loss: 1.7655 - accuracy: 0.3593
Epoch 6/30
391/391 [==============================] - 2s 4ms/step - loss: 1.7369 - accuracy: 0.3720 - val_loss: 1.7356 - val_accuracy: 0.3782
Epoch 7/30
391/391 [==============================] - 2s 5ms/step - loss: 1.7130 - accuracy: 0.3830
Epoch 8/30
391/391 [==============================] - 2s 5ms/step - loss: 1.6961 - accuracy: 0.3919 - val_loss: 1.6657 - val_accuracy: 0.4020
Epoch 9/30
391/391 [==============================] - 2s 5ms/step - loss: 1.6802 - accuracy: 0.3986
Epoch 10/30
391/391 [==============================] - 2s 6ms/step - loss: 1.6747 - accuracy: 0.4000 - val_loss: 1.6767 - val_accuracy: 0.3996
Epoch 11/30
391/391 [==============================] - 1s 4ms/step - loss: 1.6609 - accuracy: 0.4046
Epoch 12/30
391/391 [==============================] - 2s 6ms/step - loss: 1.6528 - accuracy: 0.4103 - val_loss: 1.6693 - val_accuracy: 0.3985
Epoch 13/30
391/391 [==============================] - 2s 5ms/step - loss: 1.6584 - accuracy: 0.4048
Epoch 14/30
391/391 [==============================] - 2s 4ms/step - loss: 1.6532 - accuracy: 0.4072 - val_loss: 1.6353 - val_accuracy: 0.4082
Epoch 15/30
391/391 [==============================] - 2s 5ms/step - loss: 1.6334 - accuracy: 0.4140
Epoch 16/30
391/391 [==============================] - 2s 5ms/step - loss: 1.6411 - accuracy: 0.4117 - val_loss: 1.6429 - val_accuracy: 0.4117
Epoch 17/30
391/391 [==============================] - 2s 4ms/step - loss: 1.6330 - accuracy: 0.4128
Epoch 18/30
391/391 [==============================] - 2s 6ms/step - loss: 1.6274 - accuracy: 0.4153 - val_loss: 1.6835 - val_accuracy: 0.3887
Epoch 19/30
391/391 [==============================] - 2s 4ms/step - loss: 1.6277 - accuracy: 0.4143
Epoch 20/30
391/391 [==============================] - 2s 5ms/step - loss: 1.6266 - accuracy: 0.4167 - val_loss: 1.6558 - val_accuracy: 0.4063
Epoch 21/30
391/391 [==============================] - 2s 6ms/step - loss: 1.6199 - accuracy: 0.4211
Epoch 22/30
391/391 [==============================] - 2s 6ms/step - loss: 1.6172 - accuracy: 0.4190 - val_loss: 1.6785 - val_accuracy: 0.3972
Epoch 23/30
391/391 [==============================] - 2s 4ms/step - loss: 1.6201 - accuracy: 0.4173
Epoch 24/30
391/391 [==============================] - 2s 5ms/step - loss: 1.6158 - accuracy: 0.4161 - val_loss: 1.6956 - val_accuracy: 0.3920
Epoch 25/30
391/391 [==============================] - 2s 4ms/step - loss: 1.6198 - accuracy: 0.4147
Epoch 26/30
391/391 [==============================] - 2s 4ms/step - loss: 1.6163 - accuracy: 0.4190 - val_loss: 1.7060 - val_accuracy: 0.3869
Epoch 27/30
391/391 [==============================] - 2s 5ms/step - loss: 1.6111 - accuracy: 0.4194
Epoch 28/30
391/391 [==============================] - 2s 4ms/step - loss: 1.6127 - accuracy: 0.4215 - val_loss: 1.6890 - val_accuracy: 0.3977
Epoch 29/30
391/391 [==============================] - 2s 5ms/step - loss: 1.6184 - accuracy: 0.4170
Epoch 30/30
391/391 [==============================] - 2s 4ms/step - loss: 1.6062 - accuracy: 0.4224 - val_loss: 1.6402 - val_accuracy: 0.4136
79/79 [==============================] - 0s 4ms/step - loss: 1.6402 - accuracy: 0.4136
参数已经保存
参数已经加载
79/79 [==============================] - 0s 4ms/step - loss: 1.6402 - accuracy: 0.4136

Process finished with exit code 0
这里输出的loss和accuracy是一个epoch上的平均Loss和accuracy,要想查看每个batch或者指定step次数训练后的loss和accuracy,将在下一节介绍

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值