#coding:utf-8
"""
随机产生32组生产出的零件的体积和重量,训练3000轮,每500轮输出一次损失函数
神经网络框架:输入层2个神经元,隐藏层3个神经元,输出层1个神经元
"""
"""0导入模块:导入模块,生成模拟数据集"""
import tensorflow as tf
import numpy as np
BATCH_SIZE = 8 #一次给神经网络喂入8组数据,不能太大
seed = 23455 # seed是随机数种子,为0则产生的随机数不同
# 基于seed产生随机数
rng = np.random.RandomState(seed) #对于某一个伪随机数发生器,只要种子seed相同,则产生的随机数序列就是相同的
#随机数返回32行2列的矩阵 表示32组 体积和重量 作为输入数据集
X = rng.rand(32,2)
#每次取出一行,和小于1则给Y赋值1,否则赋值0,制作标签
Y = [[int((x0 + x1) < 1)] for (x0,x1) in X]
print("X:\n",X)
print("Y:\n",Y)
"""1向前传播:定义神经网络的输入、参数和输出,定义前向传播过程"""
x = tf.placeholder(tf.float32, shape=(None,2)) # 占位函数,x是输入,n行2列,每行是一条样本数据
y_ = tf.placeholder(tf.float32, shape=(None,1)) # 标准值
w1 = tf.Variable(tf.random_normal([2,3], stddev=1, seed=1)) #输入层到隐藏层,输入层2元,隐藏层3元
# 以高斯分布的方式产生2行3列(由NN结构决定)数据,stddev=1:标准差为1,seed=1产生相同的随机数
w2 = tf.Variable(tf.random_normal([3,1], stddev=1, seed=1)) # 输出层1元
a = tf.matmul(x,w1) # 得到隐藏层的输出
y = tf.matmul(a,w2) # 得到输出,预测值
"""2反向传播:定义损失函数以及反向传播方法"""
loss = tf.reduce_mean(tf.square(y - y_)) # 均方误差作为损失函数
train_step= tf.train.GradientDescentOptimizer(0.001).minimize(loss)
# 定义优化器,学习率设为0.001,优化目的是最小化loss,还可以使用其他优化器
# 三种优化器的区别参考:https://blog.csdn.net/VictorHan01/article/details/98754644
# train_step = tf.train.MomentumOptimizer(0.001,0.9).minimize(loss)
# train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
"""3生成会话,训练steps轮"""
with tf.Session() as sess:
init_op = tf.global_variables_initializer() # 全局变量初始化
sess.run(init_op)
# 输出目前未经训练的参数取值
print("w1:\n", sess.run(w1))
print("w2:\n", sess.run(w2))
print("\n")
#训练模型
STEPS = 3000
for i in range(STEPS):
start = (i * BATCH_SIZE) % 32 #对总数取余,从数据中循环抽取BATCH_SIZE条数据的起始位置
end = start + BATCH_SIZE
sess.run(train_step, feed_dict={x:X[start:end],y_:Y[start:end]})
# 执行train_step,喂入参数,train_step用到了loss用到y和y_,y用到a和w2,a用到x
# feed_dict喂入参数,参数必须是字典的形式,feed_dict一定是一个字典
if i % 500 == 0: # 每500次输出一次结果
total_loss = sess.run(loss,feed_dict={x:X, y_:Y})
print("After %d training steps, the loss on all data is %f"%(i,total_loss))
#输出训练结束后的参数取值:
print("\nw1:\n",sess.run(w1)) # 训练是训练于loss相关的变量,框架中定义了两个变量w1,w2
print("w2:\n", sess.run(w2)) # sess.run(w2)是输出w2的值,否则是输出w2的数据类型信息