训练无法打印theta

本文介绍了使用Python进行机器学习线性回归的过程,包括数据预处理(归一化),梯度下降法和正规方程求解权重参数theta,以及学习率对梯度下降的影响。通过示例展示了不同学习速率下损失函数的收敛情况,并对比了两种方法的预测结果。
摘要由CSDN通过智能技术生成

搞了一天 发现打印theta 打印不出来值,显示为nan

`
import pandas
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

#在机器学习过程中会用的一些函数,这里全部进行列出,方便后期的调用

#归一化函数
def normalization(X):
‘’’
输入参数:X
‘’’
#平均值
#mu = np.mean(X, axis=0, dtype=None, out=None, keepdims=0, where=0)
mu = np.mean(X, axis=0)
# ddof : int, optional Means Delta Degrees of Freedom.
#The divisor used in calculationsis N - ddof, where N represents the number of elements.By default ddof is zero.
#σ 标准差
#sigma = np.std(a, axis=0, dtype=None, out=None, ddof=0, keepdims=0,where=0)
sigma = np.std(a, axis=0, ddof=1)
#归一化公式
X_norm = (X-mu)/ sigma
#返回值
return X_norm,mu,sigma

#计算损失函数
def computeCostMulti(X,y,theta):
‘’’
输入参数:X
结果参数:y
损失参数:thetaθ
‘’’

#Return the shape of an array.
m = X.shape[0] #行数
#dot 代表乘
costs = X.dot(theta) - y
total_cost = costs.transpose().dot(costs)/(2 * m)
return total_cost[0][0]

#梯度下降函数
def gradientDescentMulti(X, y, theta, alpha, iterNum):
“”"
梯度下降实现
:param X:
:param y:
:param theta:
:param alpha:
:param iterNum:
:return:
“”"
m = len(X)

J_history = list()

for i in range(0, iterNum):
    costs = X.dot(theta) - y
    theta = theta - np.transpose(costs.transpose().dot(X) * (alpha / m))

    J_history.append(computeCostMulti(X, y, theta))

return theta, J_history

#梯度下降速度比较
def learningRatePlot(X_norm, y):
“”"
不同学习速率下的梯度下降比较
:param X_norm:
:param y:
:return:
“”"
colors = [‘b’, ‘g’, ‘r’, ‘c’, ‘m’, ‘y’, ‘k’]
plt.figure()
iter_num = 50
# 如果学习速率取到3,损失函数的结果随着迭代次数增加而发散,值越来越大,不太适合在同一幅图中展示
for i, al in enumerate([0.01, 0.03, 0.1, 0.3, 0.5, 0.7, 0.9]):
ta = np.zeros((X_norm.shape[1], 1))
ta, J_history = gradientDescentMulti(X_norm, y, ta, al, iter_num)

    #plt.plot([i for i in range(len(J_history))], J_history, colors[i], label=str(al))

#plt.title("Learn Rate")
#plt.legend()
#plt.show()

#正规方程下降函数
def normalEquation(X, y):
“”"
正规方程实现
:param X:
:param y:
:return:
“”"
#np.linalg.inv 矩阵求逆
return np.linalg.inv(X.transpose().dot(X)).dot(X.transpose()).dot(y)
##后续涉及到的函数会继续增加

if name == ‘main’:
#数据读取
data_path = r’D:/15-创新规划云平台/电芯智能预测/learn-coding/SVM/新建文本文档.csv’
#delimiter 分隔符
data = pandas.read_csv(data_path, delimiter=“,”, header=None)
# 切分特征和目标, 注意:索引是从0开始的
#iloc 是按数字索引
X = data.iloc[:, 0:2].values #取前两列
#print(X)
y = data.iloc[:, 2:3].values #取第三列
#print(y)

# 数据标准化
X_norm, mu, sigma = normalization(X)
#mpy.ones = ones(shape, dtype=None, order='C', *, like=None)Return a new array of given shape and type, filled with ones.
ones = np.ones((X_norm.shape[0], 1))

# 假设函数中考虑截距的情况下,给每个样本增加一个为1的特征
X_norm = np.c_[ones, X_norm]

# 初始化theta
"""
numpy.zeros = zeros(...)
zeros(shape, dtype=float, order='C', *, like=None)
Return a new array of given shape and type, filled with zeros.
"""
theta = np.zeros((X_norm.shape[1], 1))
print (X_norm)
print(theta)

# 梯度下降学习速率为0.02
alpha = 0.02
# 梯度下降迭代次数为200
iterNum = 200
# 梯度下降
theta ,J_history = gradientDescentMulti(X_norm, y, theta, alpha, iterNum)
print("没取到值"+str(theta))
# 画出梯度下降过程中的收敛情况
plt.figure()
plt.plot([i for i in range(len(J_history))], J_history)
plt.title("alpha:" +str(alpha))
plt.show()


# 使用不同学习速率下的收敛情况
learningRatePlot(X_norm, y)

# 开始进行预测

#方法一:梯度下降求解,标准化(归一化)后,进行预测
x_pre = np.array([1650,3])

x_pre_norm = (x_pre - mu) / sigma
print(x_pre,mu,sigma)
print(x_pre_norm)
numpy_ones = np.ones((1,))
print(numpy_ones)
x_pre_norm = np.concatenate((np.ones((1,)), x_pre_norm))
print(x_pre_norm)

price = x_pre_norm.dot(theta)
print("通过梯度下降求解的参数x1、x2的终值为:%f" % price[0])

#方法二:使用正规方程求解,进行预测

X_ = np.c_[ones, data.iloc[:, 0:2].values]
y_ = data.iloc[:, 2:3].values

theta = normalEquation(X_, y)

x_pre = np.array([1, 1650, 3])
price = x_pre.dot(theta)
plt.plot(y_,X_                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          )
print("通过正规方程求解的参数x1、x2的终值为为:%f" % price[0])
#noeq = np.linalg.inv(X.transpose().dot(X)).dot(X.transpose()).dot(y)
#print("这是什么" + str(noeq))
#now = datetime.now()
#print("时间"+str(now))

`
在这里插入图片描述
在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值