小白学Tensorflow之Logistic回归

利用Tensorflow实现Logistic回归
第一,我们先来设计两个函数,使得在后续的程序中不用重复编写相同的代码。

def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev = 0.01))

def model(X, w):
    return tf.matmul(X, w)

第二,我们带入mnist的数据集,具体方法可以参考官网。

# 导入数据
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

第三,构建损失函数,我们采用softmax和交叉熵来训练模型

# 构建损失函数,我们采用softmax和交叉熵来训练模型
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y))
learning_rate = 0.01
train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

完整代码如下:

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf 
import input_data

def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev = 0.01))

def model(X, w):
    return tf.matmul(X, w)

# 导入数据
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

# 设置占位符
X = tf.placeholder("float", [None, 784])
Y = tf.placeholder("float", [None, 10])

# 初始化权重
w = init_weights([784, 10])

# 构建模型
py_x = model(X, w)

# 构建损失函数,我们采用softmax和交叉熵来训练模型
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y))
learning_rate = 0.01
train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
predict_op = tf.argmax(py_x, 1)

with tf.Session() as sess:

    init = tf.initialize_all_variables()
    sess.run(init)

    for i in xrange(100):
        for start, end in zip(range(0, len(trX), 128), range(128, len(trX), 128)):
            sess.run(train_op, feed_dict = {X: trX[start:end], Y: trY[start:end]})
        print i, np.mean(np.argmax(teY, axis = 1) == sess.run(predict_op, feed_dict = {X: teX, Y: teY}))

简书同步更新:http://www.jianshu.com/p/f51f0ca4278c

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值