项目实例:KNN预测电影网站用户性别(Hadoop学习笔记三)

通过学习《Hadoop大数据开发基础》这本书,整理了一下书本上的项目案例。让自己再梳理一下流程,也希望能给有需要的人提供一定的帮助,写的不好的希望大家提出来,一起进步。

1 学习目标

  1. 理解KNN算法的原理。
  2. 掌握以MapReduce编程实现KNN算法。
  3. 掌握以MapReduce编程实现KNN分类器评价。

2 认识KNN算法

2.1 KNN算法简介

KNN算法,全称是K Nearest Neighbor算法,即K最近邻分类算法。 其中的K表示最接近自己的K个数据样本。

比如,有一个样本空间里的样本已经分成了几个类型,然后,给定一个待分类的数据,通过计算接近自己最近的K个样本来判断这个待分类数据属于哪个分类。

简单理解就是由那离自己最近的K个点来投票决定待分类数据归为哪一类

2.2 KNN算法实现流程

  1. 准备数据,对数据进行预处理。
  2. 选用合适的数据结构来存储训练数据和测试元组。
  3. 设定参数,如k=3。
  4. 对于每一个测试记录维护一个大小为k的按距离由小到大的队列,用于存储最近邻训练元组。
  5. 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L与最近邻元组中的最大距离Lmax比较。
  6. 若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入最近邻元组。
  7. 遍历完毕,计算最近邻元组中k 个元组的多数类,并将其作为测试元组的类别。

3 数据预处理

3.1 获取数据

项目所需的数据:
(百度云链接老是被和谐,这次用微云分享一下)

