k-Nearest Neighbor
Nearest Neighbor
Compare with the distance between the training set
Distance:
-
L1 distance
-
L2 distance
Algorithm:
class NearestNeighbour(Object):
def __init__(self):
pass
def train(X_train, y_train, self):
self.X_train = X_train
self.y_train = y_train
def predict(X_test, self):
y_prd = []
num_train = self.X_train.shape[0]
num_test = X_test.shape[0]
l2 = np.zeros[num_test, num_train]
l2 += np.sum(self.X_train ** 2, axis=1).reshape(1, num_train)
l2 += np.sum(X_test ** 2, axis=1).reshape(num_test, 1)
l2 -= 2 * np.dot(X_test.T, self.X_train)
l2 = np.sqrt(l2)
y_prd = self.y_train[np.argmax(l2, axis=1)]
return y_prd
k-Nearest Neighbour
To find the k closest label and vote the best label
Alogorithm:
'''same as above'''
class kNearestNeighbour(Object):
def __init__(self):
pass
def train(X_train, y_train, self):
self.X_train = X_train
self.y_train = y_train
def predict(X_test, k, self):
y_prd = []
num_train = self.X_train.shape[0]
num_test = X_test.shape[0]
l2 = np.zeros[num_test, num_train]
l2 += np.sum(self.X_train ** 2, axis=1).reshape(1, num_train)
l2 += np.sum(X_test ** 2, axis=1).reshape(num_test, 1)
l2 -= 2 * np.dot(X_test, self.X_train.T)
l2 = np.sqrt(l2)
for i in range(num_test):
y_closest
l2_index = np.argsort(l2[i])[0:k]
y_closest = self.y_train[dists_index]
y_pred[i] = np.bincount(closest_y).argmax()
return y_prd
Cross-Validation
A more sophisticated technique for hyperparameter tuning
Typical number of folds would be 3-fold, 5-fold or 10-fold cross-validation
Visiulization Algorithm:
# plot the raw observations
for k in k_choices:
accuracies = k_to_accuracies[k]
plt.scatter([k] * len(accuracies), accuracies)
# plot the trend line with error bars that correspond to standard deviation
accuracies_mean = np.array([np.mean(v) for k,v in sorted(k_to_accuracies.items())])
accuracies_std = np.array([np.std(v) for k,v in sorted(k_to_accuracies.items())])
plt.errorbar(k_choices, accuracies_mean, yerr=accuracies_std)
plt.title('Cross-validation on k')
plt.xlabel('k')
plt.ylabel('Cross-validation accuracy')
plt.show()