import numpy as np
import matplotlib.pyplot as plt
# 随机生成样本。二分类问题。每个类别生成5000个样本数据
np.random.seed(12)
num_observation=5000
#正太分布 multivariate_normal(mean, cov, size=None, check_valid=None, tol=None) cov协方差矩阵是对称的
x1=np.random.multivariate_normal([0,0],[[1,0.75],[0.75,1]],num_observation)
x2=np.random.multivariate_normal([1,4],[[1,0.75],[0.75,1]],num_observation)
# print(x1.shape) (5000, 2)
# 生成数据源
X=np.vstack((x1,x2)).astype(np.float32) # 拼接成数据集 #(10000, 2)
y=np.hstack((np.zeros(num_observation),np.ones(num_observation))) # (10000,)
# 数据可视化
plt.figure(figsize=(12,8))
plt.scatter(X[:,0],X[:,1],c=y,alpha=0.4)
# plt.show()
# 自定义segmod函数
def sigmod(x):
return 1/(1+np.exp(-x))
#计算log likelihood
def log_likelihood(X,y,w,b):
'''
针对所有的样本数据 计算负的log_likelihood 也叫做cross-entropy loss.值越小越好
:param X: 训练数据(特征向量) 大小为N* D
:param y: 训练数据(标签) 一维向量 长度为D
:param w: 模型的参数 一维向量 长度为D
:param b: 模型的偏移量 标量
:return:
'''
#首先按照标签来提取正样本和负样本的下标
pos=np.where(y==1)
neg=np.where(y==0)
#对于正样本计算loss 使用matrix operation(矩阵操作)。如果把每一个样本都循环一遍的话,效率会很低
pos_sum=np.sum(np.log(sigmod(np.dot(X[pos],w)+b)))
#计算负样本的loss
neg_sum=np.sum(np.log(1-sigmod(np.dot(X[pos],w)+b)))
return -(pos_sum+neg_sum)
# 实现逻辑回归模型
def logistic_regression(X,y,num_steps,learning_rate):
'''
:param X: 训练数据(特征向量) 大小为N* D
:param y: 训练数据(标签) 一维向量 长度为D
:param num_steps: 梯度下降法的迭代次数
:param learning_rate: 学习率(步长)
:return:
'''
#初始化参数 w=0 b=0
w=np.zeros(X.shape[1])
b=0
for step in range(num_steps):
# 预测值与实际值的误差
error =sigmod(np.dot(X,w)+b)-y
# 对w,b 进行梯度计算
grad_w=np.matmul(X.T,error)
grad_b=np.sum(error)
#对w,b进行梯度更新
w=w-learning_rate*grad_w
b=b-learning_rate*grad_b
#每隔一段时间计算一下log_likelihood1 看变化值
#一般会慢慢变小。最后收敛
if step%10000==0: # 选取部分 打印出来
print(step,log_likelihood(X,y,w,b))
return w,b
w,b=logistic_regression(X,y,num_steps=100000,learning_rate=5e-5)
print("自写的参数W,b:",w,b)
# 下面用调用逻辑回归模块
from sklearn.linear_model import LogisticRegression
#c设置一个很大的值,意味着不加入正则项。这里为了公平的比较
clf=LogisticRegression(fit_intercept=True,C=1e15)
clf .fit(X,y)
print("调用的参数W,b:",clf.coef_,clf.intercept_)
'''
自写的参数W,b: [-5.03280465 8.24664683] -14.017856497489417
调用的参数W,b: [[-5.02712572 8.23286799]] [-13.99400797]
'''