一、什么是KNN算法
寻找最近的K个数据,推测新数据的分类
二、算法原理
通用步骤:
1、计算距离
2、升序排列
3、取前K个
4、加权平均
K的选取:
K太大:导致分类模糊
K太小:受个例影响波动较大
如何选取K:
1、经验
2、均方根误差
通过 for循环 + 可视化 选取合适的K
k_range = range(min, max)
k_scores = []
for k in k_range:
knn = KNeighborsClassifier(n_neighbors = k)
# loss = -cross_val_score(knn, data_X, data_y, cv = 10, scoring = 'neg_mean_squared_error') #for regression 小
scores = cross_val_score(knn, data_X, data_y, cv = 10, scoring='accuracy') #for classification 大
k_scores.append(scores.mean())
plt.plot(k_range, k_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel("Cross-Validated Accuracy")
plt.show()
三、实战应用(材料专业方面-预测新型钙钛矿太阳能电池)
#!/usr/bin/env python
# -*- coding:utf-8 -*
import csv
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
#读取
with open('material.csv', 'r') as file:
reader = csv.DictReader(file)
datas = [row for row in reader]
test_set = []
train_set = []
k_scores = []
#分组
def group(datas):
random.shuffle(datas)
n = len(datas) // 5
global test_set
global train_set
test_set = datas[0: n]
train_set = datas[n:]
#距离
def distance(d1, d2):
res = 0
for key in ("Formation_Energy_eV", "E_Above_Hull_eV",
"Ban_gap_eV", "Volume", "Nsites", "Density_gm_cc"):
res = res + (float(d1[key]) - float(d2[key])) ** 2
return res ** 0.5
#KNN
def Knn(data, K):
res = [
{"result":train['Has_Bandstructure'], "distance": distance(data, train)}
for train in train_set
]
res = sorted(res, key = lambda item: item['distance'])
res2 = res[0: K]
result = {'True': 0, 'False': 0}
sum = 0
for r in res2:
sum += r['distance']
for r in res2:
result[r['result']] += 1 - r['distance'] / sum
# print(result)
# print(data['Has_Bandstructure'])
if result['True'] > result['False']:
return 'True'
else:
return 'False'
#测试阶段
def TT(test_set, K):
correct = 0
for test in test_set:
result = test['Has_Bandstructure']
result2 = Knn(test, K)
if result == result2:
correct += 1
score = correct / len(test_set)
print("准确率:{:.2f}%".format(score * 100))
return score
k_range = range(2, 50)
for K in k_range:
scores = []
for i in range(1, 100):
group(datas)
score = TT(test_set, K)
scores.append(score)
arr = sum(scores) / len(scores)
print(arr)
k_scores.append(arr)
print(k_scores)
plt.plot(k_range, k_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel("Cross-Validated Accuracy")
plt.show()