KNN算法例子(java,scala,python 代码实现)

java 版本

package com.fullshare.test;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

/**
 * @author huangjiangnan
 * @email huangjiangnanjava@163.com
 * @version 1.0
 * @since 2017年12月14日 上午10:27:35 类说明
 */
public class Knn {

	public static void main(String[] args) {
		// 电影名称 搞笑镜头 拥抱镜头 打斗镜头 电影类型
		Object[][] sample = { 
				{ 1, "宝贝当家", 45, 2, 9, "喜剧片" },
				{ 2, "美人鱼", 21, 17, 5, "喜剧片" },
				{ 3, "澳门风云3", 54, 9, 11, "喜剧片" }, 
				{ 4, "功夫熊猫3", 39, 0, 31, "喜剧片" }, 
				{ 5, "谍影重重", 5, 2, 57, "动作片" },
				{ 6, "叶问3", 3, 2, 65, "动作片" }, 
				{ 7, "伦敦陷落", 2, 3, 55, "动作片" }, 
				{ 8, "我的特工爷爷", 6, 4, 21, "动作片" }, 
				{ 9, "奔爱", 7, 46, 4, "爱情片" }, 
				{ 10, "夜孔雀", 9, 39, 8, "爱情片" },
				{ 11, "代理情人", 9, 38, 2, "爱情片" }, 
				{ 12, "新步步惊心", 8, 34, 17, "爱情片" }, };
		// 求唐人街办案类型
		Object[] movie = { 13, "唐人街探案", 23, 3, 17, null };
		int length = sample.length;
		System.out.println("序号 名称        距离");
		List<MovieDis> movieDisList = new ArrayList<>();

		for (int i = 0; i < length; i++) {
			Object[] mv = sample[i];
			double distances = getDistance(mv, movie);
			MovieDis movieDis = new MovieDis((int) mv[0], (String) mv[1], distances, (String) mv[5]);
			// Object[] disInfo={};
			System.out.println(String.format("%s %s %s", mv[0], mv[1], distances));
			movieDisList.add(movieDis);
		}

		Collections.sort(movieDisList, new Comparator<MovieDis>() {

			@Override
			public int compare(MovieDis o1, MovieDis o2) {
				double sub = (o1.getDistance() - o2.getDistance());
				if (sub == 0) {
					return 0;
				}
				if (sub > 0) {
					return 1;
				}
				return -1;
			}
		});
		
		int k=5;
		System.out.println("按照欧式距离排序,取k=5");
		movieDisList=movieDisList.subList(0,k);
		for (MovieDis movieDis : movieDisList) {
			System.out.println(movieDis);
		}

	}

	public static double getDistance(Object[] movie1, Object[] movie2) {
		double[] ps1 = { (Integer) movie1[2], (Integer) movie1[3], (Integer) movie1[4] };
		double[] ps2 = { (Integer) movie2[2], (Integer) movie2[3], (Integer) movie2[4] };
		return getDistance(ps1, ps2);
	}

	public static double getDistance(double[] ps1, double[] ps2) {
		if (ps1.length != ps1.length) {
			throw new RuntimeException("属性数量不对应");
		}
		int length = ps1.length;
		double total = 0;
		for (int i = 0; i < length; i++) {
			double sub = ps1[i] - ps2[i];
			total = total + (sub * sub);
		}
		return Math.sqrt(total);
	}

}

class MovieDis {
	private int id;
	private String title;
	private double distance;
	private String type;

	public int getId() {
		return id;
	}

	public void setId(int id) {
		this.id = id;
	}

	public String getTitle() {
		return title;
	}

	public void setTitle(String title) {
		this.title = title;
	}

	public double getDistance() {
		return distance;
	}

	public void setDistance(double distance) {
		this.distance = distance;
	}

	public MovieDis(int id, String title, double distance, String type) {
		super();
		this.id = id;
		this.title = title;
		this.distance = distance;
		this.type = type;
	}


	@Override
	public String toString() {
		return "MovieDis [id=" + id + ", title=" + title + ", distance=" + distance + ", type=" + type + "]";
	}

	public String getType() {
		return type;
	}

	public void setType(String type) {
		this.type = type;
	}

}
scala版本

package com.test.api

class KnnScala {

}

