炼数成金Tensorflow学习笔记之3.1_非线性回归
代码及分析
"""
Created on Sun Mar 15 09:10:07 2020
@author: 寒火qwer
"""
import tensorflow as tf
import numpy as np
import matplotlib. pyplot as plt
x_data = np. linspace( - 0.5 , 0.5 , 200 ) [ : , np. newaxis]
noise = np. random. normal( 0 , 0.02 , x_data. shape)
y_data = np. square( x_data) + noise
x = tf. placeholder( tf. float32, [ None , 1 ] )
y = tf. placeholder( tf. float32, [ None , 1 ] )
weight_l1 = tf. Variable( tf. random_normal( [ 1 , 10 ] ) )
biases_l1 = tf. Variable( tf. zeros( [ 1 , 10 ] ) )
a1 = tf. matmul( x, weight_l1) + biases_l1
l1 = tf. nn. tanh( a1)
weight_l2 = tf. Variable( tf. random_normal( [ 10 , 1 ] ) )
biases_l2 = tf. Variable( tf. zeros( [ 1 , 1 ] ) )
a2 = tf. matmul( l1, weight_l2) + biases_l2
predict = tf. nn. tanh( a2)
loss = tf. reduce_mean( tf. square( y - predict) )
train_op = tf. train. GradientDescentOptimizer( 0.1 ) . minimize( loss)
init_op = tf. global_variables_initializer( )
with tf. Session( ) as sess:
sess. run( init_op)
for _ in range ( 2000 ) :
sess. run( train_op, feed_dict= { x: x_data, y: y_data} )
predict_value = sess. run( predict, feed_dict= { x: x_data} )
plt. figure( )
plt. scatter( x_data, y_data)
plt. plot( x_data, predict_value, 'r-' , lw= 5 )
plt. show( )
np.linspace (start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0) 作用:返回的是 [start, stop](或[start, stop))之间的均匀分布(至于包不包含stop取决于endpoint参数的取值。为True,这时stop就是最后的样本。为False时,不包含stop的值。)np.newaxis 的作用就是在这一位置增加一个一维,这一位置指的是np.newaxis所在的位置np.random.normal (loc=0.0, scale=1.0, size=None) 作用:生成高斯分布的概率密度随机数 loc :float, 此概率分布的均值(对应着整个分布的中心centre) scale :float, 此概率分布的标准差(对应于分布的宽度,scale越大越矮胖,scale越小,越瘦高) size :int or tuple of ints,输出的shape,默认为None,只输出一个值tf.nn.tanh (x, name=None)作用:tanh激活函数 plt.figure ((num=None, figsize=None, dpi=None, facecolor=None, edgecolor=None, frameon=True) 作用:创建自定义画板 num :图像编号或名称,数字为编号 ,字符串为名称 figsize :指定figure的宽和高,单位为英寸; dpi 参数指定绘图对象的分辨率,即每英寸多少个像素,缺省值为80 facecolor :背景颜色 edgecolor :边框颜色 frameon :是否显示边框plt.scatter (x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, verts=None, edgecolors=None, hold=None, data=None, **kwargs) 作用:绘制散点图 x,y :array_like,shape(n,),输入数据plt.plot (x,y,format_string,**kwargs) 作用:绘制折线图plt.show () 作用:显示图像