Tensorflow的变量和模型保存以及模型应用

Table of Contents

一、模型部分(成功)

1.保存的模型

import tensorflow as tf
import numpy as np
# To plot pretty figures
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
t_min, t_max = 0, 30
resolution = 0.1

def time_series(t):
    return t * np.sin(t) / 3 + 2 * np.sin(t*5)

def next_batch(batch_size, n_steps):
    t0 = np.random.rand(batch_size, 1) * (t_max - t_min - n_steps * resolution)
    Ts = t0 + np.arange(0., n_steps + 1) * resolution
    ys = time_series(Ts)
    return ys[:, :-1].reshape(-1, n_steps, 1), ys[:, 1:].reshape(-1, n_steps, 1)

t = np.linspace(t_min, t_max, int((t_max - t_min) / resolution))

n_steps = 20
t_instance = np.linspace(12.2, 12.2 + resolution * (n_steps + 1), n_steps + 1)
n_steps = 20
n_inputs =1
n_neurons = 100

X = tf.placeholder(tf.float32, [None, n_steps, n_inputs],name='X')
y = tf.placeholder(tf.float32, [None, n_steps, n_outputs],name='y')

cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons, activation=tf.nn.relu )
rnn_outputs, states = tf.nn.dynamic_rnn(cell=cell, inputs=X, dtype=tf.float32)

n_outputs = 1
learning_rate = 0.001

stacked_rnn_outputs = tf.reshape(tensor=rnn_outputs, shape=[-1, n_neurons])
stacked_outputs = tf.layers.dense(inputs=stacked_rnn_outputs, units=n_outputs)
outputs = tf.reshape(tensor=stacked_outputs, shape=[-1, n_steps, n_outputs],name='outputs')

