如何清空 Tensorflow 默认图

tf.reset_default_graph() 可以清空默认图里所有的节点。

下面代码transferLearnRetrainLayer2()里是要迁移学习(重用)tensorflowMain()训练出来的模型,见https://blog.csdn.net/brooknew/article/details/83107491(Tensorflow 实现迁移学习的例子)。若不包含上面这句的话,在transferLearnRetrainLayer2()导出来的w1Org, w2Org ,将是在tensorflowMain()定义的值。《Tensorflow 实现迁移学习的例子》里的代码是通过with tf.Graph().as_default()产生新的计算图,也没问题。

训练结果(Tensorflow):
w1: [[1.0515914  1.8256049  2.6668158 ]
 [0.582871   0.6132111  0.90144134]]
w2: [[1.0773216]
 [1.585386 ]
 [2.32953  ]]
b1: [-0.14707229 -0.069271   -0.02620909]
b2: [5.1552486]
Accury : 0.828125
start transferLearn:
Accury re-using model: 0.828125
start transferLearnRetrainLayer2:
after transfer learning:
w1Org: [[1. 1. 1.]
 [2. 1. 1.]]
w2Org: [[2.]
 [3.]
 [5.]]
b1Org: [0. 0. 0.]
b2Org: [0.]
w2new: [[-6.659646 ]
 [ 7.4890714]
 [ 9.489082 ]]
b2new: [4.299692]
Accury after transfer learning: 0.8984375
#NN_transferLearn.py
import tensorflow as tf
import numpy as np
from traindata import *
from dataDim import *
from tensorflow.python.framework import graph_util

pbName = 'NN_transferLearn.pb'

BATCH_SIZE = 8
TRAIN_STEPS = 10000

def calAccury(  ya ) : # ya is list containing predicted result   
    ya = np.array( ya )
    y_a = Y_T
    y_a = np.array( y_a ) 
    accury =np.array(ya - y_a)
    a = ( abs( accury ) < TRAIN_THRESHOLD )
    af = a.astype( np.float32 )
    right = af.sum()
    per = right/VALIDATE_SIZE 
    return per , a 


''' main for training using tensorflow '''
def  tensorflowMain() :    
    x = tf.placeholder(tf.float32, shape=(None, 2), name='input')
    y_= tf.placeholder(tf.float32, shape=(None, 1) , name = 'label' )

    w1= tf.Variable( [[1.0,1.0,1.0],[2.0,1.0,1.0]], name = 'w1')
    w2= tf.Variable([[2.0],[3.0],[5.0]] , name ='w2')
    b1 = tf.Variable( [0.0 , 0.0 , 0.0 ] , name ='b1' )
    b2 = tf.Variable( [0.0  ] , name ='b2' ) 
    a = tf.add ( tf.matmul(x, w1) ,  b1 , name = 'mulAddFir' ) 
    y = tf.add ( tf.matmul(a, w2) ,  b2 , name = 'mulAddSec' )

    #2定义损失函数及反向传播方法。
    loss_mse = tf.reduce_mean(tf.square(y-y_)) 
    train_step = tf.train.GradientDescentOptimizer(0.001).minimize(loss_mse)
    #3生成会话,训练STEPS轮
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
          
        # 训练模型。
        STEPS = TRAIN_STEPS
        for i in range(STEPS):
            start = (i*BATCH_SIZE) % SAMPLE_SIZE 
            end =  start + BATCH_SIZE
            sess.run(train_step, feed_dict={x: X[start:end], y_: Y_[start:end]})
            if  (i+1) % 10000 == 0  :
                total_loss = sess.run(loss_mse, feed_dict={x: X[start:end], y_: Y_[start:end]})
                #print ("w1 after trained " , i+1 , " time(s)  by tensorflow:\n", sess.run(w1))
                #print ("w2 after trained " , i +1, " time(s)  by tensorflow:\n", sess.run(w2))
                #print("After %d training step(s), loss_mse  is %g" % (i+1, total_loss))
        print( "训练结果(Tensorflow):" )
        r_w1 = sess.run(w1)
        r_w2 = sess.run(w2)
        r_b1 = sess.run( b1)
        r_b2 = sess.run( b2 )
        print("w1:", r_w1  )
        print("w2:", r_w2 )
        print ("b1:", r_b1 )
        print ("b2:", r_b2 )

        constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['mulAddFir', 'label' , 'mulAddSec' ] )
        with open( pbName, mode='wb') as f:
            f.write(constant_graph.SerializeToString())

        #validate
        rv = sess.run( y , feed_dict={x:XT } )
        per , accAr = calAccury( rv )
        print("Accury :" , per  )           
        return r_w1 , r_w2 , r_b1 , r_b2

