Spark 考题,session 数据 停留时长合并

 

   最近在面试 某知名企业的时候 遇到了一个比较难的面试题,当时没有思考出来解法。

回来后看了 mr 的运行流程  与spark 的一些文档,最终解决了这个问题。

 

MR 相关解法:

Hadoop学习之路(二十七)MapReduce的API使用(四)

https://www.cnblogs.com/qingyunzong/p/8639414.html

中的第二题

 

题目描述:

 

现有如下的数据需要处理,格式 csv

4行样例

UserA,LocA,2018-01-09 12:00,60
UserA,LocA,2018-01-09 13:00,60
UserA,LocB,2018-01-09 11:00,20
UserA,LocA,2018-01-09 19:00,60

含义:

UserA,LocA,2018-01-09 12:00,60
UserA从12点在LocA停留60min,

UserA,LocA,2018-01-09 13:00,60
UserA从13点在LocA停留60min,

处理逻辑:
1.对同一个用户,在同一个位置,连续多条的记录会进行合并
2.合并原则: 开始时间取最早时间,停留时长加和

输出:
UserA,LocA,2018-01-09 12:00,120
UserA,LocB,2018-01-09 11:00,20
UserA,LocA,2018-01-09 19:00,60


要求:
请使用Spark,MapReduce 或其他分布式处理引擎处理

 

 

测试数据:

UserA,LocA,2018-01-09 12:00:00,60
UserA,LocA,2018-01-09 13:00:00,60
UserA,LocB,2018-01-09 11:00:00,20
UserA,LocA,2018-01-09 19:00:00,60
UserB,LocA,2018-01-09 19:00:00,60
UserB,LocA,2018-01-09 18:00:00,60
UserB,LocA,2018-01-09 17:00:00,60

自定义的类:

package com.offer.test.UserSession;

import java.io.Serializable;
import java.util.Comparator;

/**
 * Created by szh on 2019/3/12.
 */
public class UserInfo implements Serializable
        //, Comparable
        {

    String user;

    String loc;

    long time;

    long dalay;


    public String getUser() {
        return user;
    }

    public void setUser(String user) {
        this.user = user;
    }

    public String getLoc() {
        return loc;
    }

    public void setLoc(String loc) {
        this.loc = loc;
    }


    public long getTime() {
        return time;
    }

    public void setTime(long time) {
        this.time = time;
    }

    public long getDalay() {
        return dalay;
    }

    public void setDalay(long dalay) {
        this.dalay = dalay;
    }

//    @Override
//    public int compareTo(Object o) {
//
//        UserInfo tmp = (UserInfo) o;
//        String key1 = this.user + this.loc;
//        String key2 = tmp.getUser() + tmp.getLoc();
//
//        if (key1.hashCode() == key2.hashCode()) {
//            if (this.getTime() > tmp.getTime()) {
//                return 1;
//            } else if (this.getTime() < tmp.getTime()) {
//                return -1;
//            }
//            return 0;
//        } else if (key1.hashCode() > key2.hashCode()) {
//            return 1;
//        } else {
//            return -1;
//        }
//    }

    @Override
    public String toString() {
        return "UserInfo{" +
                "user='" + user + '\'' +
                ", loc='" + loc + '\'' +
                ", time=" + time +
                ", dalay=" + dalay +
                '}';
    }
}

 

思路一: 

   遇到这个问题,我首先想到的思路是

Step1 : 根据 userID+Loc 作为key 对数据进行分组 使用 groupByKey() 算子,

Step2:然后对  groupByKey() 后每个 key 对应的 sequece 做迭代,对元素进行排序。

Step3 :  选取 排完序的 List 的第一个值做初始化,并标记该元素 为被合并后的元素

       记录之后要被合并的时间点 为 firstElement.time.toLong + delay * 60 * 1000,  将指针指向下一个元素。

Step4: 

    如果之前的 合并时间节点 与当前时间点 相同,

则 该记录为要被合并的数据, 更新如下值  1. dalay  = delay +  curNode.delay。 2. timeStamp = timeStamp + delay * 60 * 1000,

指向下一个元素。

否则,1. 合并数据加入到返回值中。2.记录当前节点,初始化 delay 等数据

 

代码如下:

package com.offer.test.UserSession;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.joda.time.DateTime;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import scala.Tuple2;

import java.util.*;

/**
 * Created by szh on 2019/3/12.
 *
 * @author szh
 */