loss = tf.reduce_mean(tf.square(outputs - y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss)

init = tf.global_variables_initializer()
saver = tf.train.Saver()

n_iterations = 1500
batch_size = 50

with tf.Session() as sess:
    init.run()
    for iteration in range(n_iterations):
        X_batch, y_batch = next_batch(batch_size, n_steps)
        sess.run(training_op, feed_dict={X:X_batch, y:y_batch})
        if iteration %100 ==0:
            mse = loss.eval(feed_dict={X:X_batch, y:y_batch})
            print(iteration, "\tMSE:", mse)
            
    X_new = time_series(np.array(t_instance[:-1].reshape(-1, n_steps, n_inputs)))
    y_pred = sess.run(outputs, feed_dict={X: X_new})
    
    saver.save(sess, "./my_time_series_model")

2.载入模型并用于预测

import tensorflow as tf
import numpy as np
def reset_graph(seed=42):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)
reset_graph()
1.载入图结构和参数
sess = tf.Session()
saver = tf.train.import_meta_graph('./my_time_series_model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))

INFO:tensorflow:Restoring parameters from ./my_time_series_model
2.获取图
graph = tf.get_default_graph()
3.获取tensor
X = graph.get_tensor_by_name('X:0')
y = graph.get_tensor_by_name("y:0")
outputs = graph.get_tensor_by_name("outputs:0") 
4.新的input准备
X_new = time_series(np.array(t_instance[:-1].reshape(-1, 20, 1)))
y.shape
TensorShape([Dimension(None), Dimension(20), Dimension(1)])
5.应用与预测
y_pred = sess.run(outputs, feed_dict={X: X_new})
print(y_pred)
[[[-3.44828582]
  [-2.48405623]
  [-1.13649726]
  [ 0.71962416]
  [ 2.01745081]
  [ 3.13937259]
  [ 3.54828739]
  [ 3.36234236]
  [ 2.77184248]
  [ 2.10781217]
  [ 1.64527285]
  [ 1.5579648 ]
  [ 1.87219918]
  [ 2.7233479 ]
  [ 3.85228252]
  [ 5.06193066]
  [ 6.07513857]
  [ 6.63054752]
  [ 6.59069633]
  [ 5.9993453 ]]]
plt.title("Testing the model", fontsize=14)
plt.plot(t_instance[:-1], time_series(t_instance[:-1]), "bo", markersize=10, label="instance")
plt.plot(t_instance[1:], time_series(t_instance[1:]), "w*", markersize=10, label="target")
plt.plot(t_instance[1:], y_pred[0,:,0], "r.", markersize=10, label="prediction")
plt.legend(loc="upper left")
plt.xlabel("Time")

# save_fig("time_series_pred_plot")
plt.show()

在这里插入图片描述

n_steps=20
sequence1 = [0.]*20
for iteration in range(len(t) - n_steps):
    X_batch = np.array(sequence1[-n_steps:]).reshape(1, 20, 1)
    y_red = sess.run(outputs, feed_dict={X:X_batch})
    sequence1.append(y_pred[0,-1,0])

sequence2 = [time_series(i*resolution+t_min+(t_max-t_min/3)) for i in range(20)]
for iteration in range(len(t) - n_steps):
    X_batch = np.array(sequence2[-n_steps:]).reshape(1, n_steps, 1)
    y_pred = sess.run(outputs, feed_dict={X:X_batch})
    sequence2.append(y_pred[0,-1,0])

sequence2
[-11.310069100186947,
 -10.293466647143219,
 -9.0351304228581899,
 -7.7791526184875721,
 -6.7463786174128,
 -6.0810586339718817,
 -5.8164786703312812,
 -5.8679774217158975,
 -6.0550484641549618,
 -6.1471037973674534,
 -5.9216796338876341,
 -5.2208171081907668,
 -3.991795656706794,
 -2.3022262397260698,
 -0.32578778096384303,
 1.6979234284416602,
 3.5196245383171041,
 4.938914081326276,
 5.8508622581105945,
 6.2692415768875769,
 6.2414274,
 6.0063138,
 5.7208457,
 5.6004872,
 5.7733955,
 6.2796745,
 7.0539942,
 7.9423361,
 8.7371979,
 9.2194023,
 9.2210941,
 8.6577063,
 7.55442,
 6.0555,
 4.37221,
 2.7345319,
 1.3436384,
 0.32472184,
 -0.3159388,
 -0.6737811,
 -0.92999107,
 -1.3043951,
 -1.9887228,
 -3.10217,
 -4.6800175,
 -6.6327167,
 -8.7743168,
 -10.866323,
 -12.663467,
 -13.970881,
 -14.69599,
 -14.907141,
 -14.78536,
 -14.409622,
 -14.090124,
 -14.002066,
 -14.257333,
 -15.155252,
 -16.459118,
 -18.0779,
 -19.749088,
 -21.18313,
 -22.134678,
 -22.438381,
 -22.068893,
 -21.052273,
 -19.576031,
 -18.089518,
 -16.549288,
 -15.274453,
 -14.365784,
 -13.75396,
 -13.394633,
 -13.047565,
 -12.466007,
 -11.425198,
 -9.8138237,
 -7.6150599,
 -4.9504156,
 -2.0713785,
 0.75190848,
 3.2330799,
 5.1569152,
 6.4215879,
 7.0548859,
 7.2106261,
 7.115716,
 7.0375557,
 7.2040081,
 7.7478371,
 8.6662283,
 9.8395824,
 11.034622,
 11.977497,
 12.413702,
 12.166514,
 11.191109,
 9.5760918,
 7.5540438,
 5.3992167,
 3.4150782,
 1.812041,
 0.70379823,
 0.057163272,
 -0.30561826,
 -0.6333279,
 -1.2131691,
 -2.2672105,
 -3.9158158,
 -6.1540804,
 -8.8191996,
 -11.635813,
 -14.286608,
 -16.462812,
 -17.943993,
 -18.685251,
 -18.793633,
 -18.618328,
 -18.118195,
 -17.951168,
 -18.098774,
 -18.851744,
 -20.435009,
 -22.447792,
 -24.747829,
 -26.917685,
 -28.524897,
 -29.264297,
 -29.117596,
 -28.16255,
 -26.49226,
 -24.614424,
 -23.007265,
 -21.564766,
 -20.551498,
 -19.873325,
 -19.364033,
 -18.779448,
 -17.821594,
 -16.258081,
 -13.970452,
 -10.981444,
 -7.4684205,
 -3.7501016,
 -0.15635261,
 2.9772096,
 5.4395175,
 7.132719,
 8.1294727,
 8.6363735,
 8.9281454,
 9.2941914,
 9.9697638,
 11.079341,
 12.562345,
 14.237216,
 15.809559,
 16.950453,
 17.381609,
 16.938213,
 15.613298,
 13.565508,
 11.076948,
 8.5601444,
 6.3773327,
 4.7868357,
 3.8481448,
 3.5138502,
 3.4309714,
 3.2506099,
 2.566304,
 1.0682412,
 -1.3450102,
 -4.6130099,
 -8.4637041,
 -12.477572,
 -16.190722,
 -19.199072,
 -21.232229,
 -22.2255,
 -22.393307,
 -22.079338,
 -21.793861,
 -21.614162,
 -22.374886,
 -23.66847,
 -25.927404,
 -28.822844,
 -31.882042,
 -34.618858,
 -36.46328,
 -37.181564,
 -36.713959,
 -35.274097,
 -32.990501,
 -30.3675,
 -28.14781,
 -26.146204,
 -24.562925,
 -23.419697,
 -22.426777,
 -21.289103,
 -19.713984,
 -17.391224,
 -14.248542,
 -10.339687,
 -5.8647647,
 -1.2366034,
 3.1689858,
 6.9699464,
 9.9272709,
 11.941613,
 13.105474,
 13.669283,
 13.953389,
 14.285711,
 14.963319,
 16.072838,
 17.565983,
 19.225279,
 20.71352,
 21.667736,
 21.790447,
 20.899879,
 19.012274,
 16.317251,
 13.150064,
 9.8907776,
 7.0072351,
 4.6846075,
 3.029146,
 1.946041,
 1.0919217,
 0.098311812,
 -1.4406949,
 -3.7629437,
 -7.0077753,
 -11.050987,
 -15.585717,
 -20.164085,
 -24.306299,
 -27.590773,
 -29.736839,
 -30.848682,
 -31.069016,
 -31.061485,
 -30.686995,
 -31.03178,
 -31.748112,
 -33.299461,
 -35.919754,
 -38.907295,
 -42.115463,
 -44.881657,
 -46.655895,
 -47.255661,
 -46.659004,
 -44.970089,
 -42.384876,
 -39.614048,
 -37.0881,
 -34.543285,
 -32.361893,
 -30.367081,
 -28.412891,
 -26.267321,
 -23.673107,
 -20.30426,
 -16.241322,
 -11.44988,
 -6.1824598,
 -0.85328394,
 4.2204714,
 8.6730185,
 12.300264,
 15.019421,
 16.92049,
 18.221352,
 19.190708,
 20.100592,
 21.194742,
 22.507717,
 23.986784,
 25.428883,
 26.544825,
 27.042883,
 26.692993,
 25.375931,
 23.142214,
 20.169025,
 16.782326,
 13.413164,
 10.326519,
 7.692627,
 5.611692,
 3.8167367,
 2.0629253,
 -0.033325758,
 -2.7715981,
 -6.3110862]
6.其他内容
6.1 查看tensor、node等
方法一:pywrap_tensorflow
tf.train.get_checkpoint_state(checkpoint_dir='checkpoint路径')  # checkpoint路径比如放在C:\Users\Administrator\Documents\checkpoint,这里填写r'C:\Users\Administrator\Documents\'即可
import os

logdir='./'

from tensorflow.python import pywrap_tensorflow
ckpt = tf.train.get_checkpoint_state(logdir)
reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)  # tensor的名称
#     print(reader.get_tensor(key))

