Keras高级接口代码复现及剖析

直接代码走起

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, datasets, optimizers, Sequential, metrics, optimizers


# 预处理
def preprocess(x, y):
    x = 2 * tf.cast(x, dtype=tf.float32) / 255 - 1
    y = tf.cast(y, dtype=tf.int32)
    return x, y


batchsz = 128
# [50k, 32, 32, 3], [10k, 1]

# 数据集的加载
(x, y), (x_val, y_val) = datasets.cifar10.load_data()

# onehot处理
y = tf.squeeze(y)  # 从张量的形状中移去尺寸为1的维度
y_val = tf.squeeze(y_val)
y = tf.one_hot(y, depth=10)  # [50k,10]
y_val = tf.one_hot(y_val, depth=10)  # [10k, 10]
# print('datasets:', x.shape, y.shape)
# datasets: (50000, 32, 32, 3) (50000, 10)

# 对数据集进行处理
train_db = tf.data.Dataset.from_tensor_slices((x, y))  # 切分传入Tensor的第一个维度,生成相应的dataset
# print(train_db)
# <TensorSliceDataset shapes: ((32, 32, 3), (10,)), types: (tf.uint8, tf.float32)>

train_db = train_db.map(preprocess).shuffle(10000).batch(batchsz)
# Dataset中的每个元素都会被当作preprocess函数的输入,并将函数返回值作为新的Dataset
# shuffle打乱数据集

test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(preprocess).batch(batchsz)

sample = next(iter(train_db))
# print(sample)
'''
iter(Iterable)将迭代对象转换为迭代器
next(Iterator,[,default]) 返回迭代器中的下一位
default--可选,用于设置在没有下一个元素时返回该默认值,如果不设置,又没有下一个元素则会触发StopIteration异常。

(<tf.Tensor: shape=(128, 32, 32, 3), dtype=float32, numpy=
array([[[[-0.827451  , -0.7019608 , -0.85882354],
         [-0.8117647 , -0.6784314 , -0.827451  ],
         [-0.85882354, -0.73333335, -0.8666667 ],
         ...,
最终的返回值时第一个batch,shape=(128, 32, 32, 3)
'''
print('batch:', sample[0].shape, sample[1].shape, len(sample))


# batch: (128, 32, 32, 3) (128, 10) 2

# 自定义网络层-类
class MyDense(layers.Layer):  # 继承layers.Layer基类
    def __init__(self, inp_dim, outp_dim):
        '''
        定义网络层的参数
        :param inp_dim: 输入特征长度
        :param outp_dim:输出特征长度
        '''
        super(MyDense, self).__init__()  # 对继承自父类的属性进行初始化
        self.keras = self.add_weight('w', [inp_dim, outp_dim],)
        #self.bias = self.add_weight('b', [outp_dim])

    def call(self, inputs, training=None):  # training = True 训练模式 training = False 测试模式 ,None默认测试模式
        # 正向传播
        x = inputs @ self.keras
        return x


# 自定义网络
class MyNetwork(keras.Model):  # 继承keras.Model基类
    def __init__(self):
        super(MyNetwork, self).__init__()
        # 定义网络结构
        self.fc1 = MyDense(32 * 32 * 3, 256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)

    def call(self, inputs, training=None):
        # inputs.shape = (batchsz,32,32,3)
        x = tf.reshape(inputs, [-1, 32 * 32 * 3])
        # 这里是指将inputs矩阵重建为[-1,32*32*3]形状的矩阵
        x = self.fc1(x)
        x = tf.nn.relu(x)

        x = self.fc2(x)
        x = tf.nn.relu(x)

        x = self.fc3(x)
        x = tf.nn.relu(x)

        x = self.fc4(x)
        x = tf.nn.relu(x)

        x = self.fc5(x)

        return x


network = MyNetwork()  # 实例化
network.compile(optimizer=optimizers.Adam(lr=0.001),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
'''
compile功能:编译创建好的模型,网络模型搭建完后,需要对网络的学习过程进行配置,否则在调用 fit 或 evaluate 时会抛出异常。
optimizer: 优化器
loss:损失函数
metrics: 评价函数
'''
network.fit(train_db, epochs=15, validation_data=test_db, validation_freq=1)
# Trains the model for a fixed number of epochs (iterations on a dataset).
# 训练集 迭代次数 验证集 验证频率
network.evaluate(test_db)
network.save_weights('ckpt/weights.ckpt')
# 保存权重
del network  # 删除网络
print('Saved!')
# ------------------------------------------------------------------------
network = MyNetwork()
network.compile(optimizer=optimizers.Adam(lr=1e-3),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
network.load_weights('ckpt/weights.ckpt')  # 加载权重
print('loaded weights from file.')
network.evaluate(test_db)

print

Epoch 1/15
391/391 [==============================] - 5s 10ms/step - loss: 1.8407 - accuracy: 0.3437 - val_loss: 1.5644 - val_accuracy: 0.4487
Epoch 2/15
391/391 [==============================] - 4s 9ms/step - loss: 1.5218 - accuracy: 0.4638 - val_loss: 1.4836 - val_accuracy: 0.4803
Epoch 3/15
...
...
Epoch 15/15
391/391 [==============================] - 4s 9ms/step - loss: 0.7371 - accuracy: 0.7382 - val_loss: 1.7126 - val_accuracy: 0.5144
79/79 [==============================] - 1s 6ms/step - loss: 1.7126 - accuracy: 0.5144
Saved!
loaded weights from file.
79/79 [==============================] - 1s 6ms/step - loss: 1.6934 - accuracy: 0.5207

思维导图

在这里插入图片描述

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值