TensorFlow 基础练习 5 一个简单的回归算法

今天利用TensorFlow实现一个简单的回归算法。

1、首先导入库

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

 2、创建数据集

这里使用numpy库生成输入数据 X, 输出数据 Y,并且在生成 Y 时加入了噪声。

# 初始化随机种子
tf.set_random_seed(123)
np.random.seed(123)


# 创建输入数据 X 。
x = np.linspace(-2, 2, 100)   # 形状 (100, )
# 加噪声
noise = np.random.normal(0, 0.2, 100)
# 创建输出数据 Y
y = np.power(x, 2) + noise  # 形状 (100,)
x, y = np.reshape(x, newshape=(-1, 1)), np.reshape(y, newshape=(-1, 1))  # 使数据变为(100,1)的形状
print(x.shape, y.shape)

3、创建神经网络,

这里使用了一个隐层,隐层具有10个神经元。输出层为一个神经元。

# 简单的神经网络, 完成线性回归。
# 1、定义输入与输出变量的占位符
tf_x = tf.placeholder(dtype=tf.float32, shape=x.shape)
tf_y = tf.placeholder(dtype=tf.float32, shape=y.shape)

# 定义神经网络
net1 = tf.layers.dense(tf_x, 10, activation=tf.nn.relu)
output = tf.layers.dense(net1, 1)

4、定义损失函数

对于回归问题一般使用均方误差,而对于分类问题一般使用交叉熵。这里使用 均方误差,并利用梯度下降优化器完成优化。

# 定义损失,回归问题一般使用 均方根误差。RMSE
loss = tf.losses.mean_squared_error(tf_y, output)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.3)
train_op = optimizer.minimize(loss)

5、执行会话

 这里使用里 plt.ion() 和 plt.ioff() 。实现动态的显示拟合数据与原数据的变化。

plt.ion()    # 交互式显示图像

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    # 训练次数 step
    step = 1000
    for i in range(step):
        # 训练
        _, loss_ow, y_pred = sess.run([train_op, loss, output],
                                  feed_dict={tf_x: x, tf_y: y})

        if i % 10 == 0:
            print('loss_ow=', sess.run(loss, feed_dict={tf_x: x, tf_y: y}))

            # 画图
            plt.cla()
            plt.scatter(x, y)
            plt.plot(x, y_pred, 'r-', lw=2)
            plt.text(0.5, 0, loss_ow, fontdict={'size': 20, 'color': 'red'})    # 设置图像上的文本。
            plt.pause(0.1)

plt.ioff()
plt.show()

 结果:

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值