tf2学习 mnist

源码

# !/usr/bin/python
# -*- coding: UTF-8 -*-


import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
import numpy as np

(train_x,train_y),(test_x,test_y) = datasets.mnist.load_data()
train_x = tf.convert_to_tensor(train_x, dtype=tf.float32) / 255.
train_y = train_y.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices((train_x,train_y))
# batch_size = 100,数据集最多重复10次
train_dataset = train_dataset.batch(100).repeat(10)

# 用keras.Sequential构建一个模型,并从keras.optimizers实例化一个随机梯度下降优化器。
model = tf.keras.Sequential([
    layers.Reshape(target_shape = (28 * 28,), input_shape=(28, 28)),
    layers.Dense(256, activation = tf.nn.relu),
    layers.Dense(256, activation = tf.nn.relu),
    layers.Dense(256, activation = tf.nn.relu),
    layers.Dense(10)
])

model.summary()


optimizer = optimizers.Adam(lr=1e-3)
acc = metrics.Accuracy()

# @tf.function
# def compute_loss(logits,label):
#     return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels = label))

for step,(x,y) in enumerate(train_dataset):

    """
    使用tf.GradientTape相对于网络的可训练变量手动计算损耗的梯度。GradientTape只是TensorFlow 2.0中执行梯度步骤的多种方法之一
    
    Tf.GradientTape:通过在上下文管理器中记录操作,针对给定变量手动计算损耗梯度。这是执行优化程序步骤的最灵活的方法,因为我们可以直接使用渐变,而无需预先定义的Keras模型或损失函数。
    Model.train():Keras的内置函数,用于遍历数据集并在其上拟合Keras.Model。这通常是训练Keras模型的最佳选择,并带有进度条显示,验证拆分,多处理和生成器支持的选项。
    Optimizer.minimize():通过给定的损失函数进行计算和微分,并执行一个步骤以通过梯度下降将其最小化。此方法易于实现,并且可以方便地应用于任何现有的计算图上,以进行有效的优化步骤。
    """
    with tf.GradientTape() as tape:

        # loss = compute_loss(logits=output, label=y)

        output = model(x)  # [batch_size,28,28] => [batch_size,10]
        y_onehot = tf.one_hot(y, depth=10) # [batch_size,1] => [batch_size,10]
        loss = tf.square(output - y_onehot)
        loss = tf.reduce_mean(loss)# [batch_size,10] => [batch_size,1]

    # 更新准确率
    acc.update_state(tf.argmax(output,axis = 1),y)

    # 求梯度
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))


    # 没200 步打印一次,并且重新统计准确率
    if step % 200 == 0:
        print(step, 'loss:', float(loss), 'acc:', acc.result().numpy())
        acc.reset_states()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值