package cn.spark.study.core;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import scala.Tuple2;
/**
* Kmer是一个长度为K(K>0)的子串,统计所有这些子串出现的次数
* 是很多DNA序列数据分析中的中心环节
*
* @author 连儒达
*
*/
public class Kmer {
/**
* 数组args包含三个参数
* 1.存储在hdfs的FASTQ文件
* 2.k (k>0)查找K-mer
* 3.N (N>0)查找top N的K-mer
*
* 查找每个分区本地top N的工作由JavaPairRDD.mapPartitions()方法处理
*
* @param args
*/
public static void main(String[] args){
SparkConf conf = new SparkConf().setAppName("Kmer");
JavaSparkContext ctx = new JavaSparkContext(conf);
if(args.length < 1){
System.out.println("请在执行脚本写上参数");
System.exit(1);
}
final String fastqFileName = args[0];
final int K = Integer.parseInt(args[1]);
final int N = Integer.parseInt(args[2]);
JavaRDD<String> records = ctx.textFile(fastqFileName,1);
/**
* 广播全局共享对象
*/
final Broadcast<Integer> broadcastK = ctx.broadcast(K);
final Broadcast<Integer> broadcastN = ctx.broadcast(N);
/**
* 过滤掉不用的元素
*/
JavaRDD<String> filter = records.filter(new Function<String,Boolean>(){
private static final long serialVersionUID = 1L;
@Override
public Boolean call(String v1) throws Exception {
String firstChar = v1.substring(0,1);
if(firstChar.equals("@") ||
firstChar.equals("+") ||
firstChar.equals(";") ||
firstChar.equals("!") ||
firstChar.equals("~") ){
return false;
}
else{
return true;
}
}
});
/**
* 找出所有的K-mers
*/
JavaPairRDD<String,Integer> kmers = filter.flatMapToPair(new PairFlatMapFunction<String,String,Integer>(){
private static final long serialVersionUID = 1L;
@Override
public Iterable<Tuple2<String, Integer>> call(String s) throws Exception {
int K = broadcastK.value();
List<Tuple2<String,Integer>> list = new ArrayList<Tuple2<String,Integer>>();
for(int i = 0;i < s.length() - K+1; i++){
String kmer = s.substring(i, K+i);//从i到k+i截取
list.add(new Tuple2<String,Integer>(kmer,1));
}
return list;
}
});
List<Tuple2<String, Integer>> debug1 = kmers.collect();
for(Tuple2<String, Integer> t2 : debug1){
System.out.println("debug1 t2._1 = "+t2._1);
System.out.println("debug1 t2._2 = "+t2._2);
}
/**
* 归约
*/
JavaPairRDD<String,Integer> kmersGrouped = kmers.reduceByKey(new Function2<Integer,Integer,Integer>(){
private static final long serialVersionUID = 1L;
@Override
public Integer call(Integer v1, Integer v2) throws Exception {
return v1 + v2;
}
});
List<Tuple2<String, Integer>> debug2 = kmersGrouped.collect();
for(Tuple2<String, Integer> t2 : debug2){
System.out.println("debug2 t2._1 = "+t2._1);
System.out.println("debug2 t2._2 = "+t2._2);
}
/**
* 为每一个分区创建一个本地top N
* 每一个分区只保留top N的(key,value)对
* 查找每个分区本地top N的工作由JavaPairRDD.mapPartitions()方法处理
*/
JavaRDD<SortedMap<Integer,String>> partitions = kmersGrouped.mapPartitions(
new FlatMapFunction<Iterator<Tuple2<String,Integer>>,SortedMap<Integer,String>>(){
private static final long serialVersionUID = 1L;
@Override
public Iterable<SortedMap<Integer, String>> call(Iterator<Tuple2<String, Integer>> iter)
throws Exception {
int N = broadcastN.value();
SortedMap<Integer,String> topN = new TreeMap<Integer,String>();
while(iter.hasNext()){
Tuple2<String,Integer> tuple = iter.next();
String kmer = tuple._1;
int freqency = tuple._2;
topN.put(freqency, kmer);
if(topN.size() > N){
topN.remove(topN.firstKey());
}
}
System.out.println("topN = "+topN);
return Collections.singletonList(topN);
}
});
/**
* 找出最终top N
*/
SortedMap<Integer,String> finaltopN = new TreeMap<Integer,String>();
List<SortedMap<Integer,String>> alltopN = partitions.collect();
for(SortedMap<Integer,String> localtopN : alltopN){
for(Map.Entry<Integer, String> entry : localtopN.entrySet()){
finaltopN.put(entry.getKey(), entry.getValue());
if(finaltopN.size() > N){
finaltopN.remove(finaltopN.firstKey());
}
}
}
List<Integer> frequencies = new ArrayList<Integer>(finaltopN.keySet());
for(int i = frequencies.size() - 1;i >= 0;i--){
System.out.println("top N\t" + frequencies.get(i) + "\t" + finaltopN.get(frequencies.get(i)));
}
/**
*
*/
List<Tuple2<String,Integer>> finalTop5 = kmersGrouped.top(5,TupleComparatorDescending.INSTANCE);
for(Tuple2<String, Integer> t2 : finalTop5){
System.out.println("JavaPairRDD.top Top5 t2._1 = "+t2._1 + " t2._2 " + t2._2);
}
ctx.close();
System.exit(0);
}
}
class TupleComparatorDescending implements Comparator<Tuple2<String,Integer>>,Serializable{
final static TupleComparatorDescending INSTANCE = new TupleComparatorDescending();
@Override
public int compare(Tuple2<String, Integer> o1, Tuple2<String, Integer> o2) {
return -o1._2.compareTo(o2._2);
}
}
测试数据:
hadoop dfs -cat /tKmer.fastq
@EAS54_6_R1_2_1_413_324
CCCTTCTTGTCTTCAGCGTTTCTCC
+
;;3;;;;;;;;;;;;7;;;;;;;88
@EAS54_6_R1_2_1_540_792
TTGGCAGGCCAAGGCCGATGGATCA
+
;;;;;;;;;;;7;;;;;-;;;3;83
@EAS54_6_R1_2_1_443_348
GTTGCTTCTGGCGTGGGTGGGGGGG
+EAS54_6_R1_2_1_443_348
;;;;;;;;;;;9;7;;.7;393333
脚本:
cat jKmer.sh
input=hdfs://spark01:9000/tKmer.fastq
K=3
N=5
/usr/local/spark1.5/bin/spark-submit \
--class cn.spark.study.core.Kmer \
--num-executors 3 \
--driver-memory 100m \
--executor-memory 100m \
--executor-cores 3 \
/usr/local/spark-text/java/kmer/Kmer.jar $input $K $N
运行结果:
top N 6 GGG
top N 5 TGG
top N 4 TCT
top N 3 TTG
top N 2 GCG