读取tensorflow保存的模型中权重数据,并继续进行训练

'''
读取tensorflow保存的模型中权重数据,并继续进行训练
'''
import time
import random
import numpy as np
import tensorflow as tf
import os

tf.reset_default_graph()

start=time.clock()

MODEL_SAVE_PATH='./tensorflow_model'
MODEL_NAME='model_test.ckpt'
#输入图像的通道数,灰度图为1
NUM_CHANNELS=1
CONV1_DEEP=8
CONV1_SIZE=3
CONV2_DEEP=16
CONV2_SIZE=3
CONV3_DEEP=32
CONV3_SIZE=3
FC_SIZE=200
#输出标签数
NUM_LABELS=37
OUTPUT_NODE=NUM_LABELS
BATCH_SIZE=200

KEEP_PROB=0.5
#学习率
LEARNING_RATE_BASE=0.002
LEARNING_RATE_DECAY=0.98
LEARNING_RATE_DECAY_STEP=200

REGULARAZTION_RATE=0.0001

TRAINING_STEPS=20000

MOVING_AVERAGE_DECAY=0.99

VALIDATE_STEP = 50   #每n步验证一次
#读取存储的数据
captcha_data = np.load('./captcha_array.npy')
captcha_label = np.load('./captcha_label.npy')

CAPTCHA_HEIGHT = 40
CAPTCHA_WIDTH = 25

base_path = os.getcwd()

conv1_weights=tf.Variable(np.load(base_path + '\params\conv1_w.npy'),name = 'conv1_weights',trainable = True)
conv1_biases=tf.Variable(np.load(base_path + '\params\conv1_b.npy'),name = 'conv1_biases',trainable = True)

conv2_weights=tf.Variable(np.load(base_path + '\params\conv2_w.npy'),name = 'conv2_weights',trainable = True)
conv2_biases=tf.Variable(np.load(base_path + '\params\conv2_b.npy'),name = 'conv2_biases',trainable = True)

conv3_weights=tf.Variable(np.load(base_path + '\params\conv3_w.npy'),name = 'conv3_weights',trainable = True)
conv3_biases=tf.Variable(np.load(base_path + '\params\conv3_b.npy'),name = 'conv3_biases',trainable = True)

fc1_weights=tf.Variable(np.load(base_path + '\params\\fc1_w.npy'),name = 'fc1_weights',trainable = True)    
fc1_biases=tf.Variable(np.load(base_path + '\params\\fc1_b.npy'),name = 'fc1_biases',trainable = True)

fc2_weights=tf.Variable(np.load(base_path + '\params\\fc2_w.npy'),name = 'fc2_weights',trainable = True)
fc2_biases=tf.Variable(np.load(base_path + '\params\\fc2_b.npy'),name = 'fc2_biases',trainable = True)  


def data_batch():
    train_data=np.zeros((BATCH_SIZE,CAPTCHA_HEIGHT,CAPTCHA_WIDTH))
    label = np.zeros((BATCH_SIZE,NUM_LABELS))

    for i in range(BATCH_SIZE):
        rand_index = random.randint(0,6000)
        img_gray = captcha_data[rand_index, : , : ]
        # 图像归一化处理            
        pixel_min = np.min(img_gray)
        pixel_max = np.max(img_gray)          
        standard_img = (img_gray - pixel_min)/(pixel_max - pixel_min)
        train_data[i] = standard_img   

        label[i] = captcha_label[rand_index, : ]
        train_data = train_data.astype('float32')  
    return train_data,label
#print(data_batch())

def validation_batch():
    validation_data=np.zeros((BATCH_SIZE,CAPTCHA_HEIGHT,CAPTCHA_WIDTH))
    label = np.zeros((BATCH_SIZE,NUM_LABELS))
    rand_index = random.randint(6000,6599)

    for i in range(BATCH_SIZE):
        img_gray = captcha_data[rand_index, : , : ]
        # 图像归一化处理            
        pixel_min = np.min(img_gray)
        pixel_max = np.max(img_gray)         
        standard_img = (img_gray - pixel_min)/(pixel_max - pixel_min)        
        validation_data[i] = standard_img 
        label[i] = captcha_label[rand_index, : ]

        validation_data = validation_data.astype('float32')  
    return validation_data,label       
 

           
