KNN算法学习

一、KNN是什么?

 knn,K-Nearest Neighbor是一种分类/回归算法。

 意为寻找最近的K个数据,以推测出新数据的分类


二、算法原理

K近邻算法:给定一个训练数据集,对新的的输入实例,在训练数据集中找到与该实例最邻近的的K个实例,这K个实例的多数属于某个类,就把该实例分为这个类。

K值选择、距离度量、以及分类决策(一般多数表决)为K近邻算法的三个基本要素。

1.通用步骤

  • 计算距离(常用欧几里得距离或马氏距离)
  • 按距离升序排列
  • 取前K个
  • 进行加权平均

2.关于K值的选取

根据经验、均方根误差选取

 如K值过大,将导致分类模糊。太小,将受个例影响,波动较大。


实战应用

  • 以部分癌症监测数据进行模拟检测
    • M恶性,B良性。相当于true,false
    • import csv
      import random
      
      # 打开文件
      with open('Prostate_Cancer.csv', 'r') as file:
          # 以字典的形式读取文件
          reader = csv.DictReader(file)
          datas = [row for row in reader]
      
      # 将数据进行随机打乱
      random.shuffle(datas)
      
      # 分组(测试集和训练集)
      n = len(datas)//3      # 整除去掉小数
      test_set = datas[0:n]   # 取测试集
      train_set = datas[n:]
      
      # 进行KNN操作
      
      # 计算距离
      def distance(d1, d2):
          res = 0
          for key in ("radius", "texture", "perimeter", "area", "smoothness", "compactness", "symmetry", "fractal_dimension"):
              res += (float(d1[key])-float(d2[key]))**2
      
          return res**0.5
      
      K = 5     # 可取任意值,适度,过大过小都不好
      def knn(data):
          # 距离
          res = [
              {"result": train['diagnosis_result'], "distance": distance(data,train)}
              for train in train_set
          ]
          # 排序,升序
          res = sorted(res, key=lambda item:item['distance'])
      
          # 取前K个
          res2 = res[0:K]
      
          # 加权平均
          result = {'B':0, 'M': 0}
      
          # 总距离
          sum = 0
          for r in res2:
              sum+=r['distance']
      
          # 具体算每个人的数
          for r in res2:
              result[r['result']] += 1-r['distance'] / sum
              
          # 确诊结果
          if result['B']>result['M']:
              return 'B'
          else:
              return 'M'
      
      # 测试阶段
      correct = 0
      for test in test_set:
          # 真实结果
          result = test['diagnosis_result']
          # 测试结果
          result2 = knn(test)
      
          if result == result2:
              correct += 1
      
      print("准确率:{:.2f}%".format(100*correct/len(test_set)))

      参考:https://www.bilibili.com/video/BV1Nt411i7oD?from=search&seid=4603953531395093043

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值