object KnnScala {
  def main(args: Array[String]): Unit = {
    var sample: Array[Array[Any]] = Array(
      Array(1, "宝贝当家", 45, 2, 9, "喜剧片"),
      Array(2, "美人鱼", 21, 17, 5, "喜剧片"),
      Array(3, "澳门风云3", 54, 9, 11, "喜剧片"),
      Array(4, "功夫熊猫3", 39, 0, 31, "喜剧片"),
      Array(5, "谍影重重", 5, 2, 57, "动作片"),
      Array(6, "叶问3", 3, 2, 65, "动作片"),
      Array(7, "伦敦陷落", 2, 3, 55, "动作片"),
      Array(8, "我的特工爷爷", 6, 4, 21, "动作片"),
      Array(9, "奔爱", 7, 46, 4, "爱情片"),
      Array(10, "夜孔雀", 9, 39, 8, "爱情片"),
      Array(11, "代理情人", 9, 38, 2, "爱情片"),
      Array(12, "新步步惊心", 8, 34, 17, "爱情片"))
    // 求唐人街办案类型
    var movie = Array(13, "唐人街探案", 23, 3, 17, null);
    var length = sample.length - 1;
    println("序号 名称        距离");
    var movieDisList = List[MovieDis]();

    for (i <- 0 to length) {
      var mv: Array[Any] = sample(i);
      var distances: Double = getDistance(mv, movie);
      var movieDis = new MovieDis(mv(0).asInstanceOf[Int], mv(1).asInstanceOf[String], distances, mv(5).asInstanceOf[String]);
      println(printf("%s %s %s", mv(0), mv(1), distances));
      //列表添加跟java不一样,坑
      movieDisList = (movieDisList.+:(movieDis))
    }
    movieDisList = movieDisList.sortWith((o1: MovieDis, o2: MovieDis) => (o1.distance < o2.distance));

    var k: Int = 5;
    println("按照欧式距离排序,取k=5");
    movieDisList = movieDisList take 5;
    movieDisList.foreach { o => println(o) }
  }

  def getDistance(movie1: Array[Any], movie2: Array[Any]): Double = {
    var ps1 = Array(movie1(2).asInstanceOf[Integer].doubleValue(), movie1(3).asInstanceOf[Integer].doubleValue(), movie1(4).asInstanceOf[Integer].doubleValue());
    var ps2 = Array(movie2(2).asInstanceOf[Integer].doubleValue(), movie2(3).asInstanceOf[Integer].doubleValue(), movie2(4).asInstanceOf[Integer].doubleValue());
    return getDistance(ps1, ps2);
  }

  def getDistance(ps1: Array[Double], ps2: Array[Double]): Double = {
    if (ps1.length != ps1.length) {
      throw new RuntimeException("属性数量不对应");
    }
    var length = ps1.length - 1;
    var total: Double = 0;
    for (i <- 0 to length) {
      var sub = ps1(i) - ps2(i);
      total = total + (sub * sub);
    }
    
    return Math.sqrt(total);
  }

}

class MovieDis extends Serializable {
  var id: Int = 0;
  var title: String = null;
  var distance: Double = 0;
  var movieType: String = null;

  def this(_id: Int, _title: String, _distance: Double, _movieType: String) {
    this();
    this.id = _id;
    this.title = _title;
    this.distance = _distance;
    this.movieType = _movieType;
  }

  override def toString(): String = {
    return "MovieDis [id=" + id + ", title=" + title + ", distance=" + distance + ", movieType=" + movieType + "]";
  }

}



python版本

#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
# 给多个变量赋值
import math

class MovieDis :
    # 序号
    id = 0
    # 电影标题
    title = ""
    # 差距
    distance = 0.0
    # 电影类型
    movieType = ""

    def __init__(self, id ,title,distance,movieType):
        self.id=id
        self.title=title
        self.distance=distance
        self.movieType=movieType


    def __str__(self):
        return "MovieDis [id=%s, title=%s, distance=%s, movieType=%s]" % (self.id,self.title,self.distance,self.movieType)


def get_distance(ps1,ps2):
    if len(ps1) != len(ps2):
        raise Exception("数组长度不匹配")
    ll = len(ps1)
    total=0
    for i in range(0,ll):
        sub = ps1[i] - ps2[i]
        total = total + (sub * sub)
    return math.sqrt(total)




sample = ((1, "宝贝当家", 45, 2, 9, "喜剧片"),
      (2, "美人鱼", 21, 17, 5, "喜剧片"),
      (3, "澳门风云3", 54, 9, 11, "喜剧片"),
      (4, "功夫熊猫3", 39, 0, 31, "喜剧片"),
      (5, "谍影重重", 5, 2, 57, "动作片"),
      (6, "叶问3", 3, 2, 65, "动作片"),
      (7, "伦敦陷落", 2, 3, 55, "动作片"),
      (8, "我的特工爷爷", 6, 4, 21, "动作片"),
      (9, "奔爱", 7, 46, 4, "爱情片"),
      (10, "夜孔雀", 9, 39, 8, "爱情片"),
      (11, "代理情人", 9, 38, 2, "爱情片"),
      (12, "新步步惊心", 8, 34, 17, "爱情片"))

movie = (13, "唐人街探案", 23, 3, 17, "")


length = len(sample) - 1
print("序号 名称        距离")
movieDisList = []

for i in range(0,length):
    mv= sample[i]
    start=mv[2:5]
    end= movie[2:5]
    distances = get_distance(start, end)
    print("%s %s %s" % (mv[0],mv[1],distances))
    movieDis = MovieDis(mv[0], mv[1], distances, mv[5])
    movieDisList.append(movieDis)
movieDisList.sort(key=lambda x:x.distance)
for i in range(0,5):
    print(movieDisList[i])

 
结果如下

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值