一.有监督学习:
1.在监督学习中,每个例子都是一对输入对象和一个期望输出组成。并由这些例子去训练一个推断模型,再由这个模型去预测位置样本。整个训练过程可由如下图表示:
·首先进行模型参数初始化,通常采用随机赋值或取零。
·读取训练数据,通常会随机打乱次序。
·执行推断模型,每个样本得到一个输出值。
·计算损失,损失是一个能够刻画模型在最后一步得到的输出与来自训练集希望的输出间差距的概括性指标。
·调整模型参数:对应实际学习的过程,给定损失函数,学习的目的在于通过大量训练步骤改善参数的值,从而将损失最小话。(常用随机梯度下降法)
当结束上述过程中便进入评估阶段,此时需要一个同样含有期望的输出信息的验证集依据模型进行推断,并评估模型在该数据集上的损失。(通常将70%的样本作为训练集,30%的样本作为验证和评估)。
利用上述过程定义一个训练及评估的通用代码框架:
import tensorflow as tf
def inference(X):
# 计算推断模型在X上的输出
def loss(X):
# 依据X的输出和期望输出Y计算损失
def inputs():
# 读取或生成训练数据X和期望的Y
def train(total_loss):
# 依据损失函数调整模型参数
def evaluate(sess,X,Y):
# 对训练的模型进行评估
# 在会话对象中启动数据流图
with tf.Session() as sess:
tf.initialize_local_variables().run()
# 读取数据
X,Y=inputs()
# 计算损失
total_loss=loss(X,Y)
# 训练
train_op=train()
# 创建线程
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
# 迭代
training_step=1000
for step in range(training_step):
sess.run([train_op])
# 查看损失情况
if step%10==0:
print("loss:",sess.run([total_loss]))
# 评估
evaluate(sess,X,Y)
coord.request_stop()
# 线程停止
coord.join(threads)
sess.close()
2.保存训练检查点
在训练模型时需要通过很多周期更新参数,为防止计算机在长时间的训练中发生断电等故障,可借助tf.train.Saver将数据流图中的变量保存。代码如下:
....
# 迭代
training_step=1000
for step in range(training_step):
sess.run([train_op])
if step%100==0:
saver.save(sess,'my_model',global_step=step)
.....
# 评估
saver.save(sess,'my_model',global_step=train_step)
sess.close()
如果希望从某个检查点恢复训练,则应使用tf.train.get_checkpoint_state方法:
with tf.Session() as sess:
initial_step=0
# 验证是否已经保存了检查点文件
ckpt=tf.train.get_checkpoint_state(os.path.dirname(__file__))
if ckpt and ckpt.model_checkpoint_path:
# 从检查点恢复
saver.restore(sess,ckpt.model_checkpoint_path)
initial_step=int(ckpt.model_checkpoint_path.rsplit('-',1)[1])
for step in range(initial_step,training_step):
......
参考:《面向机器智能的Tensorflow实践》