让我们创建一个可与所有后端配合使用的自定义密集层:
class MyDense(keras.layers.Layer):
def __init__(self, units, activation=None, name=None):
super().__init__(name=name)
self.units = units
self.activation = keras.activations.get(activation)
def build(self, input_shape):
input_dim = input_shape[-1]
self.w = self.add_weight(
shape=(input_dim, self.units),
initializer=keras.initializers.GlorotNormal(),
name="kernel",
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,),
initializer=keras.initializers.Zeros(),
name="bias",
trainable=True,
)
def call(self, inputs):
# Use Keras ops to create backend-agnostic layers/metrics/etc.
x = keras.ops.matmul(inputs, self.w) + self.b
return self.activation(x)
接下来,让我们制作一个依赖于keras.random命名空间的自定义Dropout层:
class MyDropout(keras.layers.Layer):
def __init__(self, rate, name=None):
super().__init__(name=name)
self.rate = rate
# Use seed_generator for managing RNG state.
# It is a state element and its seed variable is
# tracked as part of `layer.variables`.
self.seed_generator = keras.random.SeedGenerator(1337)
def call(self, inputs):
# Use `keras.random` for random ops.
return keras.random.dropout(inputs, self.rate, seed=self.seed_generator)
接下来,让我们编写一个自定义子类模型,使用我们的两个自定义层:
class MyModel(keras.Model):
def __init__(self, num_classes):
super().__init__()
self.conv_base = keras.Sequential(
[
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
keras.layers.GlobalAveragePooling2D(),
]
)
self.dp = MyDropout(0.5)
self.dense = MyDense(num_classes, activation="softmax")
def call(self, x):
x = self.conv_base(x)
x = self.dp(x)
return self.dense(x)
让我们编译并适配它:
model = MyModel(num_classes=10)
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="acc"),
],
)
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=1, # For speed
validation_split=0.15,
)
现在咱们演绎如下:
在本地的TensorFlow虚拟环境中,首先导入keras:
from tensorflow import keras
(可以在Jupyter Notebook中运行)
如果在演绎执行中出错,可能是Keras版本问题,使用如下命令升级keras。
sudo pip install --upgrade keras
执行结果:
训练模型
在任意数据源上训练模型
所有的Keras模型都可以在各种数据来源上进行训练和评估,与您使用的后端无关。这包括:
NumPy数组 Pandas数据框 TensorFlow tf.data.Dataset对象 PyTorch DataLoader对象 Keras PyDataset对象 无论您使用TensorFlow、JAX还是PyTorch作为Keras后端,它们都可以工作。
让我们尝试使用PyTorch DataLoader:
import torch
# Create a TensorDataset
train_torch_dataset = torch.utils.data.TensorDataset(
torch.from_numpy(x_train), torch.from_numpy(y_train)
)
val_torch_dataset = torch.utils.data.TensorDataset(
torch.from_numpy(x_test), torch.from_numpy(y_test)
)
# Create a DataLoader
train_dataloader = torch.utils.data.DataLoader(
train_torch_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
val_torch_dataset, batch_size=batch_size, shuffle=False
)
model = MyModel(num_classes=10)
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="acc"),
],
)
model.fit(train_dataloader, epochs=1, validation_data=val_dataloader)
现在让我们尝试使用tf.data来完成这个任务:
import tensorflow as tf
**自我介绍一下,小编13年上海交大毕业,曾经在小公司待过,也去过华为、OPPO等大厂,18年进入阿里一直到现在。**
**深知大多数Python工程师,想要提升技能,往往是自己摸索成长或者是报班学习,但对于培训机构动则几千的学费,着实压力不小。自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞不前!**
**因此收集整理了一份《2024年Python开发全套学习资料》,初衷也很简单,就是希望能够帮助到想自学提升又不知道该从何学起的朋友,同时减轻大家的负担。**
![img](https://img-blog.csdnimg.cn/img_convert/eba97ef4b20f0f269a573744e8398527.png)
![img](https://img-blog.csdnimg.cn/img_convert/b9683345245e387e61289ad9ee7fe738.png)
![img](https://img-blog.csdnimg.cn/img_convert/c2da324b9f8373bfbbc8423a1909059b.png)
![img](https://img-blog.csdnimg.cn/img_convert/34dbdb36466e5e658b2067662304d309.png)
![img](https://img-blog.csdnimg.cn/img_convert/6c361282296f86381401c05e862fe4e9.png)
![img](https://img-blog.csdnimg.cn/img_convert/9f49b566129f47b8a67243c1008edf79.png)
**既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!**
**由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新**
**如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注:Python)**
67243c1008edf79.png)
**既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!**
**由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新**
**如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注:Python)**
![](https://img-blog.csdnimg.cn/img_convert/6abc07de8be7d35103deed41e68a02ea.jpeg)