简单神经网络的搭建思路
神经网络的作用
神经网络最重要的用途是分类,简单讲,他是一个分类器,利用现有的数据找出输入与输出之间得权值关系(近似),然后利用这样的权值关系进行仿真,例如输入一组数据仿真出输出结果。
搭建简单神经网络
- 定义输入,输出,定义前向传播过程
- 定义损失函数和反向传播方法
- 生成会话,训练
- 打印结果
#coding :utf-8
#Simple NN of two layers (total connection)
import tensorflow as tf
import numpy as np
BATCH_SIZE = 8
seed = 23455
rng = np.random.RandomState(seed)
#x三十二行两列
X = rng.rand(32,2)
Y = [[int(x0 + x1 <1)] for (x0,x1) in X]
print("X:\n",X)
print("Y:\n",Y)
#x,y 组数未知,仅生成下述特征
x = tf.placeholder(tf.float32,shape = (None, 2))
y_ = tf.placeholder(tf.float32,shape = (None, 1))
#正态分布,2*3矩阵,标准差,均值,随机种子
w1 = tf.Variable(tf.random_normal([2,3], stddev = 1, seed = 1))
w2 = tf.Variable(tf.random_normal([3,1], stddev = 1, seed = 1))
#定义前向传播过程
a = tf.matmul(x,w1)
y = tf.matmul(a,w2)
#定义损失函数及反向传播方法,均方误差
loss = tf.reduce_mean(tf.square(y-y_))
train_step = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
#用会话计算结果
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
print("w1 in this section is: \n",sess.run(w1))
print("w2 in this section is: \n",sess.run(w2))
print("\n")
#训练模型
STEPS = 3000
for i in range(STEPS):
start = (i*BATCH_SIZE)%32
end = start + BATCH_SIZE
sess.run(train_step ,feed_dict = {x: X[start:end],y_:Y[start:end]})
if(i % 500 ==0):
total_loss = sess.run(loss,feed_dict={x:X, y_:Y})
print("Afer %d training step(s), loss on all data is %g" %(i, total_loss))
print("\n")
print("w1 after training is: \n",sess.run(w1))
print("w2 after training is: \n",sess.run(w2))