# 参考:https://blog.csdn.net/wc781708249/article/details/78040735
tensor_name:  dense/kernel
tensor_name:  beta1_power
tensor_name:  beta2_power
tensor_name:  dense/bias
tensor_name:  rnn/basic_rnn_cell/bias
tensor_name:  dense/bias/Adam_1
tensor_name:  dense/bias/Adam
tensor_name:  dense/kernel/Adam_1
tensor_name:  dense/kernel/Adam
tensor_name:  rnn/basic_rnn_cell/bias/Adam
tensor_name:  rnn/basic_rnn_cell/bias/Adam_1
tensor_name:  rnn/basic_rnn_cell/kernel
tensor_name:  rnn/basic_rnn_cell/kernel/Adam
tensor_name:  rnn/basic_rnn_cell/kernel/Adam_1
方法二:inspect_checkpoint
inspect_checkpoint.print_tensors_in_checkpoint_file(file_name=,tensor_name=,all_tensors=)  # file_name参数填写路径,比如checkpoint等四个模型文件存放在C:\Users\Administrator\Documents,其中meta文件C:\Users\Administrator\Documents\my_time_series_model.meta,所以file_name=r'C:\Users\Administrator\Documents\my_time_series_model'
help(chkp.print_tensors_in_checkpoint_file)
Help on function print_tensors_in_checkpoint_file in module tensorflow.python.tools.inspect_checkpoint:

print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors)
    Prints tensors in a checkpoint file.
    
    If no `tensor_name` is provided, prints the tensor names and shapes
    in the checkpoint file.
    
    If `tensor_name` is provided, prints the content of the tensor.
    
    Args:
      file_name: Name of the checkpoint file.
      tensor_name: Name of the tensor in the checkpoint file to print.
      all_tensors: Boolean indicating whether to print all tensors.
#使用inspect_checkpoint来查看ckpt里的内容
from tensorflow.python.tools import inspect_checkpoint as chkp

chkp.print_tensors_in_checkpoint_file(file_name="./my_time_series_model",
                                      tensor_name=None, # 如果为None,则默认为ckpt里的所有变量
                                      all_tensors=False, # bool 是否打印所有的tensor,这里打印出的是tensor的值,一般不推荐这里设置为False
                                     ) # bool 是否打印所有的tensor的name

#上print_tensors_in_checkpoint_file其实是用NewCheckpointReader实现的。
beta1_power (DT_FLOAT) []
beta2_power (DT_FLOAT) []
dense/bias (DT_FLOAT) [1]
dense/bias/Adam (DT_FLOAT) [1]
dense/bias/Adam_1 (DT_FLOAT) [1]
dense/kernel (DT_FLOAT) [100,1]
dense/kernel/Adam (DT_FLOAT) [100,1]
dense/kernel/Adam_1 (DT_FLOAT) [100,1]
rnn/basic_rnn_cell/bias (DT_FLOAT) [100]
rnn/basic_rnn_cell/bias/Adam (DT_FLOAT) [100]
rnn/basic_rnn_cell/bias/Adam_1 (DT_FLOAT) [100]
rnn/basic_rnn_cell/kernel (DT_FLOAT) [101,100]
rnn/basic_rnn_cell/kernel/Adam (DT_FLOAT) [101,100]
rnn/basic_rnn_cell/kernel/Adam_1 (DT_FLOAT) [101,100]
方法三:查看所node的名称

