如何使用TensorFlow
TensorFlow的使用分为下面几步
- 先定义训练的数据集
- 定义输入输出
- 定义计算图
- 定义损失函数
- 训练的过程
好了,你已经学会怎么使用TensorFlow,现在来试着写一个简单的线性回归吧!
目标
实现线性回归方程:y = a*x+b
首先先导入包,使用V1的版本
import numpy as np
import matplotlib.pyplot as plt
try:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
except:
import tensorflow as tf
定义训练数据集
定义x函数为-1到1,间距100
trainX = np.linspace(-1,1,100)
加入噪声,均值为0,方差为0.05,形状和trainX一样
noise = np.random.normal(0,0.05,trainX.shape)
定义y函数为线性函数,同时添加一些噪声数据
trainY = 4*trainX+10+noise
训练集定义结束,开始定义输入输出
定义输入值,输入结构的输入行数不固定
xs=tf.placeholder(tf.float32)
ys=tf.placeholder(tf.float32)
定义输出值 a ,b值为浮点型名叫a b
a = tf.Variable(0.0,name="a")
b = tf.Variable(0.0,name="a")
输入输出定义完成后,开始定义计算图
这次计算图是个线性方程
y = a*xs+b
接着定义损失函数,让损失值在一定的区间内
loss = tf.square(y-ys)
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)#设置反向传播算法
最后可以创建会话,开始训练
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
plt.ion() # 打开互交模式
for i in range(10):
for (X,Y) in zip(trainX,trainY):
_,w_value,b_value = sess.run([train_step,w,b],feed_dict={xs:X,ys:Y})
print("step:{},w:{},b:{}".format(i+1,w_value,b_value))
plt.plot(trainX,trainY,'+')
plt.plot(trainX,w.eval()*trainX+b.eval())
# 暂停时间
plt.pause(0.5)
运行结果
完整代码
import numpy as np
import matplotlib.pyplot as plt
# 首先先导入包,使用V1的版本
try:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
except:
import tensorflow as tf
# 定义训练数据集
trainX = np.linspace(-1,1,100)
noise = np.random.normal(0,0.05,trainX.shape)
trainY = 4*trainX+10+noise
# 定义输入输出
xs=tf.placeholder(tf.float32)
ys=tf.placeholder(tf.float32)
a = tf.Variable(0.0,name="a")
b = tf.Variable(0.0,name="a")
# 定义计算图
y = a*xs+b
loss = tf.square(y-ys)
# 定义损失函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)#设置反向传播算法
# 创建会话,开始训练
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
plt.ion() # 打开互交模式
for i in range(10):
for (X,Y) in zip(trainX,trainY):
_,a_value,b_value = sess.run([train_step,a,b],feed_dict={xs:X,ys:Y})
print("step:{},a:{},b:{}".format(i+1,a_value,b_value))
plt.plot(trainX,trainY,'+')
plt.plot(trainX,a.eval()*trainX+b.eval())
# 暂停时间
plt.pause(0.5)