def transferLearn() :
    print("start transferLearn:" )
    with tf.Graph().as_default():
        graph0 = tf.GraphDef()
        with open( pbName, mode='rb') as f:
            graph0.ParseFromString( f.read() )
            tf.import_graph_def( graph0 , name = ''  )
        with tf.Session() as sess :
            init = tf.global_variables_initializer()
            sess.run(init)
            x = sess.graph.get_tensor_by_name('input:0' )
            y_ = sess.graph.get_tensor_by_name('label:0' )
            y  = sess.graph.get_tensor_by_name('mulAddSec:0' )
            
            rv = sess.run( y , feed_dict={x:XT } )
            per , accAr = calAccury( rv )
            print("Accury re-using model:" , per  )       


def transferLearnRetrainLayer2() :
    print("start transferLearnRetrainLayer2:" )
    tf.reset_default_graph() #需要清除掉以前在默认图里定义的变量
                             #若mark掉上面这句,r_w1Org = sess.run(w1Org)出来的将是tensorflowMain()里初始化的值。
    #with tf.Graph().as_default():
    graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
        graph0.ParseFromString( f.read() )
        tf.import_graph_def( graph0 , name = ''  )
        yFirst = tf.placeholder( tf.float32 ) 
        w2new= tf.Variable([[2.0],[3.0],[5.0]] , name ='w2new')
        b2new = tf.Variable( [0.0  ] , name ='b2new' )
        ynew =  tf.matmul( yFirst , w2new ) + b2new 

    with tf.Session() as sess :
        init = tf.global_variables_initializer()
        sess.run(init)
        x = sess.graph.get_tensor_by_name('input:0' )
        y_ = sess.graph.get_tensor_by_name('label:0' )
        layer1Tensor  = sess.graph.get_tensor_by_name('mulAddFir:0' )
        loss_mse = tf.reduce_mean(tf.square(ynew-y_))
        tf.summary.scalar('loss_mse', loss_mse )            
        train_stepH = tf.train.GradientDescentOptimizer(0.001).minimize(loss_mse)
        summary_ops = tf.summary.merge_all()

        STEPS = TRAIN_STEPS
        summary_writer = tf.summary.FileWriter('./logs/', sess.graph)
        for i in range(STEPS):
            start = (i*BATCH_SIZE) % SAMPLE_SIZE 
            end =  start + BATCH_SIZE
            ar = sess.run( layer1Tensor , feed_dict={x:X[start:end] } ) 
            _,val = sess.run([train_stepH,summary_ops], feed_dict={ yFirst:ar, y_: Y_[start:end]})
            summary_writer.add_summary(val, global_step=i)
        print( "after transfer learning:" )            
        w1Org = sess.graph.get_tensor_by_name('w1:0' )
        b1Org  = sess.graph.get_tensor_by_name('b1:0' )
        w2Org = sess.graph.get_tensor_by_name('w2:0' )
        b2Org  = sess.graph.get_tensor_by_name('b2:0' )
        r_w1Org = sess.run(w1Org)
        r_w2Org = sess.run(w2Org)
        r_b1Org = sess.run( b1Org)
        r_b2Org = sess.run( b2Org )
        print("w1Org:", r_w1Org  )
        print("w2Org:", r_w2Org )
        print ("b1Org:", r_b1Org )
        print ("b2Org:", r_b2Org )

        r_w2new = sess.run(w2new)
        r_b2new = sess.run( b2new )
        print("w2new:", r_w2new )
        print ("b2new:", r_b2new )

        ar = sess.run( layer1Tensor , feed_dict={x:XT }  )
        rv = sess.run( ynew , feed_dict={yFirst:ar } )
        per , accAr = calAccury( rv )
        print("Accury after transfer learning:" , per  )       


def main() :                       
    w21, w22, b21 , b22 = tensorflowMain()
    transferLearn()
    transferLearnRetrainLayer2()
    
main()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值