先载入模型,获取图结构,然后打印图结构中的node。
(便于获取变量为基于模型的应用服务)

tf.get_default_graph().as_graph_def().node
# 查看所有的tensor名称
[tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
['X',
 'y',
 'Rank',
 'range/start',
 'range/delta',
 'range',
 'concat/values_0',
 'concat/axis',
 'concat',
 'transpose',
 'rnn/Shape',
 'rnn/strided_slice/stack',
 'rnn/strided_slice/stack_1',
 'rnn/strided_slice/stack_2',
 'rnn/strided_slice',
 'rnn/strided_slice_1/stack',
 'rnn/strided_slice_1/stack_1',
 'rnn/strided_slice_1/stack_2',
 'rnn/strided_slice_1',
 'rnn/BasicRNNCellZeroState/ExpandDims/dim',
 'rnn/BasicRNNCellZeroState/ExpandDims',
 'rnn/BasicRNNCellZeroState/Const',
 'rnn/BasicRNNCellZeroState/concat/axis',
 'rnn/BasicRNNCellZeroState/concat',
 'rnn/BasicRNNCellZeroState/ExpandDims_1/dim',
 'rnn/BasicRNNCellZeroState/ExpandDims_1',
 'rnn/BasicRNNCellZeroState/Const_1',
 'rnn/BasicRNNCellZeroState/zeros/Const',
 'rnn/BasicRNNCellZeroState/zeros',
 'rnn/Shape_1',
 'rnn/strided_slice_2/stack',
 'rnn/strided_slice_2/stack_1',
 'rnn/strided_slice_2/stack_2',
 'rnn/strided_slice_2',
 'rnn/strided_slice_3/stack',
 'rnn/strided_slice_3/stack_1',
 'rnn/strided_slice_3/stack_2',
 'rnn/strided_slice_3',
 'rnn/ExpandDims/dim',
 'rnn/ExpandDims',
 'rnn/Const',
 'rnn/concat/axis',
 'rnn/concat',
 'rnn/zeros/Const',
 'rnn/zeros',
 'rnn/time',
 'rnn/TensorArray',
 'rnn/TensorArray_1',
 'rnn/TensorArrayUnstack/Shape',
 'rnn/TensorArrayUnstack/strided_slice/stack',
 'rnn/TensorArrayUnstack/strided_slice/stack_1',
 'rnn/TensorArrayUnstack/strided_slice/stack_2',
 'rnn/TensorArrayUnstack/strided_slice',
 'rnn/TensorArrayUnstack/range/start',
 'rnn/TensorArrayUnstack/range/delta',
 'rnn/TensorArrayUnstack/range',
 'rnn/TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3',
 'rnn/while/Enter',
 'rnn/while/Enter_1',
 'rnn/while/Enter_2',
 'rnn/while/Merge',
 'rnn/while/Merge_1',
 'rnn/while/Merge_2',
 'rnn/while/Less/Enter',
 'rnn/while/Less',
 'rnn/while/LoopCond',
 'rnn/while/Switch',
 'rnn/while/Switch_1',
 'rnn/while/Switch_2',
 'rnn/while/Identity',
 'rnn/while/Identity_1',
 'rnn/while/Identity_2',
 'rnn/while/TensorArrayReadV3/Enter',
 'rnn/while/TensorArrayReadV3/Enter_1',
 'rnn/while/TensorArrayReadV3',
 'rnn/basic_rnn_cell/kernel/Initializer/random_uniform/shape',
 'rnn/basic_rnn_cell/kernel/Initializer/random_uniform/min',
 'rnn/basic_rnn_cell/kernel/Initializer/random_uniform/max',
 'rnn/basic_rnn_cell/kernel/Initializer/random_uniform/RandomUniform',
 'rnn/basic_rnn_cell/kernel/Initializer/random_uniform/sub',
 'rnn/basic_rnn_cell/kernel/Initializer/random_uniform/mul',
 'rnn/basic_rnn_cell/kernel/Initializer/random_uniform',
 'rnn/basic_rnn_cell/kernel',
 'rnn/basic_rnn_cell/kernel/Assign',
 'rnn/basic_rnn_cell/kernel/read',
 'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat/axis',
 'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat',
 'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter',
 'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul',
 'rnn/basic_rnn_cell/bias/Initializer/Const',
 'rnn/basic_rnn_cell/bias',
 'rnn/basic_rnn_cell/bias/Assign',
 'rnn/basic_rnn_cell/bias/read',
 'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter',
 'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd',
 'rnn/while/rnn/basic_rnn_cell/Relu',
 'rnn/while/TensorArrayWrite/TensorArrayWriteV3/Enter',
 'rnn/while/TensorArrayWrite/TensorArrayWriteV3',
 'rnn/while/add/y',
 'rnn/while/add',
 'rnn/while/NextIteration',
 'rnn/while/NextIteration_1',
 'rnn/while/NextIteration_2',
 'rnn/while/Exit',
 'rnn/while/Exit_1',
 'rnn/while/Exit_2',
 'rnn/TensorArrayStack/TensorArraySizeV3',
 'rnn/TensorArrayStack/range/start',
 'rnn/TensorArrayStack/range/delta',
 'rnn/TensorArrayStack/range',
 'rnn/TensorArrayStack/TensorArrayGatherV3',
 'rnn/Const_1',
 'rnn/Rank',
 'rnn/range/start',
 'rnn/range/delta',
 'rnn/range',
 'rnn/concat_1/values_0',
 'rnn/concat_1/axis',
 'rnn/concat_1',
 'rnn/transpose',
 'Reshape/shape',
 'Reshape',
 'dense/kernel/Initializer/random_uniform/shape',
 'dense/kernel/Initializer/random_uniform/min',
 'dense/kernel/Initializer/random_uniform/max',
 'dense/kernel/Initializer/random_uniform/RandomUniform',
 'dense/kernel/Initializer/random_uniform/sub',
 'dense/kernel/Initializer/random_uniform/mul',
 'dense/kernel/Initializer/random_uniform',
 'dense/kernel',
 'dense/kernel/Assign',
 'dense/kernel/read',
 'dense/bias/Initializer/zeros',
 'dense/bias',
 'dense/bias/Assign',
 'dense/bias/read',
 'dense/MatMul',
 'dense/BiasAdd',
 'outputs/shape',
 'outputs',
 'sub',
 'Square',
 'Const',
 'Mean',
 'gradients/Shape',
 'gradients/Const',
 'gradients/Fill',
 'gradients/f_count',
 'gradients/f_count_1',
 'gradients/Merge',
 'gradients/Switch',
 'gradients/Add/y',
 'gradients/Add',
 'gradients/NextIteration',
 'gradients/f_count_2',
 'gradients/b_count',
 'gradients/b_count_1',
 'gradients/Merge_1',
 'gradients/GreaterEqual/Enter',
 'gradients/GreaterEqual',
 'gradients/b_count_2',
 'gradients/Switch_1',
 'gradients/Sub',
 'gradients/NextIteration_1',
 'gradients/b_count_3',
 'gradients/Mean_grad/Reshape/shape',
 'gradients/Mean_grad/Reshape',
 'gradients/Mean_grad/Shape',
 'gradients/Mean_grad/Tile',
 'gradients/Mean_grad/Shape_1',
 'gradients/Mean_grad/Shape_2',
 'gradients/Mean_grad/Const',
 'gradients/Mean_grad/Prod',
 'gradients/Mean_grad/Const_1',
 'gradients/Mean_grad/Prod_1',
 'gradients/Mean_grad/Maximum/y',
 'gradients/Mean_grad/Maximum',
 'gradients/Mean_grad/floordiv',
 'gradients/Mean_grad/Cast',
 'gradients/Mean_grad/truediv',
 'gradients/Square_grad/mul/x',
 'gradients/Square_grad/mul',
 'gradients/Square_grad/mul_1',
 'gradients/sub_grad/Shape',
 'gradients/sub_grad/Shape_1',
 'gradients/sub_grad/BroadcastGradientArgs',
 'gradients/sub_grad/Sum',
 'gradients/sub_grad/Reshape',
 'gradients/sub_grad/Sum_1',
 'gradients/sub_grad/Neg',
 'gradients/sub_grad/Reshape_1',
 'gradients/sub_grad/tuple/group_deps',
 'gradients/sub_grad/tuple/control_dependency',
 'gradients/sub_grad/tuple/control_dependency_1',
 'gradients/outputs_grad/Shape',
 'gradients/outputs_grad/Reshape',
 'gradients/dense/BiasAdd_grad/BiasAddGrad',
 'gradients/dense/BiasAdd_grad/tuple/group_deps',
 'gradients/dense/BiasAdd_grad/tuple/control_dependency',
 'gradients/dense/BiasAdd_grad/tuple/control_dependency_1',
 'gradients/dense/MatMul_grad/MatMul',
 'gradients/dense/MatMul_grad/MatMul_1',
 'gradients/dense/MatMul_grad/tuple/group_deps',
 'gradients/dense/MatMul_grad/tuple/control_dependency',
 'gradients/dense/MatMul_grad/tuple/control_dependency_1',
 'gradients/Reshape_grad/Shape',
 'gradients/Reshape_grad/Reshape',
 'gradients/rnn/transpose_grad/InvertPermutation',
 'gradients/rnn/transpose_grad/transpose',
 'gradients/rnn/TensorArrayStack/TensorArrayGatherV3_grad/TensorArrayGrad/TensorArrayGradV3',
 'gradients/rnn/TensorArrayStack/TensorArrayGatherV3_grad/TensorArrayGrad/gradient_flow',
 'gradients/rnn/TensorArrayStack/TensorArrayGatherV3_grad/TensorArrayScatter/TensorArrayScatterV3',
 'gradients/zeros_like',
 'gradients/rnn/while/Exit_1_grad/b_exit',
 'gradients/rnn/while/Exit_2_grad/b_exit',
 'gradients/rnn/while/Switch_1_grad/b_switch',
 'gradients/rnn/while/Switch_2_grad/b_switch',
 'gradients/rnn/while/Merge_1_grad/Switch',
 'gradients/rnn/while/Merge_1_grad/tuple/group_deps',
 'gradients/rnn/while/Merge_1_grad/tuple/control_dependency',
 'gradients/rnn/while/Merge_1_grad/tuple/control_dependency_1',
 'gradients/rnn/while/Merge_2_grad/Switch',
 'gradients/rnn/while/Merge_2_grad/tuple/group_deps',
 'gradients/rnn/while/Merge_2_grad/tuple/control_dependency',
 'gradients/rnn/while/Merge_2_grad/tuple/control_dependency_1',
 'gradients/rnn/while/Enter_1_grad/Exit',
 'gradients/rnn/while/Enter_2_grad/Exit',
 'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayGrad/TensorArrayGradV3/Enter',
 'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayGrad/TensorArrayGradV3',
 'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayGrad/gradient_flow',
 'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/f_acc',
 'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/RefEnter',
 'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/StackPush',
 'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/StackPop/RefEnter',
 'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/StackPop',
 'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/b_sync',
 'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3',
 'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/tuple/group_deps',
 'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/tuple/control_dependency',
 'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/tuple/control_dependency_1',
 'gradients/AddN',
 'gradients/rnn/while/rnn/basic_rnn_cell/Relu_grad/ReluGrad/f_acc',
 'gradients/rnn/while/rnn/basic_rnn_cell/Relu_grad/ReluGrad/RefEnter',
 'gradients/rnn/while/rnn/basic_rnn_cell/Relu_grad/ReluGrad/StackPush',
 'gradients/rnn/while/rnn/basic_rnn_cell/Relu_grad/ReluGrad/StackPop/RefEnter',
 'gradients/rnn/while/rnn/basic_rnn_cell/Relu_grad/ReluGrad/StackPop',
 'gradients/rnn/while/rnn/basic_rnn_cell/Relu_grad/ReluGrad',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd_grad/BiasAddGrad',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd_grad/tuple/group_deps',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd_grad/tuple/control_dependency',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd_grad/tuple/control_dependency_1',
 'gradients/rnn/while/Switch_1_grad_1/NextIteration',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul/Enter',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul_1/f_acc',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul_1/RefEnter',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul_1/StackPush',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul_1/StackPop/RefEnter',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul_1/StackPop',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul_1',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/tuple/group_deps',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/tuple/control_dependency',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/tuple/control_dependency_1',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/b_acc',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/b_acc_1',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/b_acc_2',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/Switch',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/Add',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/NextIteration',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/b_acc_3',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/Rank',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/mod/f_acc',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/mod/RefEnter',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/mod/StackPush',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/mod/StackPop/RefEnter',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/mod/StackPop',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/mod',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/Shape',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/f_acc',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/RefEnter',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/StackPush',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/StackPop/RefEnter',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/StackPop',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/f_acc_1',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/RefEnter_1',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/StackPush_1',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/StackPop_1/RefEnter',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/StackPop_1',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ConcatOffset',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/Slice',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/Slice_1',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/tuple/group_deps',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/tuple/control_dependency',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/tuple/control_dependency_1',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/b_acc',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/b_acc_1',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/b_acc_2',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/Switch',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/Add',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/NextIteration',
 'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/b_acc_3',
 'gradients/rnn/while/Switch_2_grad_1/NextIteration',
 'beta1_power/initial_value',
 'beta1_power',
 'beta1_power/Assign',
 'beta1_power/read',
 'beta2_power/initial_value',
 'beta2_power',
 'beta2_power/Assign',
 'beta2_power/read',
 'rnn/basic_rnn_cell/kernel/Adam/Initializer/zeros',
 'rnn/basic_rnn_cell/kernel/Adam',
 'rnn/basic_rnn_cell/kernel/Adam/Assign',
 'rnn/basic_rnn_cell/kernel/Adam/read',
 'rnn/basic_rnn_cell/kernel/Adam_1/Initializer/zeros',
 'rnn/basic_rnn_cell/kernel/Adam_1',
 'rnn/basic_rnn_cell/kernel/Adam_1/Assign',
 'rnn/basic_rnn_cell/kernel/Adam_1/read',
 'rnn/basic_rnn_cell/bias/Adam/Initializer/zeros',
 'rnn/basic_rnn_cell/bias/Adam',
 'rnn/basic_rnn_cell/bias/Adam/Assign',
 'rnn/basic_rnn_cell/bias/Adam/read',
 'rnn/basic_rnn_cell/bias/Adam_1/Initializer/zeros',
 'rnn/basic_rnn_cell/bias/Adam_1',
 'rnn/basic_rnn_cell/bias/Adam_1/Assign',
 'rnn/basic_rnn_cell/bias/Adam_1/read',
 'dense/kernel/Adam/Initializer/zeros',
 'dense/kernel/Adam',
 'dense/kernel/Adam/Assign',
 'dense/kernel/Adam/read',
 'dense/kernel/Adam_1/Initializer/zeros',
 'dense/kernel/Adam_1',
 'dense/kernel/Adam_1/Assign',
 'dense/kernel/Adam_1/read',
 'dense/bias/Adam/Initializer/zeros',
 'dense/bias/Adam',
 'dense/bias/Adam/Assign',
 'dense/bias/Adam/read',
 'dense/bias/Adam_1/Initializer/zeros',
 'dense/bias/Adam_1',
 'dense/bias/Adam_1/Assign',
 'dense/bias/Adam_1/read',
 'Adam/learning_rate',
 'Adam/beta1',
 'Adam/beta2',
 'Adam/epsilon',
 'Adam/update_rnn/basic_rnn_cell/kernel/ApplyAdam',
 'Adam/update_rnn/basic_rnn_cell/bias/ApplyAdam',
 'Adam/update_dense/kernel/ApplyAdam',
 'Adam/update_dense/bias/ApplyAdam',
 'Adam/mul',
 'Adam/Assign',
 'Adam/mul_1',
 'Adam/Assign_1',
 'Adam',
 'init',
 'save/Const',
 'save/SaveV2/tensor_names',
 'save/SaveV2/shape_and_slices',
 'save/SaveV2',
 'save/control_dependency',
 'save/RestoreV2/tensor_names',
 'save/RestoreV2/shape_and_slices',
 'save/RestoreV2',
 'save/Assign',
 'save/RestoreV2_1/tensor_names',
 'save/RestoreV2_1/shape_and_slices',
 'save/RestoreV2_1',
 'save/Assign_1',
 'save/RestoreV2_2/tensor_names',
 'save/RestoreV2_2/shape_and_slices',
 'save/RestoreV2_2',
 'save/Assign_2',
 'save/RestoreV2_3/tensor_names',
 'save/RestoreV2_3/shape_and_slices',
 'save/RestoreV2_3',
 'save/Assign_3',
 'save/RestoreV2_4/tensor_names',
 'save/RestoreV2_4/shape_and_slices',
 'save/RestoreV2_4',
 'save/Assign_4',
 'save/RestoreV2_5/tensor_names',
 'save/RestoreV2_5/shape_and_slices',
 'save/RestoreV2_5',
 'save/Assign_5',
 'save/RestoreV2_6/tensor_names',
 'save/RestoreV2_6/shape_and_slices',
 'save/RestoreV2_6',
 'save/Assign_6',
 'save/RestoreV2_7/tensor_names',
 'save/RestoreV2_7/shape_and_slices',
 'save/RestoreV2_7',
 'save/Assign_7',
 'save/RestoreV2_8/tensor_names',
 'save/RestoreV2_8/shape_and_slices',
 'save/RestoreV2_8',
 'save/Assign_8',
 'save/RestoreV2_9/tensor_names',
 'save/RestoreV2_9/shape_and_slices',
 'save/RestoreV2_9',
 'save/Assign_9',
 'save/RestoreV2_10/tensor_names',
 'save/RestoreV2_10/shape_and_slices',
 'save/RestoreV2_10',
 'save/Assign_10',
 'save/RestoreV2_11/tensor_names',
 'save/RestoreV2_11/shape_and_slices',
 'save/RestoreV2_11',
 'save/Assign_11',
 'save/RestoreV2_12/tensor_names',
 'save/RestoreV2_12/shape_and_slices',
 'save/RestoreV2_12',
 'save/Assign_12',
 'save/RestoreV2_13/tensor_names',
 'save/RestoreV2_13/shape_and_slices',
 'save/RestoreV2_13',
 'save/Assign_13',
 'save/restore_all']
6.2关于不同版本的checkpoint文件理解
  • 对于tensorflow1.2版本及以上,直接书写完整的ckpt文件的路径中的model_name(比如,my_model.meta书写my_model即可)即可




参考:
(1) tensorflow的checkpoint文件的版本
(2) TensorFlow查看ckpt中变量的几种方法

二、学习其他简单的

1. 保存变量

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in path: %s" % save_path)
Model saved in path: /tmp/model.ckpt

2. 恢复变量

采用tf.train.Saver对象回复变量时,不必实现进行初始化,即tf.get_variable()中的initialization参数不需要设置。

tf.reset_default_graph()
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Check the values of the variables
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())
INFO:tensorflow:Restoring parameters from /tmp/model.ckpt
Model restored.
v1 : [ 1.  1.  1.]
v2 : [-1. -1. -1. -1. -1.]

