KNN算法

一、什么是KNN算法

寻找最近的K个数据,推测新数据的分类
在这里插入图片描述

二、算法原理

通用步骤:

1、计算距离
2、升序排列
3、取前K个
4、加权平均

K的选取:

K太大:导致分类模糊
K太小:受个例影响波动较大

如何选取K:

1、经验
2、均方根误差

通过 for循环 + 可视化 选取合适的K

k_range = range(min, max)
k_scores = []
for k in k_range:
    knn = KNeighborsClassifier(n_neighbors = k)

    # loss = -cross_val_score(knn, data_X, data_y, cv = 10, scoring = 'neg_mean_squared_error') #for regression 小
    scores = cross_val_score(knn, data_X, data_y, cv = 10, scoring='accuracy') #for classification 大
    k_scores.append(scores.mean())

plt.plot(k_range, k_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel("Cross-Validated Accuracy")
plt.show()

三、实战应用(材料专业方面-预测新型钙钛矿太阳能电池)

#!/usr/bin/env python 
# -*- coding:utf-8 -*
import csv
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
#读取
with open('material.csv', 'r') as file:
    reader = csv.DictReader(file)
    datas = [row for row in reader]

test_set = []
train_set = []
k_scores = []

#分组
def group(datas):
    random.shuffle(datas)
    n = len(datas) // 5
    global test_set
    global train_set
    test_set = datas[0: n]
    train_set = datas[n:]
#距离
def distance(d1, d2):
    res = 0
    for key in ("Formation_Energy_eV", "E_Above_Hull_eV",
                "Ban_gap_eV", "Volume", "Nsites", "Density_gm_cc"):
        res = res + (float(d1[key]) - float(d2[key])) ** 2
    return res ** 0.5
#KNN
def Knn(data, K):
    res = [
        {"result":train['Has_Bandstructure'], "distance": distance(data, train)}
        for train in train_set
    ]
    res = sorted(res, key = lambda item: item['distance'])
    res2 = res[0: K]
    result = {'True': 0, 'False': 0}
    sum = 0
    for r in res2:
        sum += r['distance']
    for r in res2:
        result[r['result']] += 1 - r['distance'] / sum
    # print(result)
    # print(data['Has_Bandstructure'])
    if result['True'] > result['False']:
        return 'True'
    else:
        return 'False'
#测试阶段
def TT(test_set, K):
    correct = 0
    for test in test_set:
        result = test['Has_Bandstructure']
        result2 = Knn(test, K)
        if result == result2:
            correct +=  1
    score = correct / len(test_set)
    print("准确率:{:.2f}%".format(score * 100))
    return score

k_range = range(2, 50)
for K in k_range:
    scores = []
    for i in range(1, 100):
        group(datas)
        score = TT(test_set, K)
        scores.append(score)
    arr = sum(scores) / len(scores)
    print(arr)
    k_scores.append(arr)
print(k_scores)

plt.plot(k_range, k_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel("Cross-Validated Accuracy")
plt.show()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值