def inference(input_tensor,_dropout,regularizer):
#    conv1_weights=tf.get_variable('conv1_weights',[CONV1_SIZE,CONV1_SIZE,NUM_CHANNELS,CONV1_DEEP],initializer=tf.truncated_normal_initializer(stddev=0.1))
#    conv1_biases=tf.get_variable('conv1_biases',[CONV1_DEEP],initializer=tf.constant_initializer(0.0))
        
    conv1=tf.nn.conv2d(input_tensor,conv1_weights,strides=[1,1,1,1],padding='SAME')
    relu1=tf.nn.relu(tf.nn.bias_add(conv1,conv1_biases))        
    pool1=tf.nn.max_pool(relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
        
#    conv2_weights=tf.get_variable('conv2_weights',[CONV2_SIZE,CONV2_SIZE,CONV1_DEEP,CONV2_DEEP],initializer=tf.truncated_normal_initializer(stddev=0.1))
#    conv2_biases=tf.get_variable('conv2_biases',[CONV2_DEEP],initializer=tf.constant_initializer(0.0))
    
    conv2=tf.nn.conv2d(pool1,conv2_weights,strides=[1,1,1,1],padding='SAME')
    relu2=tf.nn.relu(tf.nn.bias_add(conv2,conv2_biases))        
    pool2=tf.nn.max_pool(relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')  

#    conv3_weights=tf.get_variable('conv3_weights',[CONV3_SIZE,CONV3_SIZE,CONV2_DEEP,CONV3_DEEP],initializer=tf.truncated_normal_initializer(stddev=0.1))
#    conv3_biases=tf.get_variable('conv3_biases',[CONV3_DEEP],initializer=tf.constant_initializer(0.0))
    
    conv3=tf.nn.conv2d(pool2,conv3_weights,strides=[1,1,1,1],padding='SAME')
    relu3=tf.nn.relu(tf.nn.bias_add(conv3,conv3_biases))
    pool3=tf.nn.max_pool(relu3,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
 
    pool_shape=pool3.get_shape().as_list()
    nodes=pool_shape[1]*pool_shape[2]*pool_shape[3]   
    reshaped=tf.reshape(pool3,[pool_shape[0],nodes])   

#    fc1_weights=tf.get_variable('fc1_weights',[nodes,FC_SIZE],initializer=tf.truncated_normal_initializer(stddev=0.1))
#    fc1_biases=tf.get_variable('fc1_biases',[FC_SIZE],initializer=tf.constant_initializer(0.1))
    
    tf.add_to_collection('losses',regularizer(fc1_weights))
  
    fc1=tf.nn.relu(tf.matmul(reshaped,fc1_weights)+fc1_biases)

    if _dropout:fc1=tf.nn.dropout(fc1,KEEP_PROB)        

#    fc2_weights=tf.get_variable('fc2_weights',[FC_SIZE,NUM_LABELS],initializer=tf.truncated_normal_initializer(stddev=0.1))
#    fc2_biases=tf.get_variable('fc2_biases',[NUM_LABELS],initializer=tf.constant_initializer(0.1))
    
    tf.add_to_collection('losses',regularizer(fc2_weights))
        
    logit=tf.matmul(fc1,fc2_weights)+fc2_biases
           
    return logit          

x=tf.placeholder(tf.float32,[BATCH_SIZE,CAPTCHA_HEIGHT,CAPTCHA_WIDTH,NUM_CHANNELS],name='x-input')
y_=tf.placeholder(tf.float32,[BATCH_SIZE,OUTPUT_NODE],name='y-input')
#    L2正则化损失
regularizer=tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
# 计算神经网络在当前参数下的前向传播结果   
y=inference(x,True,regularizer)
#    训练轮数变量
global_step=tf.Variable(0,trainable=False)
#    初始化滑动平均
variable_averages=tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
#tf.get_variables返回所有没有指定trainable=False的参数
#    在所有代表神经网络参数的变量上使用滑动平均值
variables_averages_op=variable_averages.apply(tf.trainable_variables())   

loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = y_,logits = y))+tf.add_n(tf.get_collection('losses'))

learning_rate=tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,LEARNING_RATE_DECAY_STEP,LEARNING_RATE_DECAY,staircase=True)

train_step=tf.train.AdamOptimizer(learning_rate).minimize(loss,global_step=global_step)
         
with tf.control_dependencies([train_step,variables_averages_op]):
    train_op=tf.no_op(name='train')      
    
saver=tf.train.Saver()    
   
with tf.Session() as sess:
    tf.global_variables_initializer().run()
 
    for i in range(TRAINING_STEPS):             
        train_data = data_batch()
        xs=np.reshape(train_data[0],(BATCH_SIZE,CAPTCHA_HEIGHT,CAPTCHA_WIDTH,NUM_CHANNELS))
        ys = train_data[1]
         
        _,loss_value,step=sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})

        if i % VALIDATE_STEP ==0:
            #在训练数据上的准确率
            k = 0
            train_batch_right_num = 0    #一个batch中判断正确的个数
            for _vec in sess.run(y,feed_dict={x:xs,y_:ys}):

                if np.argmax(_vec) == np.argmax(ys[k]):
                    train_batch_right_num += 1
                k += 1
            training_correct_rate = train_batch_right_num / BATCH_SIZE
            
#        验证数据                         
            valid_data = validation_batch()
#                
            valid_x=np.reshape(valid_data[0],(BATCH_SIZE,CAPTCHA_HEIGHT,CAPTCHA_WIDTH,NUM_CHANNELS))
            k = 0
            right_num_in_the_batch = 0    #一个batch中判断正确的个数
            for _vec in sess.run(y,feed_dict={x:valid_x,y_:valid_data[1]}):

                if np.argmax(_vec) == np.argmax(valid_data[1][k]):
                    right_num_in_the_batch += 1
                k += 1
                    
            precision_rate = right_num_in_the_batch / BATCH_SIZE
                                 
            print ('after %d training steps, loss on training ' 'batch is %g.' % (step,loss_value))
            print('training_correct_rate',training_correct_rate,'precision_rate is ',precision_rate)#准确率

            saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值