最近在面试 某知名企业的时候 遇到了一个比较难的面试题,当时没有思考出来解法。
回来后看了 mr 的运行流程 与spark 的一些文档,最终解决了这个问题。
MR 相关解法:
Hadoop学习之路(二十七)MapReduce的API使用(四)
https://www.cnblogs.com/qingyunzong/p/8639414.html
中的第二题
题目描述:
现有如下的数据需要处理,格式 csv
4行样例
UserA,LocA,2018-01-09 12:00,60
UserA,LocA,2018-01-09 13:00,60
UserA,LocB,2018-01-09 11:00,20
UserA,LocA,2018-01-09 19:00,60
含义:
UserA,LocA,2018-01-09 12:00,60
UserA从12点在LocA停留60min,
UserA,LocA,2018-01-09 13:00,60
UserA从13点在LocA停留60min,
处理逻辑:
1.对同一个用户,在同一个位置,连续多条的记录会进行合并
2.合并原则: 开始时间取最早时间,停留时长加和
输出:
UserA,LocA,2018-01-09 12:00,120
UserA,LocB,2018-01-09 11:00,20
UserA,LocA,2018-01-09 19:00,60
要求:
请使用Spark,MapReduce 或其他分布式处理引擎处理
测试数据:
UserA,LocA,2018-01-09 12:00:00,60
UserA,LocA,2018-01-09 13:00:00,60
UserA,LocB,2018-01-09 11:00:00,20
UserA,LocA,2018-01-09 19:00:00,60
UserB,LocA,2018-01-09 19:00:00,60
UserB,LocA,2018-01-09 18:00:00,60
UserB,LocA,2018-01-09 17:00:00,60
自定义的类:
package com.offer.test.UserSession;
import java.io.Serializable;
import java.util.Comparator;
/**
* Created by szh on 2019/3/12.
*/
public class UserInfo implements Serializable
//, Comparable
{
String user;
String loc;
long time;
long dalay;
public String getUser() {
return user;
}
public void setUser(String user) {
this.user = user;
}
public String getLoc() {
return loc;
}
public void setLoc(String loc) {
this.loc = loc;
}
public long getTime() {
return time;
}
public void setTime(long time) {
this.time = time;
}
public long getDalay() {
return dalay;
}
public void setDalay(long dalay) {
this.dalay = dalay;
}
// @Override
// public int compareTo(Object o) {
//
// UserInfo tmp = (UserInfo) o;
// String key1 = this.user + this.loc;
// String key2 = tmp.getUser() + tmp.getLoc();
//
// if (key1.hashCode() == key2.hashCode()) {
// if (this.getTime() > tmp.getTime()) {
// return 1;
// } else if (this.getTime() < tmp.getTime()) {
// return -1;
// }
// return 0;
// } else if (key1.hashCode() > key2.hashCode()) {
// return 1;
// } else {
// return -1;
// }
// }
@Override
public String toString() {
return "UserInfo{" +
"user='" + user + '\'' +
", loc='" + loc + '\'' +
", time=" + time +
", dalay=" + dalay +
'}';
}
}
思路一:
遇到这个问题,我首先想到的思路是
Step1 : 根据 userID+Loc 作为key 对数据进行分组 使用 groupByKey() 算子,
Step2:然后对 groupByKey() 后每个 key 对应的 sequece 做迭代,对元素进行排序。
Step3 : 选取 排完序的 List 的第一个值做初始化,并标记该元素 为被合并后的元素
记录之后要被合并的时间点 为 firstElement.time.toLong + delay * 60 * 1000, 将指针指向下一个元素。
Step4:
如果之前的 合并时间节点 与当前时间点 相同,
则 该记录为要被合并的数据, 更新如下值 1. dalay = delay + curNode.delay。 2. timeStamp = timeStamp + delay * 60 * 1000,
指向下一个元素。
否则,1. 合并数据加入到返回值中。2.记录当前节点,初始化 delay 等数据
代码如下:
package com.offer.test.UserSession;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.joda.time.DateTime;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import scala.Tuple2;
import java.util.*;
/**
* Created by szh on 2019/3/12.
*
* @author szh
*/
public class UserLocTest {
static final String FORMAT = "yyyy-MM-dd HH:mm:ss";
static final String SPLIT = ",";
public static void main(String args[]) {
SparkConf conf = new SparkConf().setAppName("url-test").setMaster("local[2]");
SparkContext jsc = new SparkContext(conf);
SparkSession sparkSession = SparkSession.builder().sparkContext(jsc).getOrCreate();
Dataset<Row> org = sparkSession
.read()
.option("timestampFormat", FORMAT)
.csv("file:///E:\\test-spark\\offline-java\\files\\offer\\userLoc\\test.csv")
.toDF(new String[]{"user", "loc", "time", "delay"});
org.show();
JavaPairRDD<String, UserInfo> tmpPairMid = org.toJavaRDD().mapPartitionsToPair((x) -> {
List<Tuple2<String, UserInfo>> result = new ArrayList<>();
Row tmpRow = null;
UserInfo tmpInfo = null;
DateTimeFormatter formatter = DateTimeFormat.forPattern(FORMAT);
while (x.hasNext()) {
tmpRow = x.next();
tmpInfo = new UserInfo();
tmpInfo.setDalay(Long.valueOf(tmpRow.getAs("delay")));
tmpInfo.setLoc(tmpRow.getAs("loc"));
DateTime datetime = formatter.parseDateTime(tmpRow.getAs("time"));
tmpInfo.setTime(datetime.getMillis());
tmpInfo.setUser(tmpRow.getAs("user"));
result.add(new Tuple2<>(tmpRow.getAs("user") + "-" + tmpRow.getAs("loc"), tmpInfo));
}
return result.iterator();
});
//TODO 测试代码
Map<String, Long> tmpMap = tmpPairMid.groupByKey().countByKey();
for (Map.Entry<String, Long> tmp : tmpMap.entrySet()) {
System.out.println("key : " + tmp.getKey() + " , " + "value : " + tmp.getValue());
}
JavaRDD<String> finalData = tmpPairMid.groupByKey().mapPartitions(new FlatMapFunction<Iterator<Tuple2<String, Iterable<UserInfo>>>, String>() {
@Override
public Iterator<String> call(Iterator<Tuple2<String, Iterable<UserInfo>>> tuple2Iterator) throws Exception {
//最终的返回值
List<String> result = new ArrayList<>();
while (tuple2Iterator.hasNext()) {
Iterable<UserInfo> infoIterable = tuple2Iterator.next()._2();
Iterator<UserInfo> iterator = infoIterable.iterator();
List<UserInfo> tmpList = new ArrayList<>();
while (iterator.hasNext()) {
tmpList.add(iterator.next());
}
//TODO : 测试代码
System.out.println(tmpList.size());
tmpList.sort(new Comparator<UserInfo>() {
@Override
public int compare(UserInfo o1, UserInfo o2) {
if (o1.getTime() > o2.getTime()) {
return 1;
} else if (o1.getTime() < o2.getTime()) {
return -1;
}
return 0;
}
});
UserInfo tmpUser = null;
if (tmpList.size() > 0) {
//init
tmpUser = tmpList.get(0);
UserInfo preUser = tmpUser;
long preTime = tmpUser.getTime() + tmpUser.getDalay() * 60 * 1000;
long delay = tmpUser.getDalay();
for (int i = 1; i < tmpList.size(); i++) {
tmpUser = tmpList.get(i);
if (preTime == tmpUser.getTime()) {
//跟之前的key相同
delay += tmpUser.getDalay();
preTime += tmpUser.getDalay() * 60 * 1000;
System.out.println("-" + preTime + "--");
} else {
//存储之前的信息
StringBuffer testBuffer = new StringBuffer();
testBuffer.
append(preUser.getUser()).append(SPLIT).
append(preUser.getLoc()).append(SPLIT).
append(new DateTime(preUser.getTime()).toString(FORMAT)).append(SPLIT).
append(delay);
result.add(testBuffer.toString());
//更新信息
preUser = tmpUser;
preTime = tmpUser.getTime() + tmpUser.getDalay() * 60 * 1000;
delay = tmpUser.getDalay();
}
}
StringBuffer testBuffer = new StringBuffer();
testBuffer.
append(preUser.getUser()).append(SPLIT).
append(preUser.getLoc()).append(SPLIT).
append(new DateTime(preUser.getTime()).toString(FORMAT)).append(SPLIT).
append(delay);
result.add(testBuffer.toString());
}
}
return result.iterator();
}
});
finalData.saveAsTextFile("file:///E:\\test-spark\\offline-java\\out\\offer\\userLoc");
// JavaRDD<String> zz = test.map((x) -> {
// System.out.println(x.toString());
// return x.toString();
// });
// zz.count();
//RelationalGroupedDataset relationalGroupedDataset = org.groupBy(new Column[]{new Column("user"),new Column("loc")});
}
}
思路二:
在思路一,我们用了groupByKey 然后自己迭代实现了排序。
其实我们没有必要为每个key 生成个序列,我们只需要 相同 userID + loc 发送到相同的partition , 然后 partition 内部按照 timestamp 升序排序即可。
基于以上的思路,我去查看 spark 有没有对应的算子,我找到了Spark 2.0 有对应的算子。
spark算子1:repartitionAndSortWithinPartitions
https://www.jianshu.com/p/5906ddb5bfcd
repartitionAndSortWithinPartitions
官方文档描述:
Repartition the RDD according to the given partitioner and, within each resulting partition,sort records by their keys.This is more efficient than calling `repartition` and then sorting within each partition because it can push the sorting down into the shuffle machinery.
函数原型:
def repartitionAndSortWithinPartitions(partitioner: Partitioner): JavaPairRDD[K, V]
def repartitionAndSortWithinPartitions(partitioner: Partitioner, comp: Comparator[K]) : JavaPairRDD[K, V]
源码分析:
def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = self.withScope {
new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering)
}
从源码中可以看出,该方法依据partitioner对RDD进行分区,并且在每个结果分区中按key进行排序;通过对比sortByKey发现,这种方式比先分区,然后在每个分区中进行排序效率高,这是因为它可以将排序融入到shuffle阶段。
代码的思路基本与 思路一类似,
代码如下:
package com.offer.test.UserSession;
import org.apache.spark.Partitioner;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.joda.time.DateTime;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import scala.Serializable;
import scala.Tuple2;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.function.Consumer;
/**
* Created by szh on 2019/3/12.
*
* @author szh
*/
public class UserLocTest2 implements Serializable {
static final String FORMAT = "yyyy-MM-dd HH:mm:ss";
static final String SPLIT = ",";
public static void main(String args[]) {
SparkConf conf = new SparkConf().setAppName("url-test").setMaster("local[2]");
SparkContext jsc = new SparkContext(conf);
SparkSession sparkSession = SparkSession.builder().sparkContext(jsc).getOrCreate();
Dataset<Row> org = sparkSession
.read()
.option("timestampFormat", FORMAT)
.csv("file:///E:\\test-spark\\offline-java\\files\\offer\\userLoc\\test.csv")
.toDF(new String[]{"user", "loc", "time", "delay"});
org.show();
JavaPairRDD<String, UserInfo> tmpMid = org.toJavaRDD().mapPartitionsToPair((x) -> {
List<Tuple2<String, UserInfo>> result = new ArrayList<>();
Row tmpRow = null;
UserInfo tmpInfo = null;
DateTimeFormatter formatter = DateTimeFormat.forPattern(FORMAT);
while (x.hasNext()) {
tmpRow = x.next();
tmpInfo = new UserInfo();
tmpInfo.setDalay(Long.valueOf(tmpRow.getAs("delay")));
tmpInfo.setLoc(tmpRow.getAs("loc"));
DateTime datetime = formatter.parseDateTime(tmpRow.getAs("time"));
tmpInfo.setTime(datetime.getMillis());
tmpInfo.setUser(tmpRow.getAs("user"));
result.add(new Tuple2<String, UserInfo>(tmpInfo.getUser() + "-" + tmpInfo.getLoc() + "-" + tmpInfo.getTime(), tmpInfo));
}
return result.iterator();
});
JavaPairRDD<String, UserInfo> afterPartAndCmp = tmpMid.repartitionAndSortWithinPartitions(new Partitioner() {
@Override
public int getPartition(Object key) {
String tmp = (String) key;
String[] arr = tmp.split("-");
return (arr[0].hashCode() + arr[1].hashCode()) % numPartitions();
}
@Override
public int numPartitions() {
return 10;
}
}
// , new Comparator<String>() {
// @Override
// public int compare(String o1, String o2) {
//
// String[] tmp1 = o1.split("-");
// String key1 = tmp1[0] + tmp1[1];
// Long value1 = Long.valueOf(tmp1[2]);
//
// String[] tmp2 = o2.split("-");
// String key2 = tmp2[0] + tmp2[1];
// Long value2 = Long.valueOf(tmp2[2]);
//
// if (key1.hashCode() > key2.hashCode()) {
// return 1;
// } else if (key1.hashCode() < key2.hashCode()) {
// return -1;
// } else {
// if (key1.equals(key2)) {
// if (value1 > value2) {
// return 1;
// } else if (value1 < value2) {
// return -1;
// }
// return 0;
// } else {
// return 1;
// }
// }
// }
// }
);
// afterPartAndCmp.foreachPartition(new VoidFunction<Iterator<Tuple2<String, UserInfo>>>() {
// @Override
// public void call(Iterator<Tuple2<String, UserInfo>> tuple2Iterator) throws Exception {
// while (tuple2Iterator.hasNext()){
// System.out.println(tuple2Iterator.next()._2().toString());
// }
// }
// });
JavaRDD<String> finalData = afterPartAndCmp.mapPartitions(new FlatMapFunction<Iterator<Tuple2<String, UserInfo>>, String>() {
@Override
public Iterator<String> call(Iterator<Tuple2<String, UserInfo>> tuple2Iterator) throws Exception {
List<String> resultList = new ArrayList<>();
UserInfo preUser = null;
long preTime = -1;
long delay = -1;
//init
if (tuple2Iterator.hasNext()) {
preUser = tuple2Iterator.next()._2();
preTime = preUser.getTime() + preUser.getDalay() * 60 * 1000;
delay = preUser.getDalay();
}
while (tuple2Iterator.hasNext()) {
UserInfo tmpUser = tuple2Iterator.next()._2();
if(tmpUser.getTime() == preTime){
preTime += tmpUser.getDalay() * 60 * 1000;
delay += tmpUser.getDalay();
}else{
//存储之前的信息
StringBuffer stringBuffer = new StringBuffer();
stringBuffer.
append(preUser.getUser()).append(SPLIT).
append(preUser.getLoc()).append(SPLIT).
append(new DateTime(preUser.getTime()).toString(FORMAT)).append(SPLIT).
append(delay);
resultList.add(stringBuffer.toString());
//更新信息
preUser = tmpUser;
preTime = tmpUser.getTime() + tmpUser.getDalay() * 60 * 1000;
delay = tmpUser.getDalay();
}
}
if(preUser != null){
//存储之前的信息
StringBuffer testBuffer = new StringBuffer();
testBuffer.
append(preUser.getUser()).append(SPLIT).
append(preUser.getLoc()).append(SPLIT).
append(new DateTime(preUser.getTime()).toString(FORMAT)).append(SPLIT).
append(delay);
resultList.add(testBuffer.toString());
}
return resultList.iterator();
}
});
afterPartAndCmp.collect().forEach(new Consumer<Tuple2<String, UserInfo>>() {
@Override
public void accept(Tuple2<String, UserInfo> stringUserInfoTuple2) {
System.out.println(stringUserInfoTuple2._2().toString());
}
});
finalData.repartition(1).saveAsTextFile("file:///E:\\test-spark\\offline-java\\out\\offer\\userLoc2");
}
}
结果
UserA,LocA,2018-01-09 12:00:00,120
UserA,LocA,2018-01-09 19:00:00,60
UserA,LocB,2018-01-09 11:00:00,20
UserB,LocA,2018-01-09 17:00:00,180