3. 选择想要保存的和恢复的变量(还不太明白)

tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
# Add ops to save and restore only `v2` using the name "v2"
saver = tf.train.Saver({"v2": v2})

# Use the saver object normally after that.
with tf.Session() as sess:
  # Initialize v1 since the saver will not.
  v1.initializer.run()
  saver.restore(sess, "/tmp/model.ckpt")

  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())
INFO:tensorflow:Restoring parameters from /tmp/model.ckpt
v1 : [ 0.  0.  0.]
v2 : [-1. -1. -1. -1. -1.]

三、保存和恢复模型

  • savaedmodel 保存和加载模型(包括变量、图、图的元数据)。对应tensorflow有tf.save_model和tf.estimator.Estimator。
1. 构建和加载savedmodel
简单保存

采用:tf.saved_model.simple_save函数

simple_save(session,
            export_dir,
            inputs={"x": x, "y": y},
            outputs={"z": z})
手动构建savedmodel
2. 加载savedmodel

需要的基本信息

  • 图定义和变量的会话
  • 用于标识要加载的 MetaGraphDef 的标签
  • SavedModel 的位置(目录)

???

export_dir = ...
with tf.Session(graph=tf.Graph()) as sess:
 tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)

参考:

  1. 官方文档 https://tensorflow.google.cn/guide/saved_model

