使用tensorflow实现logistic回归

   逻辑回归原理很简单,这里不再赘述,我使用tensorflow的思路和前面一样,还是利用Supervisor模块(这个确实好用啊),argparser和logging日志模块。实现代码如下:

import sys
reload(sys)
sys.setdefaultencoding('utf-8')

import tensorflow as tf
import logging
import argparse
import os
import matplotlib.pyplot as plt
import numpy as np
logging.basicConfig(format="[%(process)d] %(levelname)s %(filename)s:%(lineno)s | %(message)s")
log = logging.getLogger('train')
log.setLevel(logging.INFO)

data = np.mat([[0.697,0.460,1],
        [0.774,0.376,1],
        [0.634,0.264,1],
        [0.608,0.318,1],
        [0.556,0.215,1],
        [0.403,0.237,1],
        [0.481,0.149,1],
        [0.437,0.211,1],
        [0.666,0.091,0],
        [0.243,0.267,0],
        [0.245,0.057,0],
        [0.343,0.099,0],
        [0.639,0.161,0],
        [0.657,0.198,0],
        [0.360,0.370,0],
        [0.593,0.042,0],
        [0.719,0.103,0]])

def log_hook(sess,log_fetches):
    data = sess.run(log_fetches)
    loss = data['loss']
    step = data['step']
    log.info('Step {}|loss = {:.4f}'.format(step,loss))


def logistic_regression(W,b,x):
    pred = 1/(1+tf.exp(-(tf.matmul(x,W)+b)))
    return pred


def main(args):
    W = tf.Variable(tf.random_normal([2,1],stddev=0.1))
    b = tf.Variable(tf.random_normal([1],stddev=0.1))
    x = tf.to_float(data[:,0:2])
    y = tf.to_float(data[:,2])
    global_step = tf.contrib.framework.get_or_create_global_step()
    pred = logistic_regression(W,b,x)
    loss = tf.reduce_sum(-tf.reshape(y,[-1,1])*tf.log(pred)-(1-tf.reshape(y,[-1,1]))*tf.log(1-pred))
    train_op = tf.train.GradientDescentOptimizer(args.learning_rate).minimize(loss,global_step)

    tf.summary.scalar('loss',loss)
    log_fetches ={
        "W":W,
        "b":b,
        "loss":loss,
        "step":global_step
    }
    sv = tf.train.Supervisor(logdir = args.checkpoint_dir,save_model_secs=args.checkpoint_interval,
                             save_summaries_secs=args.summary_interval)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with sv.managed_session(config=config) as sess:
        sv.loop(args.log_interval,log_hook,[sess,log_fetches])
        while True:
            if sv.should_stop():
                log.info('stopping supervisor')
            try:
                WArr,bArr,_ = sess.run([W,b,train_op])
                x0 = np.array(data[:8])
                x0_ = np.array(data[8:])
                plt.scatter(x0[:,0],x0[:,1],c='r',label='+')
                plt.scatter(x0_[:,0],x0_[:,1],c='b',label='-')
                x1 = np.arange(-0.2,1.0,0.1)
                y1 = (-bArr-WArr[0]*x1)/WArr[1]
                plt.plot(x1,y1)
                plt.pause(0.01)
                plt.cla()
            except tf.errors.AbortedError:
                log.error('Aborted')
                break
            except KeyboardInterrupt:
                break
        chkpt_path = os.path.join(args.checkpoint_dir, 'on_stop.ckpt')
        log.info("Training complete, saving chkpt {}".format(chkpt_path))
        sv.saver.save(sess, chkpt_path)
        sv.request_stop()




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--learning_rate",default=5e-2,type=float,help='learning rate for the stochastic gradient update.')
    parser.add_argument('--checkpoint_dir', default='summary/', help='directory of summary to save.')
    parser.add_argument('--summary_interval', type=int, default=1, help='interval between tensorboard summaries (in s)')
    parser.add_argument('--log_interval', type=int, default=1, help='interval between log messages (in s).')
    parser.add_argument('--checkpoint_interval', type=int, default=20, help='interval between model checkpoints (in s)')

    args =  parser.parse_args()
    main(args)

效果如图所示:



   大部分功能在线性回归那部分已经说了,这里再补充一些supervisor的用法。
   使用Supervisor的步骤一般是:
(1)创建一个Supervisor对象,将要保存checkpoints以及summaries的目录路径传递给该对象。
(2)利用tf.train.Supervisor.managed_session向supervisor请求一个session。
(3)利用该session来执行训练的op,在每一步都核查supervisor是否要求训练结束。
   图中有一个name为global_step的整型变量,服务会使用它的值来衡量执行的训练步数。sv.should_stop()的判读作用是,当shold_stop()条件设置为true时,这些服务线程中提起的异常会被报告给supervisor。服务线程会通知该条件并且恰当地终止。
   sv.loop()里,第一个参数是每多久该线程运行一次,第二个参数是每次的打印目标,第三个参数是每次要训练的参数。
   初始的模型加载参数,保存参数等等在sv = tf.train.Supervisor就已经准备好了,更多的supervisor功能可以参考官方文档。

参考:
1、https://blog.csdn.net/mijiaoxiaosan/article/details/75021279
2、https://www.cnblogs.com/wuzhitj/p/6648641.html

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值