tfe 模型保存和载入

原文链接: tfe 模型保存和载入

上一篇: tfe 配合 Keras model 线性拟合 和 自己处理梯度进行线性拟合

下一篇: tfe mnist 使用dataset 分类 保存和载入

简单参数保存和载入

如果路径不存在回自动创建

import tensorflow as tf
import tensorflow.contrib.eager as tfe

tf.enable_eager_execution()
x = tfe.Variable(10.)

checkpoint = tfe.Checkpoint(x=x)
x.assign(2.)  # Assign a new value to the variables and save.
print(x.numpy())  # 2.0

save_path = checkpoint.save('./ckpt/')
print(save_path)  # ./ckpt/-1

x.assign(11.)  # Change the variable after saving.
print(x.numpy())  # 11.0

# Restore values from the checkpoint
checkpoint.restore(save_path)
print(x.numpy())  # 2.0

2d827ab74bfe7d6a766053194ac10c61ea4.jpg

使用Keras的Model时,需要保存很多参数,此时使用对象保存的方式

载入时使用的是模型的文件夹路径

主要代码

# 保存训练参数
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
checkpoint_dir = './save/'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
root = tfe.Checkpoint(optimizer=optimizer,
                      model=model,
                      optimizer_step=tf.train.get_or_create_global_step())

root.save(file_prefix=checkpoint_prefix)
# or
# root.restore(tf.train.latest_checkpoint(checkpoint_dir))

完整代码

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import os

tf.enable_eager_execution()


class Model(tf.keras.Model):
    def __init__(self):
        super(Model, self).__init__()
        self.W = tfe.Variable(5., name='weight')
        self.B = tfe.Variable(10., name='bias')

    def call(self, inputs):
        return inputs * self.W + self.B


# A toy dataset of points around 3 * x + 2
NUM_EXAMPLES = 2000
inputs = tf.random_normal([NUM_EXAMPLES])
noise = tf.random_normal([NUM_EXAMPLES])
targets = inputs * 3 + 2 + noise


# The loss function to be optimized
def loss():
    error = model(inputs) - targets
    return tf.reduce_mean(tf.square(error))


# Define:
# 1. A model.
# 2. Derivatives of a loss function with respect to model parameters.
# 3. A strategy for updating the variables based on the derivatives.
model = Model()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

# 载入训练参数
# checkpoint_dir = './save/'
# checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
# root = tfe.Checkpoint(optimizer=optimizer,
#                       model=model,
#                       optimizer_step=tf.train.get_or_create_global_step())
# root.restore(tf.train.latest_checkpoint(checkpoint_dir))
#


# Training loop
for i in range(300):
    optimizer.minimize(loss)
    if i % 20 == 0:
        print(model.W.numpy(), model.B.numpy())

# 保存训练参数
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
checkpoint_dir = './save/'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
root = tfe.Checkpoint(optimizer=optimizer,
                      model=model,
                      optimizer_step=tf.train.get_or_create_global_step())

root.save(file_prefix=checkpoint_prefix)
# or
# root.restore(tf.train.latest_checkpoint(checkpoint_dir))

72e5904347d2162a796d0a1cf7926c602fa.jpg

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
tfe 函数是 Matlab 信号处理工具箱中的一个函数,用于计算两个信号的传递函数估计。它的基本用法如下: ```matlab [H, f] = tfe(x, y, fs, window, noverlap, nfft) ``` 其中,参数的含义如下: - `x`:输入信号的向量,通常是信号源。 - `y`:输出信号的向量,通常是测量的响应。 - `fs`:信号的采样率,单位为 Hz。 - `window`:窗函数的名称或向量。默认值为 `'hann'`。 - `noverlap`:相邻窗口之间重叠的样本数。默认值为窗口长度的一半。 - `nfft`:用于计算功率谱密度的 FFT 点数。默认值为窗口长度。 函数的输出包括: - `H`:传递函数估计的向量。如果 `x` 是多列的,则 `H` 将是一个矩阵,每列代表一个信道。 - `f`:频率向量,单位为 Hz。 例如,以下代码演示了如何使用 tfe 函数计算两个信号之间的传递函数估计: ```matlab % 生成输入信号和输出信号 fs = 1000; % 采样率为 1000 Hz t = (0:1/fs:1)'; x = sin(2*pi*50*t) + sin(2*pi*120*t); y = 0.5*x + randn(size(t)); % 计算传递函数估计 [H, f] = tfe(x, y, fs); % 绘制结果 figure; subplot(2, 1, 1); plot(t, x, 'b', t, y, 'r'); xlabel('Time (s)'); ylabel('Amplitude'); legend('Input', 'Output'); title('Input and output signals'); subplot(2, 1, 2); plot(f, abs(H).^2); xlim([0 500]); xlabel('Frequency (Hz)'); ylabel('Power'); title('Power spectral density'); ``` 此代码将生成两个信号,一个是包含 50 Hz 和 120 Hz 正弦波的混合信号,另一个是加入噪声的输出信号。然后,tfe 函数被用于计算输入信号和输出信号之间的传递函数估计。最后,绘制了输入信号、输出信号和功率谱密度的图形。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值