public class UserLocTest {

    static final String FORMAT = "yyyy-MM-dd HH:mm:ss";
    static final String SPLIT = ",";

    public static void main(String args[]) {

        SparkConf conf = new SparkConf().setAppName("url-test").setMaster("local[2]");
        SparkContext jsc = new SparkContext(conf);

        SparkSession sparkSession = SparkSession.builder().sparkContext(jsc).getOrCreate();

        Dataset<Row> org = sparkSession
                .read()
                .option("timestampFormat", FORMAT)
                .csv("file:///E:\\test-spark\\offline-java\\files\\offer\\userLoc\\test.csv")
                .toDF(new String[]{"user", "loc", "time", "delay"});

        org.show();

        JavaPairRDD<String, UserInfo> tmpPairMid = org.toJavaRDD().mapPartitionsToPair((x) -> {

            List<Tuple2<String, UserInfo>> result = new ArrayList<>();
            Row tmpRow = null;
            UserInfo tmpInfo = null;
            DateTimeFormatter formatter = DateTimeFormat.forPattern(FORMAT);

            while (x.hasNext()) {

                tmpRow = x.next();
                tmpInfo = new UserInfo();

                tmpInfo.setDalay(Long.valueOf(tmpRow.getAs("delay")));
                tmpInfo.setLoc(tmpRow.getAs("loc"));
                DateTime datetime = formatter.parseDateTime(tmpRow.getAs("time"));
                tmpInfo.setTime(datetime.getMillis());
                tmpInfo.setUser(tmpRow.getAs("user"));

                result.add(new Tuple2<>(tmpRow.getAs("user") + "-" + tmpRow.getAs("loc"), tmpInfo));
            }

            return result.iterator();
        });

        //TODO 测试代码
        Map<String, Long> tmpMap = tmpPairMid.groupByKey().countByKey();
        for (Map.Entry<String, Long> tmp : tmpMap.entrySet()) {
            System.out.println("key : " + tmp.getKey() + " , " + "value : " + tmp.getValue());
        }


        JavaRDD<String> finalData = tmpPairMid.groupByKey().mapPartitions(new FlatMapFunction<Iterator<Tuple2<String, Iterable<UserInfo>>>, String>() {

            @Override
            public Iterator<String> call(Iterator<Tuple2<String, Iterable<UserInfo>>> tuple2Iterator) throws Exception {

                //最终的返回值
                List<String> result = new ArrayList<>();

                while (tuple2Iterator.hasNext()) {

                    Iterable<UserInfo> infoIterable = tuple2Iterator.next()._2();
                    Iterator<UserInfo> iterator = infoIterable.iterator();

                    List<UserInfo> tmpList = new ArrayList<>();
                    while (iterator.hasNext()) {
                        tmpList.add(iterator.next());
                    }

                    //TODO : 测试代码
                    System.out.println(tmpList.size());

                    tmpList.sort(new Comparator<UserInfo>() {
                        @Override
                        public int compare(UserInfo o1, UserInfo o2) {
                            if (o1.getTime() > o2.getTime()) {
                                return 1;
                            } else if (o1.getTime() < o2.getTime()) {
                                return -1;
                            }
                            return 0;
                        }
                    });


                    UserInfo tmpUser = null;
                    if (tmpList.size() > 0) {

                        //init
                        tmpUser = tmpList.get(0);
                        UserInfo preUser = tmpUser;
                        long preTime = tmpUser.getTime() + tmpUser.getDalay() * 60 * 1000;
                        long delay = tmpUser.getDalay();

                        for (int i = 1; i < tmpList.size(); i++) {

                            tmpUser = tmpList.get(i);

                            if (preTime == tmpUser.getTime()) {

                                //跟之前的key相同
                                delay += tmpUser.getDalay();
                                preTime += tmpUser.getDalay() * 60 * 1000;

                                System.out.println("-" + preTime + "--");

                            } else {

                                //存储之前的信息
                                StringBuffer testBuffer = new StringBuffer();
                                testBuffer.
                                        append(preUser.getUser()).append(SPLIT).
                                        append(preUser.getLoc()).append(SPLIT).
                                        append(new DateTime(preUser.getTime()).toString(FORMAT)).append(SPLIT).
                                        append(delay);
                                result.add(testBuffer.toString());

                                //更新信息
                                preUser = tmpUser;
                                preTime = tmpUser.getTime() + tmpUser.getDalay() * 60 * 1000;
                                delay = tmpUser.getDalay();
                            }
                        }

                        StringBuffer testBuffer = new StringBuffer();
                        testBuffer.
                                append(preUser.getUser()).append(SPLIT).
                                append(preUser.getLoc()).append(SPLIT).
                                append(new DateTime(preUser.getTime()).toString(FORMAT)).append(SPLIT).
                                append(delay);
                        result.add(testBuffer.toString());
                    }

                }

                return result.iterator();
            }
        });


        finalData.saveAsTextFile("file:///E:\\test-spark\\offline-java\\out\\offer\\userLoc");

//        JavaRDD<String> zz = test.map((x) -> {
//            System.out.println(x.toString());
//            return x.toString();
//        });
//        zz.count();

        //RelationalGroupedDataset relationalGroupedDataset = org.groupBy(new Column[]{new Column("user"),new Column("loc")});

    }

}

 

 

