tensorflow-多元线性模型

原创 2018年04月16日 19:26:49

     数据见deep learning,该模型为多元输入的线性模型,即z=mx+ny+b的空间平面

#--coding:UTF-8--
import tensorflow as tf
import numpy as np
from sklearn import preprocessing

#读取数据为float32形式//tf.zeros数据类型为tf.float32
x_data=np.loadtxt('./ex3x.dat').astype(np.float32)
y_data=np.loadtxt('./ex3y.dat').astype(np.float32)

#对数据进行预处理标准化
scaler=preprocessing.StandardScaler().fit(x_data)
x_data_standard=scaler.transform(x_data)

#定义变量并初始化为0
W=tf.Variable(tf.zeros([2,1]))
b=tf.Variable(tf.zeros([1,1]))

#W为矩阵,用tf.matmul函数进行矩阵相乘运算
y_=tf.matmul(x_data_standard,W)+b

#定义优化器
optimizer=tf.train.GradientDescentOptimizer(0.01)
#定义损失函数
loss=tf.reduce_mean(tf.square(y_-y_data.reshape(-1,1)))
train=optimizer.minimize(loss)

sess=tf.Session()
init=tf.global_variables_initializer()
sess.run(init)
for step in range(1500):
	sess.run(train)
	if step%100==0:
		print step,sess.run(W).flatten(),sess.run(b).flatten()
#flatten函数将参数展开

运行结果如下:

0 [2115.2822 1094.1765] [6808.2524]
100 [83597.23  15084.638] [296170.16]
200 [99544.86    3148.6218] [334545.28]
300 [105389.125   -2527.0706] [339634.53]
400 [107772.35    -4903.2246] [340309.5]
500 [108755.66   -5886.229] [340398.97]
600 [109161.85    -6292.4053] [340410.75]
700 [109329.64   -6460.213] [340411.9]
800 [109398.984   -6529.5547] [340411.9]
900 [109427.63   -6558.196] [340411.9]
1000 [109439.47    -6570.0293] [340411.9]
1100 [109444.33   -6574.912] [340411.9]
1200 [109446.36   -6576.919] [340411.9]
1300 [109447.18   -6577.736] [340411.9]
1400 [109447.445  -6578.082] [340411.9]

同样整个过程并不像结果展示的那样顺利,首先是数据类型报错的问题,tf.zeros数据类型为float32,而读出来的数据类型为float64,所以矩阵相乘时会报错,附上函数解释:tf.zeros(shape, dtype=tf.float32, name=None)

当没有对x_data进行预处理时,其结果如下

0 [13621140.       21583.605] [6808.2524]
100 [nan nan] [nan]
200 [nan nan] [nan]
300 [nan nan] [nan]
400 [nan nan] [nan]
500 [nan nan] [nan]
600 [nan nan] [nan]
700 [nan nan] [nan]
800 [nan nan] [nan]
900 [nan nan] [nan]
1000 [nan nan] [nan]
1100 [nan nan] [nan]
1200 [nan nan] [nan]
1300 [nan nan] [nan]
1400 [nan nan] [nan]

同样又是一排排的nan,从第一行的13621140可以看出,很大程度上应该就是梯度爆炸导致参数数值过大,超过阈值

0 [1362114.1       2158.3606] [680.82526]
2 [1.16299066e+14 1.67974388e+11] [5.035392e+10]
4 [9.9308346e+21 1.4343421e+19] [4.2997471e+18]
6 [8.4799905e+29 1.2247917e+27] [3.6715757e+26]
8 [          inf 1.0458559e+35] [3.1351772e+34]

通过查看前面输出可以看出的确是由于参数过大导致的nan,由此我们也可以看出数据预处理对于参数优化的重要性,如果不同因素的数量级不在一个层次,那么势必其权重也会有特别大的差别,通过对数据预处理,进行标准化,可以维持参数的稳定性。

版权声明: https://blog.csdn.net/qq_33668920/article/details/79961157

多元非线性回归分析

  • 2006年02月23日 09:05
  • 2.82MB
  • 下载

tensorflow实现非线性回归

模拟非线性回归,给定一些二维点,y = x^2 + noise,用梯度下降进行训练,实线前向传播神经网络。 import tensorflow as tf import numpy as np i...
  • y12345678904
  • y12345678904
  • 2017-08-31 11:13:35
  • 1521

tensorflow 非线性回归

#encoding:utf-8 #encoding:utf-8 import tensorflow as tf import numpy as np import matplotlib.pyplot ...
  • SDUTyangkun
  • SDUTyangkun
  • 2017-10-12 20:58:05
  • 435

TensorFlow训练Logistic回归

Logistic回归在用线性模型进行回归训练时,有时需要根据这个线性模型进行分类,则要找到一个单调可微的用于分类的函数将线性回归模型的预测值关联起来。这时就要用到逻辑回归,之前看吴军博士的《数学之美》...
  • wangyangzhizhou
  • wangyangzhizhou
  • 2017-04-22 20:05:56
  • 5462

tensorflow-多元线性模型

     数据见deep learning,该模型为多元输入的线性模型,即z=mx+ny+b的空间平面#--coding:UTF-8-- import tensorflow as tf import ...
  • qq_33668920
  • qq_33668920
  • 2018-04-16 19:26:49
  • 13

TensorFlow学习笔记(二):TensorFlow实现线性回归模型

一、线性回归模型中所涉及到API#导入TensorFlow包 import tensorflow as tf #TensorFlow程序分为两个阶段:准备阶段和执行阶段 #--------------...
  • kenwengqie2235
  • kenwengqie2235
  • 2017-11-09 20:33:28
  • 307

多元非线性回归问题算法

http://dl2.csdn.net/down4/20070726/26064842539.rar 这是我要解决的公式,见附件  
  • dfkoko
  • dfkoko
  • 2007-07-26 06:51:00
  • 1233

[03]tensorflow实现softmax回归(softmax regression)

MNIST数据集MNIST数据集的官网是Yann LeCun’s website。在这里,我们提供了一份python源代码用于自动下载和安装这个数据集。你可以下载这份代码,然后用下面的代码导入到你的项...
  • SA14023053
  • SA14023053
  • 2016-07-11 23:25:40
  • 14367

【机器学习】非线性回归算法分析

AI机器学习 - 非线形回归分析。我们上文深入本质了解了机器学习基础线性回归算法后,本文继续研究非线性回归。非线性回归在机器学习中并非热点,并且较为小众,且其应用范畴也不如其他广。鉴于此,我们本文也将...
  • L70AShC3Q50
  • L70AShC3Q50
  • 2017-12-15 00:00:00
  • 933
收藏助手
不良信息举报
您举报文章:tensorflow-多元线性模型
举报原因:
原因补充:

(最多只允许输入30个字)