如何使用k近邻算法解决约会问题

实验环境 PyCharm,Python3

有句古话叫做 ”近朱者赤近墨者黑“ ,这差不多就是k-近邻算法的中文描述。

理论基础

k近邻算法(k-NN)是一种基础且易于实现的机器学习方法,它的工作原理是在特征空间中寻找与新数据点最近的k个已知数据点,并根据这些最近点的标签来预测新数据点的标签。在约会问题中,我们可以使用k-NN来根据人们的兴趣、活动偏好和其他特征来预测他们是否能成为彼此的好搭档。

算法流程

  1. 收集数据:获取约会者的特征数据,如兴趣爱好、出行频率、喜欢的食物类型等。
  2. 准备数据:处理数据,转换成适合机器学习算法的格式。
  3. 分析数据:可视化数据,了解特征之间的关系。
  4. 训练算法:这里的训练就是简单地存储数据。
  5. 测试算法:使用k-NN算法对约会者进行分类,以测试算法的效果。
  6. 使用算法:构建完整的程序,输入一个人的特征,来预测他/她的约会类型。

下面是一个具体的例子

下面就以经典的海伦约会为例

    海伦一直使用在线约会网站寻找适合自己的约会对象。尽管约会网站会推荐不同的人选,但她并不是喜欢每一个人。经过一番总结,她发现曾交往过三种类型的人:不喜欢的人,魅力一般的人,极具魅力的人

    尽管发现了上述规律,但海伦依然无法将约会网站推荐的匹配对象归入恰当的类别。海伦希望分类软件可以更好地帮助她将匹配对象划分到确切的分类中,此外海伦收集了一些约会网站未记录的数据信息,她认为这些数据有助于匹配对象的归类。

海伦约会的数据集

每个数据占一行,总共有1000行,前三列为特征

第一列:每年获得的飞行常客里程数

第二列:看视频玩游戏所耗时间百分比

第三列:每周消费的冰淇淋公升数

第四列:男人的类别(不喜欢,一般,魅力十足)
————————————————

                            版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
                        
原文链接:https://blog.csdn.net/qq_19381989/article/details/98471819

代码实现

使用Python来实现k近邻算法,解决约会问题。

import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from operator import itemgetter

def create_dataset(filename):
    with open(filename) as f:
        lines = f.readlines()
    data = [line.strip().split(' ') for line in lines]
    features = np.array(data).astype(float)
    labels = features[:, -1].astype(int)
    return features[:, :-1], labels

def normalize(features):
    min_vals = features.min(axis=0)
    max_vals = features.max(axis=0)
    ranges = max_vals - min_vals
    norm_features = (features - min_vals) / ranges
    return norm_features, ranges, min_vals

def knn_classify(test_feature, features, labels, k):
    distances = np.sqrt(((features - test_feature) ** 2).sum(axis=1))
    nearest = np.argsort(distances)
    top_k = labels[nearest[:k]]
    vote = Counter(top_k).most_common(1)[0][0]
    return vote

def calculate_error_rate(test_features, test_labels, train_features, train_labels, k):
    error_count = sum(knn_classify(test_f, train_features, train_labels, k) != test_l
                      for test_f, test_l in zip(test_features, test_labels))
    return error_count / len(test_labels)

def plot_data(features, labels):
    plt.scatter(features[:, 0], features[:, 1], s=20, c=labels, alpha=0.5)
    plt.xlabel('Feature 1')
    plt.ylabel('Feature 2')
    plt.title('Feature Scatter Plot')
    plt.colorbar()
    plt.show()

def main():
    filename = 'hailun.txt'
    features, labels = create_dataset(filename)
    norm_features, ranges, min_vals = normalize(features)
    
    # Assuming we're using the first 10% of data as test set
    test_ratio = 0.1
    num_test = int(len(features) * test_ratio)
    test_features = norm_features[:num_test]
    test_labels = labels[:num_test]
    train_features = norm_features[num_test:]
    train_labels = labels[num_test:]
    
    error_rate = calculate_error_rate(test_features, test_labels, train_features, train_labels, k=3)
    print(f"Error rate: {error_rate:.2%}")

    # Uncomment to see the plots
    # plot_data(features, labels)
    # plot_data(norm_features, labels)

if __name__ == '__main__':
    main()
算法优缺点
优点:
简单易理解,易于实现。
适合对稀有事件进行分类。
特别适合多分类问题(multi-modal,对象具有多个类别标签)。
缺点:
计算成本高,尤其是在样本数据库很大的情况下。
样本不平衡问题会影响准确性。
需要手动选择k值。
对异常值敏感。
应用场景
K近邻算法广泛应用于金融领域信用评分、电商网站推荐系统、医学领域疾病诊断、手写识别以及视频内容的推荐系统等。
实现步骤
收集数据:可以使用任何方法。
准备数据:距离计算需要数值型数据,最好是结构化的数据格式。
分析数据:可以使用任何方法。
训练算法:此步骤不适用于K近邻算法,因为训练过程就是存储数据的过程。
测试算法:计算错误率。
使用算法:首先需要输入样本数据和结构化的输出结果,然后运行K近邻算法判断输入数据分别属于哪个分类。
结语
K近邻算法是一种有效的分类算法,但在实际应用中需要注意数据预处理、距离度量的选择以及k值的选择等问题。通过适当的优化和调整,K近邻算法可以在多个领域发挥出强大的作用。

  • 19
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值