思路二:

    在思路一,我们用了groupByKey 然后自己迭代实现了排序。

    其实我们没有必要为每个key 生成个序列,我们只需要 相同 userID + loc 发送到相同的partition , 然后 partition 内部按照 timestamp 升序排序即可。

    基于以上的思路,我去查看 spark 有没有对应的算子,我找到了Spark  2.0 有对应的算子。

 

spark算子1:repartitionAndSortWithinPartitions

https://www.jianshu.com/p/5906ddb5bfcd


repartitionAndSortWithinPartitions

 


官方文档描述:

Repartition the RDD according to the given partitioner and, within each resulting partition,sort records by their keys.This is more efficient than calling `repartition` and then sorting within each partition because it can push the sorting down into the shuffle machinery.



函数原型:

def repartitionAndSortWithinPartitions(partitioner: Partitioner): JavaPairRDD[K, V]

def repartitionAndSortWithinPartitions(partitioner: Partitioner, comp: Comparator[K])  : JavaPairRDD[K, V]


源码分析:

def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = self.withScope {  
  new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering)
}


      从源码中可以看出,该方法依据partitioner对RDD进行分区,并且在每个结果分区中按key进行排序;通过对比sortByKey发现,这种方式比先分区,然后在每个分区中进行排序效率高,这是因为它可以将排序融入到shuffle阶段。

 

代码的思路基本与 思路一类似,

代码如下:

package com.offer.test.UserSession;

import org.apache.spark.Partitioner;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.joda.time.DateTime;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import scala.Serializable;
import scala.Tuple2;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.function.Consumer;

/**
 * Created by szh on 2019/3/12.
 *
 * @author szh
 */
public class UserLocTest2 implements Serializable {

    static final String FORMAT = "yyyy-MM-dd HH:mm:ss";
    static final String SPLIT = ",";

