第四章:Tensorflow 2.0 实现自定义层和自定义模型的编写并实现cifar10 的全连接网络(理论+实战)

通过阅读少奶奶上一篇博文,我们初步实现了调用Tensorflow 2.0 提供的keras模块下的方法,使用极少的代码完成了一个五层的全连接网络,具体代码如下:

                                 

若小伙伴需要复习一下上一篇博文的内容的话,少奶奶这里给出对应链接,希望对大家有所帮助:

             第三章:Tensorflow 2.0 利用高级接口实现对cifar10 数据集的全连接(理论+实战实现)

我相信,通过阅读上一篇博文,大家已经初步领略了Tensorflow 2.0 良好的封装性,我们只需要书写几行代码就能实现较为复杂的全连接神经网络。那么,我们如何实现自定义的神经网络呢?下面开始正文讲解。

利用Tensorflow 2.0 的keras模块中的方法,实现自定义的神经网络层和模型

1,实现自定义层

要实现自定义层,我们的类需要先继承keras.layers.Layer,并在新的类中实现_ _init_ _()和call()这两个方法。因为,当模型进行前向传播时,model类会先调用Dense层母类中的_ _ init _ _()方法,母类中的_ _ init_ _()方法会调用子类当中的call()方法,并执行call()中对应的逻辑。具体代码如下:  

   

我们可以简单理解为,在自定义层类中,_ _init_ _()方法需要先调用母类中的_ _init_ _()方法,然后定义一些参数(通常是权重)。然后,在call()方法中,构建层结构(这里构建的是线性结构)

2,实现自定义模型

要实现自定义模型,我们的模型类需要先继承keras.layers.Model类,这样我们的模型类就能使用母类中的许多方法,例如:compile、fit、  evaluate等。然后,在模型类中实现_ _init_ _()和call()这两个方法。其原理和自定义层是一样的。下面给出代码

                  

我们在模型类的__init__()方法中,调用自定义层,并为它设置每一层的连接数。在call方法中,我们为每一层传入训练数据,并添加一个激活函数,提高模型的表达能力。

实战

import tensorflow as tf
from tensorflow import keras
from  tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

import os


os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"


(x ,y), (test_x ,test_y) = datasets.cifar10.load_data()
print(x.shape, y.shape)

def progross (x, y):
    x = 2 * tf.cast(x, dtype=tf.float32) / 255. - 1.
    y = tf.one_hot(y, depth=10, dtype=tf.int32)
    y = tf.squeeze(y,axis=0)
    return x, y


db_train = tf.data.Dataset.from_tensor_slices((x,y)).shuffle(1000)
db_train = db_train.map(progross).batch(128)

db_test = tf.data.Dataset.from_tensor_slices((test_x,test_y))
db_test = db_test.map(progross).batch(128)


train_inter = iter(db_train)
train_next = next(train_inter)
print(train_next[0].shape, train_next[1][1])


# 继承layers.Layer方法
class myDense (layers.Layer):
    # 实现__init__()方法
    def __init__(self, in_dim, out_dim):
        # 调用母类中的__init__()
        super(myDense, self).__init__()

        self.kernel = self.add_variable('w', [in_dim, out_dim])
        self.bias = self.add_variable('b', [out_dim])
    # 实现call()方法
    def call(self, inputs, training = None):
        # 构建模型结构
        out = inputs @ self.kernel + self.bias
        return out

# 继承keras.Model母类
class myModel (keras.Model):

    def __init__(self):
        # 调用母类中的__init__()方法
        super(myModel, self).__init__()
        # 调用自定义层类 并构建每一层的连接数
        self.fc1 = myDense(32*32*3 ,256)
        self.fc2 = myDense(256, 128)
        self.fc3 = myDense(128, 64)
        self.fc4 = myDense(64, 32)
        self.fc5 = myDense(32, 10)

    # 构建一个五层的全连接网络
    def call(self, inputs, training = None):
        # 把输入模型中的图片进行打平操作
        inputs = tf.reshape(inputs, [-1, 32 * 32 * 3])
        # 把训练数据输入到自定义层中
        x = self.fc1(inputs)
        # 利用relu函数进行非线性激活操作
        out = tf.nn.relu(x)
        x = self.fc2(out)
        out = tf.nn.relu(x)
        x = self.fc3(out)
        out = tf.nn.relu(x)
        x = self.fc4(out)
        out = tf.nn.relu(x)
        x = self.fc5(out)
        return x


netWork = myModel()
netWork.build(input_shape=[None, 32*32*3])
netWork.summary()

netWork.compile(optimizer = optimizers.Adam(lr = 1e-3),
                loss = tf.losses.CategoricalCrossentropy(from_logits = True),
                metrics=['accuracy']
                )
netWork.fit(db_train, epochs=10, validation_data=db_test,
            validation_freq=2
            )
netWork.evaluate(db_test)

在下一章中,少奶奶将使用13层的卷积神经网络实现对cifar 100数据集的训练,其中代码的编写将综合前四章中所提到的所有知识点,希望对大家有帮助,

 

开篇:开启Tensorflow 2.0时代

第一章:Tensorflow 2.0 实现简单的线性回归模型(理论+实践)

第二章:Tensorflow 2.0 手写全连接MNIST数据集(理论+实战)

第三章:Tensorflow 2.0 利用高级接口实现对cifar10 数据集的全连接(理论+实战实现)

第四章:Tensorflow 2.0 实现自定义层和自定义模型的编写并实现cifar10 的全连接网络(理论+实战)

第五章:Tensorflow 2.0 利用十三层卷积神经网络实现cifar 100训练(理论+实战)

第六章:优化神经网络的技巧(理论)

第七章:Tensorflow2.0 RNN循环神经网络实现IMDB数据集训练(理论+实践)

第八章:Tensorflow2.0 传统RNN缺陷和LSTM网络原理(理论+实战)

  • 4
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值