Tensorflow-- 用估算器框架训练一个回归模型

估算器框架(Estimators API)是属于Tensorflow里的一个高级API。 它的特点是对底层代码进行了高度的封装,使其对开发模型的过程变得简单。

估算器框架主要包含了三个部分:1. 输入函数、2.模型函数、 3. 估算器。 使用估算器框架开发模型,就是具体化上述三个过程

下面代码中,输入函数包含了两个部分,即训练部分与测试部分。 其中 train_input_fn函数用于训练, eval_input_fn函数用于测试。

估算器的模型函数有固定的模式,其函数名可以任意指定,但是输入参数与返回值都有固定的要求、

输入参数有四个: feature (样本数据) labels (标签数据) mode (模型运行模式) params (用于模型其他参数)

输入参数:  必须是一个tf.estimator.EstimatorSpec类型的对象(见代码,return值)

定义估算器: 估算器在这里作用相当于操作系统,它将前面的输入数据部分与模型函数部分作为一个整合进行处理。

## 微调模型,热启动问题。 该问题的参数设置在定义的评估器对象中、需要重新定义评估器,载入预训练参数模型,并定义新的模型路径。

#本实例利用估算器实现一个回归模型
import tensorflow as tf
import numpy as np

tf.reset_default_graph()
#在内存中生成模拟数据
def GenerateData(datasize = 100 ):
    train_X = np.linspace(-1, 1, datasize)   #train_X为-1到1之间连续的100个浮点数
    train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.3 # y=2x,但是加入了噪声
    return train_X, train_Y   #以生成器的方式返回

train_data = GenerateData()  
test_data = GenerateData(20)  
batch_size=10

def train_input_fn(train_data, batch_size):  #定义训练数据集输入函数
    #构造数据集的组成:一个特征输入,一个标签输入
    dataset = tf.data.Dataset.from_tensor_slices( (  train_data[0],train_data[1]) )   
    dataset = dataset.shuffle(1000).repeat().batch(batch_size) #将数据集乱序、重复、批次划分. 
    return dataset     #返回数据集 

def eval_input_fn(data,labels, batch_size):  #定义测试或应用模型时,数据集的输入函数
    #batch不允许为空
    assert batch_size is not None, "batch_size must not be None" 
    
    if labels is None:  #如果评估,则没有标签
        inputs = data  
    else:  
        inputs = (data,labels)  
    #构造数据集 
    dataset = tf.data.Dataset.from_tensor_slices(inputs)  
 
    dataset = dataset.batch(batch_size)  #按批次划分
    return dataset     #返回数据集     

def my_model(features, labels, mode, params):#自定义模型函数:参数是固定的。一个特征,一个标签
    #定义网络结构
    W = tf.Variable(tf.random_normal([1]), name="weight")
    b = tf.Variable(tf.zeros([1]), name="bias")
    # 前向结构
    predictions = tf.multiply(tf.cast(features,dtype = tf.float32), W)+ b
    
    if mode == tf.estimator.ModeKeys.PREDICT: #预测处理
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    #定义损失函数
    loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions)

    meanloss  = tf.metrics.mean(loss)#添加评估输出项
    metrics = {'meanloss':meanloss}

    if mode == tf.estimator.ModeKeys.EVAL: #测试处理
        return tf.estimator.EstimatorSpec(   mode, loss=loss, eval_metric_ops=metrics) 
        #return tf.estimator.EstimatorSpec(   mode, loss=loss)

    #训练处理.
    assert mode == tf.estimator.ModeKeys.TRAIN
    optimizer = tf.train.AdagradOptimizer(learning_rate=params['learning_rate'])
    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)  


tf.reset_default_graph()  #清空图
tf.logging.set_verbosity(tf.logging.INFO)      #能够控制输出信息  ,
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)  #构建gpu_options,防止显存占满
session_config=tf.ConfigProto(gpu_options=gpu_options)
#构建估算器
estimator = tf.estimator.Estimator(  model_fn=my_model,model_dir='./myestimatormode',params={'learning_rate': 0.1},
                                   config=tf.estimator.RunConfig(session_config=session_config)  )
#匿名输入方式
estimator.train(lambda: train_input_fn(train_data, batch_size),steps=200)
tf.logging.info("训练完成.")#输出训练完成

print("##########################################################################################")
#通过热启动实现模型的微调

#热启动
warm_start_from = tf.estimator.WarmStartSettings(
            ckpt_to_initialize_from='./myestimatormode',
        )
#重新定义带有热启动的估算器
estimator2 = tf.estimator.Estimator(  model_fn=my_model,model_dir='./myestimatormode3',warm_start_from=warm_start_from,params={'learning_rate': 0.1},
                                   config=tf.estimator.RunConfig(session_config=session_config)  )
estimator2.train(lambda: train_input_fn(train_data, batch_size),steps=200)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值