常见API例子运用

#命令行参数
tf.app.flags.DEFINE_integer("max_step", 1000, "train step number")

FLAGS = tf.app.flags.FLAGS

def linearregression():
    """
    tensorflow实现线性回归
    :return:
    """
    with tf.variable_scope("original_data"):
        #生成一个100个一维数据
        X = tf.random_normal([100, 1], mean=0.0, stddev=1.0, name="original_data_x")
        #用100个一维数据生成一个符合 y=0.8x+0.7分布的[100,1]矩阵
        y_true = tf.matmul(X, [[0.8]]) + [[0.7]]

    with tf.variable_scope("liner_model"):
        #随机初始化w,b
        weights = tf.Variable(initial_value=tf.random_normal([1, 1]), trainable=False, name="w")
        bias = tf.Variable(initial_value=tf.random_normal([1, 1]), name="b")
        #构建预测线性关系的计算矩阵
        y_predict = tf.matmul(X, weights) + bias

    with tf.variable_scope("loss"):
        #构造损失函数
        loss = tf.reduce_mean(tf.square(y_predict - y_true))

    with tf.variable_scope("optimizer"):
        #剃度下降优化损失,指定学习率
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(loss)

    #1.收集观察张量
    tf.summary.scalar('losses', loss)
    tf.summary.histogram('weight', weights)
    tf.summary.histogram('biases', bias)

    #合并收集的张量
    merge = tf.summary.merge_all()


    #初始化变量
    init_op = tf.global_variables_initializer()

    #创建一个saver
    saver = tf.train.Saver()

    #开启会话进行训练
    with tf.Session() as sess:
        sess.run(init_op)

        filewriter = tf.summary.FileWriter("./tmp/summary", graph=sess.graph)

        # print("weights:", sess.run(weights))
        # print("bias:", sess.run(bias))
        # saver.restore(sess, "./tmp/ckpt/linerregression")
        # print("weights:", sess.run(weights))
        # print("bias:", sess.run(bias))
        for i in range(FLAGS.max_step):
            sess.run(optimizer)
            # print("loss:", sess.run(loss))
            # print("weights:", sess.run(weights))
            # print("bias:", sess.run(bias))
            summary = sess.run(merge)

            filewriter.add_summary(summary, i)
            print("train loss:%f, weights:%f, bias:%f" % (sess.run(loss), sess.run(weights), sess.run(bias)))

            #checkpoint:检查点文件格式
            saver.save(sess, "./tmp/ckpt/linerregression")

    return None


if __name__ == '__main__':
    linearregression()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值