代码
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
def create_data():
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'labels']
data = np.array(df.iloc[:, [0,1,2,3,-1]])
return data[:, :4], data[:, -1]
X, y = create_data()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
class LogisticRegressionClassifier:
def __init__(self, max_iter=200, lr=0.01):
self.max_iter = max_iter
self.lr = lr
def sigmoid(self, X):
return 1/(1+np.exp(-X))
def data_matrix(self, X):
data_mat = []
for x in X:
data_mat.append([1, *x])
return data_mat
def softmax(self, d):
return np.exp(d) / np.sum(np.exp(d))
def fit(self, X, y):
data_mat = self.data_matrix(X)
self.weights = np.zeros((len(data_mat[0]),3), dtype = np.float32)
for step_ in range(self.max_iter):
for i in range(len(data_mat)):
pre = self.softmax(np.dot(data_mat[i], self.weights))
obj = np.eye(3)[int(y[i])]
err = pre - obj
self.weights -= self.lr * np.transpose([data_mat[i]]) * err
if(step_%5==0):
print("*********************************************************")
print("round {}\nweights\n {} \nerr {} \nscore {}".format(step_, self.weights, err, self.score(X_test,y_test)))
print("distribution\t", pre)
def score(self, X, y):
X=self.data_matrix(X)
right = 0
for i in range(len(X)):
pre = np.dot(X[i], self.weights)
pre2 = np.argmax(pre)
if pre2 == y[i]:
right+=1
return right/len(X)
lrc = LogisticRegressionClassifier(max_iter=500)
lrc.fit(X_train, y_train)
输出
round 490
weights
[[ 0.6455882 3.0346825 -3.6802142 ]
[ 1.4472942 0.852686 -2.2999246 ]
[ 3.3743029 0.15490974 -3.5290432 ]
[-4.4074197 -0.23293304 4.640532 ]
[-2.1648731 -2.5515227 4.716464 ]]
err [-1.16406736e-04 1.16406736e-04 5.53598502e-18]
score 0.9777777777777777
distribution [9.99883593e-01 1.16406736e-04 5.53598502e-18]
*********************************************************
round 495
weights
[[ 0.6471047 3.0558276 -3.7028728 ]
[ 1.4510568 0.85293597 -2.3039396 ]
[ 3.3821776 0.15481405 -3.5368207 ]
[-4.417349 -0.23393045 4.6514635 ]
[-2.1703055 -2.5591588 4.7295337 ]]
err [-1.13765965e-04 1.13765965e-04 5.01576396e-18]
score 0.9777777777777777
distribution [9.99886234e-01 1.13765965e-04 5.01576396e-18]