tensorflow 神经网络解决二分类问题

import tensorflow as tf
from numpy.random import RandomState

# 定义训练数据batch大小
batch_size = 8

# 定义神经网络的参数
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))

# 在shape的一个维度上使用None可以方便使用不大的batch大小,在训练时需要把数据分
# 成比较小的batch,但是在测试时,可以一次性使用全部的数据。当数据集比较小时这样比较
# 方便测试,但数据集比价大时,将大量数据放入一个batch可能会导致内存溢出。
x = tf.placeholder(tf.float32, shape=(None, 2), name='x-input')
y_ = tf.placeholder(tf.float32, shape=(None, 1), name='y-input')

# 定义神经网络前向传播的过程
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)

# 定义损失函数和反向传播算法
cross_entropy = -tf.reduce_mean(
    y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

# 通过随机数生成一个模型数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)

# 定义规则来给出样本的标签。在这里所有x1+x2<1的样例都被认为是正样本(比如零件合格),
# 而其他为负样本(比如零件不合格)。和Tensorflow游乐场中的表示法不大一样的地方是,
# 这里使用0来表示负样本,1来表示正样本。大部分解决分类问题的神经网络都会采用0和1的表示方法。
Y = [[int(x1+x2 < 1)] for (x1, x2) in X ]

# 创建一个会话来运行Tensorflow程序
with tf.Session() as sess:
    init_op = tf.initialize_all_variables()
# 初始化变量
    sess.run(init_op)
    print sess.run(w1)
    print sess.run(w2)

    '''在训练之前神经网络参数的值为:

    '''

# 设定训练的轮数
STEPS = 5000
for i in range(STEPS):
    # 每次选取batch_size个样本进行训练。
    start = (i * batch_size) % dataset_size
    end = min(start+batch_size, dataset_size)

    # 通过选取的样本训练神经网络并更新参数。
    sess.run(train_step,
             feed_dict={x: X[start:end], y_: Y[start:end]})
    if i % 1000 == 0:
        #每隔一段时间计算在所有数据上的交叉熵并输出。
        total_cross_entropy = sess.run(
            cross_entropy, feed_dict={x: X, y_: Y})
        print ("After %d training step(s), cross entropy on all data is % g" % (i, total_cross_entropy))

        '''
        输出结果:

        '''

    print sess.run(w1)
    print sess.run(w2)

    '''
    在训练之后神经网络参数的值:

    '''
上面的程序实现了训练神经网络的全部过程。从中可以总结出训练神经网络的过程可以分为以下3步:
1、定义神经网络的结构和前向传播的过程。
2、定义损失函数以及选择反向传播优化的算法。
3、生成会话(tf.Session)并且在训练数据上反复运行反向传播优化算法。
无论神经网络的结构如何变化,这3个步骤是不变的
展开阅读全文

没有更多推荐了,返回首页