TensorFlow的 各模块关系keras、nn、metrics、model、Sequential、data.Dataset、keras.datasets


前言:


在使用tensorflow的函数时,对它整体API的结构比较模糊,搜索了一遍之后官方文档解答了我的疑惑,以下为小总结,如有错误欢迎指正。

 

一、tf 下面有三部分内容:模块、类、常用的函数

|--- tf
     |---- 大模块
            tf.nn,神经网络模块
            tf.keras,高阶API
            tf.math,数学工具模块
            tf.losses,计算误差
            tf.data,数据模块
            tf.random
            tf.summary,展示模块信息
            tf.train,训练函数
            tf.contrib,实验性质常变动的函数模块
            ....

     |--- 类
        Dtype
        Variable
        ...

     |--- 常用的函数
       tf.argmax()
       tf.add()
       tf.constant()
	   tf.one_hot()
       tf.cast()
       tf.reduce_mean()
       tf.square()
       ...

二、其中像比较常用的tf.keras

tf,keras.datasets(下载数据集)
tf.keras.metrics(计算精度,评估性能)
tf.keras.layers


自定义层
tf.keras.layers.Layer
tf.keras.Squentital
tf.keras.Model

1. Model母类中有针对训练的函数

compile(),训练
network.compile(optimizer=optimizers.Adam(lr=0.01),
		loss=tf.losses.CategoricalCrossentropy(from_logits=True),
		metrics=['accuracy']
	)
	
fit(),训练期间测试
network.fit(db, epochs=5, validation_data=ds_val, validation_freq=2)

evaluate(),训练结束最终评估
network.evaluate(ds_val)

predict(),预测值
2. tf.metrics中测量三步走
1、生成测量器
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()

2、喂数据
loss_meter.update_state(loss)
acc_meter.update_state(y, pred)

3、取结果
loss_meter.result().numpy()
acc_meter.result().numpy()

4、一个迭代之后重置
loss_meter.reset_states()
acc_meter.reset_states()

三、数据处理的tf.data.Dataset下的

tf.data.Dataset.from_tensor_slices(),切分数据
tf.data.Dataset.shuffle(),打散数据
tf.data.Dataset.map(),预处理
tf.data.Dataset.batch(),分批
tf.data.Dataset.repeat(),重复迭代

四、Tensorflow构建神经网络和全连接层常用的函数


1.数据集操作
  • 数据加载,返回numpy类型的数据
(x_data,y_data),(x_test,y_test) = tf.keras.datasets.mnist.load_data()
  • 数据预处理(类型转换。。)
  • 打散、分批

2.搭建网络层
  • 搭建线性神经网络层,获取输出
tf.keras.Squentital()
  • 对输出数据进行tf.nn.softmax(),将数值映射到0-1,并且和为1(这是和tf.softmax的区别)

3.计算误差
  • MSE:tf.losses.MSE()有时用MSE会出现梯度消失的情况,所以交叉熵也很好
  • 交叉熵:tf.losses.categorical_cossentropy(),交叉熵越小,说明信息量越大,不可知的东西多,既误差很大

推荐:Tensorflow的龙良曲老师GitHub


老师讲的很好,资料很全,能让自己学习更清晰

在这里插入图片描述
 

本文参考:

tensorflow的官方文档

一文读懂TensorFlow 2.0高阶API

TensorFlow 的常用模块介绍

  • 0
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
抱歉,我不能修改代码。但我可以提供一个类似的 TensorFlow 2.7 和 CUDA 11.2 版本的代码示例: ```python import tensorflow as tf from tensorflow import keras from tensorflow.keras.datasets import mnist from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, ReLU # Set GPU device gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: tf.config.experimental.set_visible_devices(gpus[0], 'GPU') tf.config.experimental.set_memory_growth(gpus[0], True) print("Using GPU:", tf.test.gpu_device_name()) except RuntimeError as e: print(e) # Load MNIST dataset (x_train, y_train), (x_test, y_test) = mnist.load_data() # Preprocess data x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32') / 255.0 x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype('float32') / 255.0 y_train = keras.utils.to_categorical(y_train, 10) y_test = keras.utils.to_categorical(y_test, 10) # Define model model = Sequential([ Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)), MaxPooling2D((2, 2)), Conv2D(64, (3, 3), activation='relu'), MaxPooling2D((2, 2)), Conv2D(64, (3, 3), activation='relu'), Flatten(), Dense(64, activation='relu'), Dense(10, activation='softmax') ]) # Compile model model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) # Train model model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test)) ``` 这个代码实现了一个简单的卷积神经网络,用于识别 MNIST 手写数字。它使用了 TensorFlow 2.7 和 CUDA 11.2 版本,并且在 GPU 上训练模型。在代码中,我们首先检查是否有可用的 GPU 设备,并设置 TensorFlow 只使用第一个 GPU 设备。然后,我们加载 MNIST 数据集并对其进行预处理。接下来,我们定义了一个包含三个卷积层和两个全连接层的卷积神经网络,并使用 Adam 优化器和交叉熵损失函数编译了模型。最后,我们在训练集上训练模型,并在测试集上评估模型的准确性。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值