tensorflow1.1/线性回归

环境:tensorflow1.1 python3 matplotlib2.02

tensorflow 1.1和之前版本有了很大的改动,在构建神经网络方面代码量减少了很多,matplotlib2.02在画图上也比之前好看了许多

#coding:utf-8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(-1,1,500)[:,np.newaxis] #列向量
noise = np.random.normal(0,0.1,x.shape)
y = np.power(x,3) + noise

xs = tf.placeholder(tf.float32,x.shape)
ys = tf.placeholder(tf.float32,y.shape)

#构建神经网络
#输入,输出神经元个数,激活函数
l1 = tf.layers.dense(xs,20,tf.nn.relu) #输出10个神经元的隐藏层,激活函数relu
output = tf.layers.dense(l1,1) #输入l1,输出神经元个数1

#定义均方误差loss
#tf.losses.mean_squared_error
loss = tf.losses.mean_squared_error(ys,output) #均方误差
#定义优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.4).minimize(loss) #数据量较小调大learning_rate使其学习加快

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    plt.ion() #打开交互模式
    for step in range(100):
        _,c = sess.run([optimizer,loss],feed_dict={xs:x,ys:y})
        prediction = sess.run(output,feed_dict={xs:x}) #计算预测值
        if step % 5 == 0:
            #可以用clf()来清空当前图像,用cla()来清空当前坐标
            plt.clf()#清空当前图像
            plt.scatter(x,y)
            plt.plot(x,prediction,'c-',lw='5')
            plt.text(0,0.5,'cost=%.4f' % c,fontdict={'size':15,'color':'red'}) #添加text,位置在坐标轴0,0.5处
            plt.pause(0.1) #暂停0.1s
    plt.ioff() #关闭交互模式
    plt.show()

结果

这里写图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值