四、a quick complete tutorial to save model

参考自:https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

1. 关于tensorflow的model文件基本介绍

(1) 在0.11版本之后(4个)
- checkpoint: 保存的最新checkpoint文件
- model.data-0000-of-0001: 文件中包含有training variables
- model.index:
- model.meta:

例如:
在这里插入图片描述
(2) 0.11版本之前(3个)
- checkpoint
- model.ckpt
- model.meta

2. 保存变量

  • 注意:
    • 需要在session内进行保存
    • 文件不保存原有的placeholder中的值

(1) 简单保存

saver.save(sess, 'my-test-model')

(2) 指定迭代数即每iteration值之后再保存

saver.save(sess, 'my-test-model', global_step=1000)  # 每1000次迭代保存一次

结果的model文件名称将追加’-1000’

(3) write_meta_graph参数:false表示不跟随global_step同步保存,True表示与global_step同步保存

saver.save(sess, 'my-test-model', global_step=1000, write_meta_graph=false)  # 表示每1000步对的的meta文件并不保存,仅保存第一次的

(4) 希望每n个小时把保存最新的m个models

saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)  # 表示每2h薄脆最新的4个models

(5) 关于tf.train.Saver()中参数
- 按照默认参数:保存全部的vairables
- 具体变量名称的list或dict: 保存部分variables

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000)