链接:https://share.weiyun.com/1uVK7mpg 密码:wc677y

  1. 用户对电影的部分评分数据 ratings. dat如图所示。该数据包含4个字段,即 UserID(用户ID) MovieID(电影ID) Rating(评分)及 Timestamp(时间戳)其中, UserID的范围是1~6040, MovieID的范围是1~3952, Rating采用5分好评制度,即最高分为5分,最低分为1分。

    在这里插入图片描述

  2. 已知性别的用户信息部分数据 users.dat如图所示。该数据包含5个字段,分别为 UserID(用户ID) Gender(性别)Age(年龄Occupation(职业)以及Zip-code(编码)其中, Occupation字段代表的是21种不同的职业类型,Age字段记录的也并不是用户的实际年龄,而是一个年龄段,例如,1代表的是18岁以下,具体的解释请参考 README。

在这里插入图片描述

  1. 部分电影信息数据 movies.dat如图所示。该字段包含 MovieID(电影ID) Title(电影名称) Genres(电影类型)3个字段。其中, Title字段不仅记录电影的名称,还记录了电影的上映时间。数据中总共记录了18种电影类型,包括喜剧片、动作片、警匪片、爱情片等,具体的电影类型请参见 README。

在这里插入图片描述
4. 数据相关字段的解释文件README

3.2 数据变换

(1)根据UserID字段字段连接ratings.dat数据和users.dat数据,连接结果得到一份包含UserID(用户ID),Gender(性别),Age(年龄),Occupation(职业),Zip-code(编码),MovieID(电影ID)的数据。

只需下载上边百度云链接里的ratings_users.jar包。将JAR包上传到 Linux的opt目录下,在HDFS上新建文件夹/movie,将 ratings.dt、 users.dat传到/movie下,将程序运行结果保存在/movie/ratingsusers目录下。

命令如下:

hadoop jar /opt/ratings_users.jar demo. RatingsAndusers /movie/users.dat/movie/ratings.dat/movie/ratings_ users

运行之后得到

(2)同理,根据MovieID连接movies.dat数据和/movie/ratings_users/part-m-00000上的数据,连接结果得到一份包含UserID(用户ID),Gender(性别),Age(年龄),Occupation(职业),Zip-code(编码),MovieID(电影ID),Genres(电影类型)。

然后把百度云链接里的users_movies.jar包下载,Linux的opt目录下,将movies.dat数据上传到HDFS的/movies目录下,运行结果保存在/movie/users_movies。

命令如下:

 hadoop jar /opt/users_movies.jar demo.UserAndMovies /movie/movies.dat /movie/ratings_users/part-m-00000 /movie/users_movies

结果如下:

在这里插入图片描述

(3)对每个用户看过电影类型进行统计。对Gender(性别)做一步转换,如果是女性(F)则用1标记,如果是男性(M)则用0标记.

这一步的处理看Map端和Reduce端的处理流程:

对每个用户看过的电影类型进行统计的Mapper类及Reducer类代码:


public class MoviesGenresMapper extends Mapper<LongWritable, Text, UserAndGender, Text> {
	private UserAndGender user_gender=new UserAndGender();
	private String splitter="";
	private Text genres=new Text();
	@Override
	protected void setup(Mapper<LongWritable, Text, UserAndGender, Text>.Context context)
			throws IOException, InterruptedException {
		splitter=context.getConfiguration().get("SPLITTER");
	}
	@Override
	protected void map(LongWritable key, Text value, Mapper<LongWritable, Text, UserAndGender, Text>.Context context)
			throws IOException, InterruptedException {
		String[] val=value.toString().split(splitter);
		user_gender.setUserID(val[0]);
		if(val[1].equals("M")){
			//性别为M则用0标记
			user_gender.setGender(0);
		}else{
			//性别为F则用1标记
			user_gender.setGender(1);
		}
		user_gender.setAge(Integer.parseInt(val[2]));
		user_gender.setOccupation(val[3]);
		user_gender.setZip_code(val[4]);
		genres.set(val[6]);
		context.write(user_gender, genres);
	}
}



public class MoviesGenresReducer extends Reducer<UserAndGender, Text, Text, NullWritable> {
	@Override
	protected void reduce(UserAndGender key, Iterable<Text> value,
			Reducer<UserAndGender, Text, Text, NullWritable>.Context context) throws IOException, InterruptedException {
		//初始化一个HashMap集合,集合中的键为18种电影类型,每个键对应的值为0
		HashMap<String,Integer> genresCounts=new HashMap<String,Integer>();
		String[] genreslist={"Action","Adventure","Animation","Children's","Comedy","Crime","Documentary","Drama",
				"Fantasy","Film-Noir","Horror","Musical","Mystery","Romance","Sci-Fi","Thriller","War","Western"		
				};
	    for(int i=0;i<genreslist.length;i++){
			if(!genresCounts.containsKey(genreslist[i])){
				genresCounts.put(genreslist[i], 0);
				}
			}
	    //遍历值列表
		for (Text val : value) {
			//对每个元素进行分割
			String[] genres=val.toString().split("\\|");
			for(int i=0;i<genres.length;i++){
				//如果HashMap元素的键包含分割结果的元素,则该键对应的值加1
				if(genresCounts.containsKey(genres[i])){
				   genresCounts.put(genres[i], genresCounts.get(genres[i])+1);
				}
			}
		}
		//将HashMap集合中所有键对应的值根据逗号连接成字符串
		String result="";
		for(Map.Entry<String, Integer> kv:genresCounts.entrySet()){
			if(result.length()==0){
				result=kv.getValue().toString();
			}else{
				result=result+","+kv.getValue();
			}
		}
		
	    context.write(new Text(key.toString()+","+result), NullWritable.get());
	}

}

处理之后得到结果:

3.3 数据清洗

缺失值和异常值的处理方式如下图:

[外链图片转存失败(img-lZAFYwWj-1567943728834)(C:\Users\z\Desktop\批注 2019-09-08 174401.png)]

[外链图片转存失败(img-0peXZziS-1567943728836)(C:\Users\z\Desktop\批注 2019-09-08 174418.png)]

处理缺失值和异常值的代码:


public class DataProcessingMapper extends Mapper<LongWritable, Text, Text, NullWritable> {
	private String splitter="";
	enum DataProcessingCounter{
		NullData,
		AbnormalData
	}
	@Override
	protected void setup(Mapper<LongWritable, Text, Text, NullWritable>.Context context)
			throws IOException, InterruptedException {
		splitter=context.getConfiguration().get("SPLITTER");
	}
	@Override
	protected void map(LongWritable key, Text value, Mapper<LongWritable, Text, Text, NullWritable>.Context context)
			throws IOException, InterruptedException {
		String[] val = value.toString().split(splitter);
		for(int i=5;i<val.length;i++){
			//判断每个字段的值是否是空值,若是则用0替换
			if(val[i].equals("") || val[i].equals("null") || val[i].equals("NULL") || val[i].equals("NAN")){	
				context.getCounter(DataProcessingCounter.NullData).increment(1);
				val[i]="0";
			}else{
				context.getCounter(DataProcessingCounter.NullData).increment(0);
			}
			//判断每个字段的值是否是异常值,若是则用0替换
			if(Integer.parseInt(val[i])<0){
				context.getCounter(DataProcessingCounter.AbnormalData).increment(1);
				val[i]="0";
			}else{
				context.getCounter(DataProcessingCounter.AbnormalData).increment(0);
			}
		}
		
		String result="";
		//重新将字符创数组val拼接成字符串
		for(int i=0;i<val.length;i++){
			if(i==0){
				result=val[i];
			}else{
				result=result+splitter+val[i];
			}
		}
		context.write(new Text(result), NullWritable.get());
	}
}

3.4 划分数据集

一般来说分类算法由3个过程:

(1)通过归纳分析训练样本集建立分类器

(2)用验证数据集来选择最优的模型参数

(3)用已知类别的测试样本集评估分类器的准确性

本项目在建立M电影用户分类器之前,将处理之后的数据按8:1:1的比例随机划分数据集为训练数据集、测试数据集、验证数据集。

读取HDFS的数据并统计记录数的方法:

/**
	 * 读取原始数据并统计数据的记录数
	 * @param fs
	 * @param path
	 * @return
	 * @throws Exception
	 */
	public static int getSize(FileSystem fs,Path path) throws Exception{
		int count=0;
		FSDataInputStream is=fs.open(path);
		BufferedReader br=new BufferedReader(new InputStreamReader(is));
		String line="";
		while((line=br.readLine())!=null){
			count++;
		}
		br.close();
		is.close();
		return count;		
	}
	/**
	 *随机获取 80%原始数据的对应下标
	 * @param count
	 * @return
	 */
	public static Set<Integer> trainIndex(int count){
		Set<Integer> train_index=new HashSet<Integer>();
		int trainSplitNum=(int)(count*0.8);
		Random random=new Random();
		while(train_index.size()<trainSplitNum){
			int a=random.nextInt(count);
			train_index.add(a);
		}
		return train_index;	
	}
	/**
	 * 随机获取10%原始数据对应的下标
	 * @param count
	 * @param train_index
	 * @return
	 */
	public static Set<Integer> validateIndex(int count,Set<Integer> train_index){
		Set<Integer> validate_index=new HashSet<Integer>();
		int validateSplitNum=count-(int)(count*0.9);
		Random random=new Random();
		while(validate_index.size()<validateSplitNum){
			int a=random.nextInt(count);
			if(!train_index.contains(a)){
				validate_index.add(a);	
			}
		}
		return validate_index;	
	}

设置训练集的存储路径为/movie/trainData,验证数据集的存储路径为/movie/validateData,测试数据集的存储路径为/movie/testData。

将数据写入HDFS:

public class SplitData {
	public static void main(String[] args) throws Exception {
		Configuration conf=new Configuration();
		conf.set("fs.defaultFS", "master:8020");
		FileSystem fs=FileSystem.get(conf);
		//获取预处理之后的电影数据路径
		Path moviedata=new Path("/movie/processing_out/part-m-00000");
		//得到电影数据大小
		int datasize=getSize(fs, moviedata);
		//得到train数据对应原始下标
		Set<Integer> train_index=trainIndex(datasize);
		
		System.out.println(train_index.size());
		//得到validate数据对应原始数据的下标
		Set<Integer> validate_index=validateIndex(datasize,train_index);
		System.out.println(validate_index.size());
		//训练数据存放的路径
		Path train=new Path("hdfs://master:8020/movie/trainData");
		fs.delete(train,true);
		FSDataOutputStream os1=fs.create(train);
		BufferedWriter bw1=new BufferedWriter(new OutputStreamWriter(os1));
		//测试数据存放的路径
		Path test=new Path("hdfs://master:8020/movie/testData");
		fs.delete(test,true);
		FSDataOutputStream os2=fs.create(test);
		BufferedWriter bw2=new BufferedWriter(new OutputStreamWriter(os2));
		//验证数据存放的路径
		Path validate=new Path("hdfs://master:8020/movie/validateData");
		fs.delete(validate,true);
		FSDataOutputStream os3=fs.create(validate);
		BufferedWriter bw3=new BufferedWriter(new OutputStreamWriter(os3));
		//读取数据并将数据分为训练数据、测试数据以及验证数据写入到HDFS
		FSDataInputStream is=fs.open(moviedata);
		BufferedReader br=new BufferedReader(new InputStreamReader(is));
		String line="";
		int sum=0;
		int trainsize=0;
		int testsize=0;
		int validatesize=0;
		while((line=br.readLine())!=null){
			sum+=1;
			if(train_index.contains(sum)){
				trainsize+=1;
				bw1.write(line.toString());
				bw1.newLine();
			}else if(validate_index.contains(sum)){
				validatesize+=1;
				bw3.write(line.toString());
				bw3.newLine();
			}else{
				testsize+=1;
				bw2.write(line.toString());
				bw2.newLine();
			}
		}
		bw1.close();
		os1.close();
		bw2.close();
		os2.close();
		bw3.close();
		os3.close();
		br.close();
		is.close();
		fs.close();
	}

4 实现用户性别分类

4.1 KNN算法Hadooop实现思路

算法描述:

1.自定义值类型表示距离和类型,由于KNN算法是计算测试数据与已知类别的训练数据之间的距离,找到距离与测试数据最近的K个训练数据,再根据这些训练所属的类别的众数来判断测试数据的类别。所以在map阶段需要将测试数据与训练数据的距离及该训练数据的类别作为值输出,程序可以使用Hadoop内置的数据类型Text作为值类型输出距离及类别,但为了提高程序的执行效率,建议自定义值类型表示距离和类别。

2.map阶段,setup函数读取测试数据。在map函数里读取每条训练数据,遍历测试数据,计算读取进来的训练记录与每条测试数据的距离,计算距离采用的是欧式距离的计算方法,map输出的键是每条测试数据,输出的值是该测试数据与读取的训练数据的距离和训练数据的类别。

3.reduce阶段,函数初始化参数值,函数对相同键的值根据距离进行升序排序,取出前个值,输出读取进来的键和这个值中类别的众数

4.2 代码实现

4.2.1 自定义值类型

public class DistanceAndLabel implements Writable{
	private double distance;
	private String label;
	public DistanceAndLabel() {
	}
	public DistanceAndLabel(double distance,String label) {
		this.distance=distance;
		this.label=label;
	}
	public double getDistance() {
		return distance;
	}
	public void setDistance(double distance) {
		this.distance = distance;
	}
	public String getLabel() {
		return label;
	}
	public void setLabel(String label) {
		this.label = label;
	}
	/**
	 *先读取距离,再读取类别
	 */
	@Override
	public void readFields(DataInput in) throws IOException {
		this.distance=in.readDouble();
		this.label=in.readUTF();
		
	}
	/**
	 * 先把distnce写入out输出流
	 * 再把label写入out输出流
	 */
	@Override
	public void write(DataOutput out) throws IOException {
		out.writeDouble(distance);
		out.writeUTF(label);
		
	}
	/**
	 * 使用空格将距离和类别连接成字符串
	 */
	@Override
	public String toString() {
		return this.distance+","+this.label;
	}
}

4.2.2 在Mapper类中定义计算距离的方法

public class MovieClassifyMapper extends Mapper<LongWritable, Text, Text, DistanceAndLabel> {
	private DistanceAndLabel distance_label=new DistanceAndLabel();
	private String splitter="";
	ArrayList<String> testData=new ArrayList<String>();
	private String testPath="";
	@Override
	protected void setup(Mapper<LongWritable, Text, Text, DistanceAndLabel>.Context context)
			throws IOException, InterruptedException {
		Configuration conf=context.getConfiguration();
		splitter=conf.get("SPLITTER");
		testPath=conf.get("TESTPATH");
		//读取测试数据存于列表testData中
		FileSystem fs=FileSystem.get(conf);
		FSDataInputStream is=fs.open(new Path(testPath));
		BufferedReader br=new BufferedReader(new InputStreamReader(is));
		String line="";
		while((line=br.readLine())!=null){
			testData.add(line);
		}
		is.close();
		br.close();
	}
	@Override
	protected void map(LongWritable key, Text value, Mapper<LongWritable, Text, Text, DistanceAndLabel>.Context context)
			throws IOException, InterruptedException {
		double distance=0.0;
		String[] val=value.toString().split(splitter);
		String[] singleTrainData=Arrays.copyOfRange(val, 5, val.length);
		String label=val[1];
		for (String td: testData) {
			String[] test=td.split(splitter);
			String[] singleTestData=Arrays.copyOfRange(test, 5, test.length);
			distance=Distance(singleTrainData,singleTestData);
			distance_label.setDistance(distance);
			distance_label.setLabel(label);
			context.write(new Text(td), distance_label);		
		}
	}
	/**
	 * 计算训练数据与测试数据的距离
	 * @param singleTrainData
	 * @param singleTestData
	 * @return
	 */
	private double Distance(String[] singleTrainData, String[] singleTestData) {
		double sum=0.0;
		for(int i=0;i<singleTrainData.length;i++){
			sum+=Math.pow(Double.parseDouble(singleTrainData[i]), Double.parseDouble(singleTestData[i]));
		}
		return Math.sqrt(sum);
	}
}

4.2.3 Reduce类实现

public class MovieClassifyReducer extends Reducer<Text, DistanceAndLabel, Text, NullWritable> {
	private int k=0;
	@Override
	protected void setup(Reducer<Text, DistanceAndLabel, Text, NullWritable>.Context context)
			throws IOException, InterruptedException {
		//初始化K值
		k=context.getConfiguration().getInt("K",3);
	}
	@Override
	protected void reduce(Text key, Iterable<DistanceAndLabel> value,
			Reducer<Text, DistanceAndLabel, Text, NullWritable>.Context context) throws IOException, InterruptedException {
		String label=getMost(getTopK(sort(value)));
		context.write(new Text(label+","+key), NullWritable.get());	
	}
	/**
	 * 得到列表中类别的众数
	 * @param topK
	 * @return
	 */
	private String getMost(List<String> topK) {
		HashMap<String,Integer> labelTimes=new HashMap<String,Integer>();
		for (String str : topK) {
			String label=str.substring(str.lastIndexOf(",")+1,str.length());
			if(labelTimes.containsKey(label)){
				labelTimes.put(label, labelTimes.get(label)+1);
			}else{
				labelTimes.put(label, 1);
			}
		}
		int maxInt=Integer.MIN_VALUE;
		String mostLabel="";
		for(Map.Entry<String, Integer> kv:labelTimes.entrySet()){
			if(kv.getValue()>maxInt){
				maxInt=kv.getValue();
				mostLabel=kv.getKey();
			}
		}
		return mostLabel;
	}
	/**
	 * 取出列表中的前K个值
	 * @param sort
	 * @return
	 */
	private List<String> getTopK(List<String> sort) {
		return sort.subList(0, k);
	}
	/**
	 * 根据距离升序排序
	 * @param value
	 * @return 
	 */
	private List<String> sort(Iterable<DistanceAndLabel> value) {
		ArrayList<String> result=new ArrayList<String>();
		for(DistanceAndLabel val:value){
			result.add(val.toString());
		}
		String[] tmp=new String[result.size()];
		result.toArray(tmp);
		Arrays.sort(tmp, new Comparator<String>(){

			@Override
			public int compare(String o1, String o2) {
				double o1D=Double.parseDouble(o1.substring(0, o1.indexOf(",")));
				double o2D=Double.parseDouble(o2.substring(0, o2.indexOf(",")));
				if(o1D>o2D){
					return 1;
				}else if(o1D<o2D){
					return -1;
				}else{
					return 0;
				}
			}});
		return Arrays.asList(tmp);
	}
}

4.2.4 驱动类的实现

public class MovieClassify extends Configured implements Tool{
	@Override
	public int run(String[] args) throws Exception {
		if(args.length!=5){
			System.err.println("demo.MovieClassify <testinput> <traininput> <output> <k> <splitter>");
			System.exit(-1);
		}
		Configuration conf=getMyConfiguration();
		conf.setInt("K", Integer.parseInt(args[3]));
		conf.set("SPLITTER",args[4]);
		conf.set("TESTPATH", args[0]);
		Job job=Job.getInstance(conf, "movie_knn");
		job.setJarByClass(MovieClassify.class);//设置主类
		job.setMapperClass(MovieClassifyMapper.class);//设置Mapper类
		job.setReducerClass(MovieClassifyReducer.class);//设置Reducer类
		job.setMapOutputKeyClass(Text.class);//设置Mapper输出的键类型
		job.setMapOutputValueClass(DistanceAndLabel.class);//设置Mapper输出的值类型
		job.setOutputKeyClass(Text.class);//设置Reducer输出的键类型
		job.setOutputValueClass(NullWritable.class);//设置Reducer输出的值类型
		FileInputFormat.addInputPath(job, new Path(args[1]));//设置输入路径
		FileSystem.get(conf).delete(new Path(args[2]), true);//删除输出路径
		FileOutputFormat.setOutputPath(job, new Path(args[2]));//设置输出路径
		return job.waitForCompletion(true)?-1:1;//提交任务
	}
	public static void main(String[] args) {
		String[] myArgs={
				"/movie/testData",
				"/movie/trainData",
				"/movie/knnout",
				"3",
				","
		};
		try {
			ToolRunner.run(getMyConfiguration(), new MovieClassify(), myArgs);
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
	/**
	 * 设置连接Hadoop集群的配置
	 * @return
	 */
	public static Configuration getMyConfiguration(){
		Configuration conf = new Configuration();
		conf.setBoolean("mapreduce.app-submission.cross-platform",true);
		conf.set("fs.defaultFS", "hdfs://master:8020");// 指定namenode
		conf.set("mapreduce.framework.name","yarn"); // 指定使用yarn框架
		String resourcenode="master";
		conf.set("yarn.resourcemanager.address", resourcenode+":8032"); // 指定resourcemanager
		conf.set("yarn.resourcemanager.scheduler.address",resourcenode+":8030");// 指定资源分配器
		conf.set("mapreduce.jobhistory.address",resourcenode+":10020");
		conf.set("mapreduce.job.jar",JarUtil.jar(MovieClassify.class));
		return conf;	
	}
}

4.2.5 打包成jar包的工具

public class JarUtil {
    public static String jar(Class<?> cls){// 验证ok
        String outputJar =cls.getName()+".jar";
        String input = cls.getClassLoader().getResource("").getFile();
        input= input.substring(0,input.length()-1);
        input = input.substring(0,input.lastIndexOf("/")+1);
        input =input +"bin/";
        jar(input,outputJar);
        return outputJar;
    }
    private static void jar(String inputFileName, String outputFileName){
        JarOutputStream out = null;
        try{
            out = new JarOutputStream(new FileOutputStream(outputFileName));
            File f = new File(inputFileName);
            jar(out, f, "");
        }catch (Exception e){
            e.printStackTrace();
        }finally{
            try {
                out.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

    }
    private static void jar(JarOutputStream out, File f, String base) throws Exception {
        if (f.isDirectory()) {
            File[] fl = f.listFiles();
            base = base.length() == 0 ? "" : base + "/"; // 注意,这里用左斜杠
            for (int i = 0; i < fl.length; i++) {
                jar(out, fl[ i], base + fl[ i].getName());
            }
        } else {
            out.putNextEntry(new JarEntry(base));
            FileInputStream in = new FileInputStream(f);
            byte[] buffer = new byte[1024];
            int n = in.read(buffer);
            while (n != -1) {
                out.write(buffer, 0, n);
                n = in.read(buffer);
            }
            in.close();
        }
    }
}

5 评价分类结果的准确性

5.1 评价思路

准确率的计算公式:
准 确 率 = 正 确 识 别 的 个 体 总 数 ÷ 识 别 出 的 个 体 总 数 准确率=正确识别的个体总数÷识别出的个体总数 =÷
评价思路:

5.2 实现分类评价

评价代码之Mapper类:

public class ValidateMapper extends Mapper<LongWritable, Text, NullWritable, Text> {
	private String splitter="";
	@Override
	protected void setup(Mapper<LongWritable, Text, NullWritable, Text>.Context context)
			throws IOException, InterruptedException {
		splitter=context.getConfiguration().get("SPLITTER");
	}
	@Override
	protected void map(LongWritable key, Text value, Mapper<LongWritable, Text, NullWritable, Text>.Context context)
			throws IOException, InterruptedException {
		String[] val=value.toString().split(splitter);
		context.write(NullWritable.get(), new Text(val[0]+splitter+val[2]));
	}
}

Reducer类:

public class ValidateReducer extends Reducer<NullWritable, Text, DoubleWritable, NullWritable> {
	private String splitter="";
	@Override
	protected void setup(Reducer<NullWritable, Text, DoubleWritable, NullWritable>.Context context)
			throws IOException, InterruptedException {
		splitter=context.getConfiguration().get("SPLITTER");
	}
	@Override
	protected void reduce(NullWritable key, Iterable<Text> value,
			Reducer<NullWritable, Text, DoubleWritable, NullWritable>.Context context)
					throws IOException, InterruptedException {
		//初始化sum记录预测分类正确的个数
		int sum=0;
		//初始化count记录所有分类结果的记录数,也即测试数据的记录数
		int count=0;
		for (Text val: value) {
			count++;
			String predictLabel=val.toString().split(splitter)[0];
			String trueLabel=val.toString().split(splitter)[1];
			//判断预测分类的类别是否与正确分类的类别一样
			if(predictLabel.equals(trueLabel)){
				sum+=1;
			}
		}
		//计算正确率
		double accuracy=(double)sum/count;
		context.write(new DoubleWritable(accuracy), NullWritable.get());
	}
}

驱动类:

public class Validate extends Configured implements Tool{
	@Override
	public int run(String[] args) throws Exception {
		if(args.length!=3){
			System.err.println("demo01.Validate <input> <output> <splitter>");
			System.exit(-1);
		}
		Configuration conf=getMyConfiguration();
		conf.set("SPLITTER", args[2]);
		Job job=Job.getInstance(conf, "validate");
		job.setJarByClass(Validate.class);//设置主类
		job.setMapperClass(ValidateMapper.class);//设置Mapper类
		job.setReducerClass(ValidateReducer.class);//设置Reducer类
		job.setMapOutputKeyClass(NullWritable.class);//设置Mapper输出的键格式
		job.setMapOutputValueClass(Text.class);//设置Mapper输出的值格式
		job.setOutputKeyClass(DoubleWritable.class);//设置Reducer输出的键格式
		job.setOutputValueClass(NullWritable.class);//设置Reducer输出的值格式
		FileInputFormat.addInputPath(job, new Path(args[0]));//设置输入路径
		FileSystem.get(conf).delete(new Path(args[1]),true);//设置删除输出路径
		FileOutputFormat.setOutputPath(job, new Path(args[1]));//设置输出路径
		return job.waitForCompletion(true)?-1:1;
	}
	public static void main(String[] args) {
		String[] myArgs={
				"/movie/knnout/part-r-00000",
				"/movie/validateout",
				","
		};
		try {
			ToolRunner.run(getMyConfiguration(), new Validate(), myArgs);
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
	/**
	 * 设置连接Hadoop集群的配置
	 * @return
	 */
	public static Configuration getMyConfiguration(){
		Configuration conf = new Configuration();
		conf.setBoolean("mapreduce.app-submission.cross-platform",true);
		conf.set("fs.defaultFS", "hdfs://master:8020");// 指定namenode
		conf.set("mapreduce.framework.name","yarn"); // 指定使用yarn框架
		String resourcenode="master";
		conf.set("yarn.resourcemanager.address", resourcenode+":8032"); // 指定resourcemanager
		conf.set("yarn.resourcemanager.scheduler.address",resourcenode+":8030");// 指定资源分配器
		conf.set("mapreduce.jobhistory.address",resourcenode+":10020");
		conf.set("mapreduce.job.jar",JarUtil.jar(Validate.class));
		return conf;	
	}
}

5.3 寻找最优K值

KNN算法的K值会对分类结果产生重大影响。

下图是K值分别为3,4,5,6,7对应的准确率,从图中可以看出K值为3是准确率是最高的。

虽然在K=3,4,5,6,7中,K=3 的准确率是最高的,但并不意味着K=3 得到的模型就是最好的分类器。
对与K值的选取,可以利用验证数据集及迭代的算法思想,其思路为:

  1. 初始化最大准确率maxAccuracy为0.0及最优K值bestK为0
  2. 定义K值列表k,设置K值从2取到100,K值并非直接从2连续取到100,而是隔开取值,例如,K可以取2,3,5,9,15,30,55,70,80,95。
  3. 循环k列表,针对每一个K值,训练模型并利用验证数据集计算准确率accuracy,如果准确率大于最大准确率maxAccuracy,则将accuracy的值赋给maxAccuracy,K值赋给最优K值bestK,接着循环下一个K值。如果准确率小于或等于最大准确率maxAccuracy,则直接循环下一个K值。
  4. 循环结束之后输出最优K。

[外链图片转存失败(img-UMgiK5Xd-1567943728839)(C:\Users\z\AppData\Roaming\Typora\typora-user-images\1567942782033.png)]

针对上述选择最优K值的思路,编写一个ALLJob类来完成选择最优K值。ALLJob类中只有一个main方法,在该方法中循环K值,每循环一次则需调用实现用户性别分类的MapReduce程序,同时还需调用评价分类准确性的MapReducue程序。

选择最优K值代码:

public class AllJob {
	public static void main(String[] args) throws IOException {
		Configuration conf=new Configuration();
		conf.set("fs.defaultFS", "master:8020");
		FileSystem fs=FileSystem.get(conf);
		double maxAccuracy=0.0;
		int bestK=0;
		int[] k={2,3,5,9,15,30,55,70,80,100};
		for(int i=0;i<k.length;i++){
			double accuracy=0.0;
			String[] classifyArgs={
					"/movie/validateData",
					"/movie/trainData",
					"/movie/knnout",
					String.valueOf(k[i]),
					","
			};
			try {
				ToolRunner.run(demo.MovieClassify.getMyConfiguration(), new demo.MovieClassify(), classifyArgs);
			} catch (Exception e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
			String[] validateArgs={
					"/movie/knnout/part-r-00000",
					"/movie/validateout",
					","
			};
			try {
				ToolRunner.run(demo01.Validate.getMyConfiguration(),new demo01.Validate(),validateArgs);
			} catch (Exception e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
			FSDataInputStream is=fs.open(new Path("/movie/validateout/part-r-00000"));
			BufferedReader br=new BufferedReader(new InputStreamReader(is));
			String line="";
			while((line=br.readLine())!=null){
				accuracy=Double.parseDouble(line);
			}
			br.close();
			is.close();
			if(accuracy>maxAccuracy){
				maxAccuracy=accuracy;
				bestK=k[i];
			}
			System.out.println("K="+k[i]+"\t"+"accuracy="+accuracy);				
		}
		System.out.println("最优K值是:"+bestK+"\t"+"最优K值对应的准确率:"+maxAccuracy);
	}
}

5.4 KNN算法优缺点

优点:

  1. 简单,易于理解,易于实现,无需估计参数,无需训练;
  2. 适合对稀有事件进行分类;
  3. 特别适合于多分类问题(multi-modal,对象具有多个类别标签), KNN比SVM的表现要好;

缺点:

  1. 该算法计算量大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。
  2. 维度灾难: 在计算距离的时候考虑的是实例所有属性 。但分类可能仅由2个属性决定,这中情况下属性的相似性度量会误导k-近邻算法的分类。

解决办法:(1)属性加权;(2)剔除不相关的属性。

从豆瓣批量获取看过电影的用户列表,并应用kNN算法预测用户性别 首先从豆瓣电影的“看过这部电影 的豆瓣成员”页面上来获取较为活跃的豆瓣电影用户。 获取数据 链接分析 这是看过"模仿游戏"的豆瓣成员的网页链接:http://movie.douban.com/subject/10463953/collections。 一页上显示了20名看过这部电影的豆瓣用户。当点击下一页时,当前连接变为:http://movie.douban.com/subject/10463953/collections?start=20。 由此可知,当请求下一页内容时,实际上就是将"start"后的索引增加20。 因此,我们可以设定base_url='http://movie.douban.com/subject/10463953/collections?start=',i=range(0,200,20),在循环中url=base_url+str(i)。 之所以要把i的最大值设为180,是因为后来经过测试,豆瓣只给出看过一部电影的最近200个用户。 读取网页 在访问时我设置了一个HTTP代理,并且为了防止访问频率过快而被豆瓣封ip,每读取一个网页后都会调用time.sleep(5)等待5秒。 在程序运行的时候干别的事情好了。 网页解析 本次使用BeautifulSoup库解析html。 每一个用户信息在html中是这样的: 七月 (银川) 2015-08-23   首先用读取到的html初始化soup=BeautifulSoup(html)。本次需要的信息仅仅是用户id和用户的电影主页,因此真正有用的信息在这段代码中: 因此在Python代码中通过td_tags=soup.findAll('td',width='80',valign='top')找到所有的块。 td=td_tags[0],a=td.a就可以得到 通过link=a.get('href')可以得到href属性,也就用户的电影主页链接。然后通过字符串查找也就可以得到用户ID了。
评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值