    public static void main(String args[]) {

        SparkConf conf = new SparkConf().setAppName("url-test").setMaster("local[2]");
        SparkContext jsc = new SparkContext(conf);

        SparkSession sparkSession = SparkSession.builder().sparkContext(jsc).getOrCreate();

        Dataset<Row> org = sparkSession
                .read()
                .option("timestampFormat", FORMAT)
                .csv("file:///E:\\test-spark\\offline-java\\files\\offer\\userLoc\\test.csv")
                .toDF(new String[]{"user", "loc", "time", "delay"});

        org.show();

        JavaPairRDD<String, UserInfo> tmpMid = org.toJavaRDD().mapPartitionsToPair((x) -> {

            List<Tuple2<String, UserInfo>> result = new ArrayList<>();
            Row tmpRow = null;
            UserInfo tmpInfo = null;
            DateTimeFormatter formatter = DateTimeFormat.forPattern(FORMAT);

            while (x.hasNext()) {

                tmpRow = x.next();
                tmpInfo = new UserInfo();

                tmpInfo.setDalay(Long.valueOf(tmpRow.getAs("delay")));
                tmpInfo.setLoc(tmpRow.getAs("loc"));
                DateTime datetime = formatter.parseDateTime(tmpRow.getAs("time"));
                tmpInfo.setTime(datetime.getMillis());
                tmpInfo.setUser(tmpRow.getAs("user"));

                result.add(new Tuple2<String, UserInfo>(tmpInfo.getUser() + "-" + tmpInfo.getLoc() + "-" + tmpInfo.getTime(), tmpInfo));
            }

            return result.iterator();
        });

        JavaPairRDD<String, UserInfo> afterPartAndCmp = tmpMid.repartitionAndSortWithinPartitions(new Partitioner() {
            @Override
            public int getPartition(Object key) {

                String tmp = (String) key;
                String[] arr = tmp.split("-");

                return (arr[0].hashCode() + arr[1].hashCode()) % numPartitions();
            }

            @Override
            public int numPartitions() {
                return 10;
            }
        }

//        , new Comparator<String>()  {
//            @Override
//            public int compare(String o1, String o2) {
//
//                String[] tmp1 = o1.split("-");
//                String key1 = tmp1[0] + tmp1[1];
//                Long value1 = Long.valueOf(tmp1[2]);
//
//                String[] tmp2 = o2.split("-");
//                String key2 = tmp2[0] + tmp2[1];
//                Long value2 = Long.valueOf(tmp2[2]);
//
//                if (key1.hashCode() > key2.hashCode()) {
//                    return 1;
//                } else if (key1.hashCode() < key2.hashCode()) {
//                    return -1;
//                } else {
//                    if (key1.equals(key2)) {
//                        if (value1 > value2) {
//                            return 1;
//                        } else if (value1 < value2) {
//                            return -1;
//                        }
//                        return 0;
//                    } else {
//                        return 1;
//                    }
//                }
//            }
//        }

        );


//       afterPartAndCmp.foreachPartition(new VoidFunction<Iterator<Tuple2<String, UserInfo>>>() {
//           @Override
//           public void call(Iterator<Tuple2<String, UserInfo>> tuple2Iterator) throws Exception {
//               while (tuple2Iterator.hasNext()){
//                   System.out.println(tuple2Iterator.next()._2().toString());
//               }
//           }
//       });

        JavaRDD<String> finalData = afterPartAndCmp.mapPartitions(new FlatMapFunction<Iterator<Tuple2<String, UserInfo>>, String>() {
            @Override
            public Iterator<String> call(Iterator<Tuple2<String, UserInfo>> tuple2Iterator) throws Exception {

                List<String> resultList = new ArrayList<>();

                UserInfo preUser = null;
                long preTime = -1;
                long delay = -1;

                //init
                if (tuple2Iterator.hasNext()) {
                    preUser = tuple2Iterator.next()._2();
                    preTime = preUser.getTime() + preUser.getDalay() * 60 * 1000;
                    delay = preUser.getDalay();
                }


                while (tuple2Iterator.hasNext()) {

                    UserInfo tmpUser = tuple2Iterator.next()._2();

                    if(tmpUser.getTime() == preTime){

                        preTime += tmpUser.getDalay() * 60 * 1000;
                        delay += tmpUser.getDalay();
                    }else{

                        //存储之前的信息
                        StringBuffer stringBuffer = new StringBuffer();
                        stringBuffer.
                                append(preUser.getUser()).append(SPLIT).
                                append(preUser.getLoc()).append(SPLIT).
                                append(new DateTime(preUser.getTime()).toString(FORMAT)).append(SPLIT).
                                append(delay);
                        resultList.add(stringBuffer.toString());

                        //更新信息
                        preUser = tmpUser;
                        preTime = tmpUser.getTime() + tmpUser.getDalay() * 60 * 1000;
                        delay = tmpUser.getDalay();
                    }

                }

                if(preUser != null){
                    //存储之前的信息
                    StringBuffer testBuffer = new StringBuffer();
                    testBuffer.
                            append(preUser.getUser()).append(SPLIT).
                            append(preUser.getLoc()).append(SPLIT).
                            append(new DateTime(preUser.getTime()).toString(FORMAT)).append(SPLIT).
                            append(delay);
                    resultList.add(testBuffer.toString());
                }

                return resultList.iterator();
            }
        });


        afterPartAndCmp.collect().forEach(new Consumer<Tuple2<String, UserInfo>>() {
            @Override
            public void accept(Tuple2<String, UserInfo> stringUserInfoTuple2) {
                System.out.println(stringUserInfoTuple2._2().toString());
            }
        });

        finalData.repartition(1).saveAsTextFile("file:///E:\\test-spark\\offline-java\\out\\offer\\userLoc2");
    }

}

 

 

结果

 

UserA,LocA,2018-01-09 12:00:00,120
UserA,LocA,2018-01-09 19:00:00,60
UserA,LocB,2018-01-09 11:00:00,20
UserB,LocA,2018-01-09 17:00:00,180

 

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值