3. 导入model

主要包括两步骤:

  • create the network
    通过meta文件来进行获取原有的关系图。
saver = tf.train.import_meta_graph('my_model-1000.meta')
  • load the parameters
with tf.Session() as sess:
   new_saver = tf.train.import_meta_graph('my_model-1000.meta')  # 假设原有模型保存了w1和w2的tensor值
   new_saver.restore(sess, tf.train.latest_checkpoint('./'))  # 模型已经载入,这里载入model中的w1和w2的值
   print(sess.run('w1:0'))  # 执行打印出之前模型里的w1值

4. 基于载入model的操作

常见的用已经训练好的模型进行prediction, fine-tuning, further training.
关于get_tensor_by_name中的的理解:

# w1:0
# <name>:0 (0 refers to endpoint which is somewhat redundant)
# 形如'w1'是节点名称,而'w1:0'是张量名称,表示节点的第一个输出张量
tensor = tf.get_default_graph().get_tensor_by_name("w1:0")

(1) 取得已保存的 variable、tensor、placeholders、operation

w1 = graph.get_tensor_by_name("w1:0")
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

(2) 基于原有网络图,喂入新数据
- 载入meta grapha和恢复weight——取得placeholder和为新数据创建feed-dict——取得operation——运行

import tensorflow as tf
 
sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
 
 
#Now, let's access and create placeholders variables and
#create feed-dict to feed new data
 
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
 
#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
 
print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 
#using new values of w1 and w2 and saved value of b1. 

(3) 在原有图关系基础上添加更多的operation

#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)

(4) 在 fine-tuning
e.g. 对原有的vgg网络图,将最后输出层更改成2个,并用新数据微调。

saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning 
 
#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')
 
#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()
new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)
# Now, you run this with fine-tuning data in sess.run()


  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值