线性可分
import numpy as np
import matplotlib.pyplot as plt
train = np.loadtxt('images2.csv', delimiter=',', skiprows=1)
train_x = train[:, 0:2]
train_y = train[:, 2]
theta = np.random.rand(3)
mu = train_x.mean(axis=0)
sigma = train_x.std(axis=0)
def standardize(x):
return (x - mu) / sigma
train_z = standardize(train_x)
X = np.hstack((np.ones((train_z.shape[0], 1)), train_z))
def f(X):
return (1 / (1 + np.exp(-np.dot(X, theta))))
def classify(x):
return (f(x) >= 0.5).astype(np.int)
epoch = 5000
ETA = 1e-3
count = 1
for _ in range(epoch):
theta = theta - ETA * np.dot((f(X) - train_y), X)
print('第{}次,theta={}'.format(count, theta))
count += 1
x0 = np.linspace(-2, 2, 100)
plt.plot(train_z[train_y == 1, 0], train_z[train_y == 1, 1], 'o')
plt.plot(train_z[train_y == 0, 0], train_z[train_y == 0, 1], 'x')
plt.plot(x0, -(theta[0] + theta[1] * x0) / theta[2], linestyle='dashed')
plt.axis('scaled')
plt.show()
线性不可分
import numpy as np
import matplotlib.pyplot as plt
train = np.loadtxt('data3.csv', delimiter=',', skiprows=1)
train_x = train[:, 0:2]
train_y = train[:, 2]
mu = train_x.mean(axis=0)
sigma = train_x.std(axis=0)
def standardize(x):
return (x - mu) / sigma
train_z = standardize(train_x)
def to_matrix(x):
return np.hstack([np.ones([x.shape[0], 1]), x, x[:, 0, np.newaxis] ** 2])
X = to_matrix(train_z)
theta = np.random.rand(4)
def f(X):
return 1 / (1 + np.exp(-np.dot(X, theta)))
def classify(X):
return (f(X) >= 0.5).astype(np.int)
epoch = 5000
ETA = 1e-3
count = 1
accuracies = []
for _ in range(epoch):
theta = theta - ETA * np.dot(f(X) - train_y, X)
result = classify(X) == train_y
accuracy = len(result[result == True]) / len(result)
accuracies.append(accuracy)
print('第{}次更新,theta={}'.format(count, theta))
count += 1
x1 = np.linspace(-2, 2, 100)
plt.plot(train_z[train_y == 1, 0], train_z[train_y == 1, 1], 'o')
plt.plot(train_z[train_y == 0, 0], train_z[train_y == 0, 1], 'x')
plt.plot(x1, -(theta[3] * (x1 ** 2) + theta[1] * x1 + theta[0]) / theta[2], linestyle='dashed')
plt.axis('scaled')
plt.show()
x = np.arange(len(accuracies))
plt.plot(x,accuracies)
plt.show()