对于迭代运算,通常应用于一些数学计算,机器学习算法以及图计算等领域,在Flink中,对于批处理作业,也提供了相应的迭代运算,主要分为下面两大类:
- Bulk Iterate
- Delta Iterate
什么是迭代运算?
所谓迭代运算,就是给定一个初值,用所给的算法公式计算初值得到一个中间结果,然后将中间结果作为输入参数进行反复计算,在满足一定条件的时候得到计算结果。
Bulk Iterate
这种迭代方式称为全量迭代,它会将整个数据输入,经过一定的迭代次数,最终得到你想要的结果,如下图所示:
从上图可以看出,该迭代过程主要分为以下几步:
- Iteration Input(迭代输入):是初始输入值或者上一次迭代计算的结果
- Step Function(step函数):它迭代计算DataSet,由一系列的operator组成,比如map,flatMap,join等,取决于具体的业务逻辑。
- Next Partial Solution(中间结果):每一次迭代计算的结果,被发送到下一次迭代计算中。
- Iteration Result(迭代结果):最后一次迭代输出的结果,被输出到datasink或者发送到下游处理。
它迭代的结束条件是:
- 达到最大迭代次数
- 自定义收敛聚合函数
在官方文档中有下面这样一个例子,给定一组数据,输出迭代10次每次加1后的结果,如下图所示,这个例子比较简单,这里就不贴代码了:
这里介绍一些比较有意思的例子,使用蒙洛卡特方法来计算圆周率。
蒙洛卡特思想的核心就是:假设这里有一个半径为1的圆,它的面积S=PiR2=Pi,所以我们只要计算出这个圆的面积就可以计算出圆周率了。这里我们可以在一个边长为1的正方形中计算圆的四分之一扇形的面积,这样扇形的面积的4倍就是整个圆的面积了。如何计算扇形的面积?可以使用概率的方法,假设在这个正方形中有n个点,那么有m个点落在了扇形中,那么S扇形:S正方形=m:n。这样就可以计算出扇形的面积,最终计算出圆周率了。
最终实现代码如下:
/**
* 使用蒙洛卡特方法计算圆周率
*/
public class IterativePi {
public static void main(String[] args) throws Exception {
final ExecutionEnvironment env=ExecutionEnvironment.getExecutionEnvironment();
//迭代次数
int iterativeNum=100000;
Random random=new Random(1);
IterativeDataSet<Integer> iterativeDataSet=env.fromElements(0).iterate(iterativeNum);
DataSet<Integer> mapResult=iterativeDataSet.map(new MapFunction<Integer, Integer>() {
@Override
public Integer map(Integer value) throws Exception {
double x=random.nextDouble();
double y=random.nextDouble();
value+=(x*x+y*y<=1)?1:0;
return value;
}
});
//迭代结束的条件
DataSet<Integer> result=iterativeDataSet.closeWith(mapResult);
result.map(new MapFunction<Integer, Double>() {
@Override
public Double map(Integer count) throws Exception {
return count/(double)iterativeNum*4;
}
}).print();
}
}
Delta Iterate
这种迭代方式称为增量迭代,它并不是每次去迭代全量的数据,而是有两个数据集,WorkSet和SolutionSet,每次输入这两个数据集进行迭代运算(这两个数据集可以相等),然后对workset进行迭代运算并且不断的更新solutionset,直到达到迭代次数或者workset为空,输出迭代计算结果。如下图所示:
主要需要下面的几步:
- Iteration Input:读取初始WorkSet和初始Solution Set作为第一次迭代计算的输入。
- Step Function:step函数,每次迭代计算dataset,由map,flatmap以及join等操作组成的,具体有业务逻辑决定。
- Next Workset/Update Solution Set:Next WorkSet驱动迭代计算,将计算结果反馈到下一次迭代计算中,Solution Set将被不断的更新。两个数据集都在step函数中被迭代计算。
- Iteration Result:在最后一次迭代计算完成后,Solution Set将被输出或者输入下游处理。
迭代终止的条件:
- 达到迭代次数或者work Set为空(默认)
- 自定义聚合器收敛
其代码编写模型如下:
IterationState workset = getInitialState();
IterationState solution = getInitialSolution();
while (!terminationCriterion()) {
(delta, workset) = step(workset, solution);
solution.update(delta)
}
setFinalState(solution);
下面以一个连通体算法:最小传播值为例,计算每一个连通体中的最小ID值。
首先我们需要明白什么是连通图,如下图所示,就是两个连通图:
那么什么是最小传播值呢?在上面的图中,1,2,3,4组成了一个连通图,在这个连通图中,对每一个顶点进行编号,求出ID值最小的顶点,比如上面的图一中最小值是1。如果初始输入值是一条条边,我们最终要计算输出形如这样的元组对(vertixID,minimumID),比如(1,1),(2,1),(3,1),(4,1)这样,图二也类似。
其迭代过程如下图所示:
最终实现代码如下:
public class IterativeGraph {
public static void main(String[] args) throws Exception {
final ExecutionEnvironment env=ExecutionEnvironment.getExecutionEnvironment();
int iterativeNum=100;
//顶点
DataSet<Long> vertix=env.fromElements(1L,2L,3L,4L,5L,6L,7L);
//边
DataSet<Tuple2<Long,Long>> edges=env.fromElements(
Tuple2.of(1L, 2L),
Tuple2.of(2L, 3L),
Tuple2.of(2L, 4L),
Tuple2.of(3L, 4L),
Tuple2.of(5L, 6L),
Tuple2.of(5L, 7L),
Tuple2.of(6L, 7L)
);
//单向边转为双向边
edges=edges.flatMap(new FlatMapFunction<Tuple2<Long,Long>, Tuple2<Long,Long>>() {
@Override
public void flatMap(Tuple2<Long, Long> tuple, Collector<Tuple2<Long, Long>> collector) throws Exception {
collector.collect(tuple);
collector.collect(Tuple2.of(tuple.f1,tuple.f0));
}
});
//initialSolutionSet,将顶点映射为(vertix,vertix)的形式
DataSet<Tuple2<Long,Long>> initialSolutionSet=vertix.map(new MapFunction<Long, Tuple2<Long, Long>>() {
@Override
public Tuple2<Long, Long> map(Long vertix) throws Exception {
return Tuple2.of(vertix,vertix);
}
});
//initialWorkSet
DataSet<Tuple2<Long,Long>> initialWorkSet=vertix.map(new MapFunction<Long, Tuple2<Long, Long>>() {
@Override
public Tuple2<Long, Long> map(Long vertix) throws Exception {
return Tuple2.of(vertix,vertix);
}
});
//第一个字段做迭代运算
DeltaIteration<Tuple2<Long,Long>,Tuple2<Long,Long>> iterative=
initialSolutionSet.iterateDelta(initialWorkSet,iterativeNum,0);
//数据集合边做join操作,然后求出当前顶点的邻居顶点的最小ID值
DataSet<Tuple2<Long,Long>> changes=iterative.getWorkset().join(edges).where(0).equalTo(0).with(new NeighborWithComponentIDJoin())
.groupBy(0).aggregate(Aggregations.MIN,1)
//和solution set进行join操作,更新solution set,如果当前迭代结果中的最小ID小于solution中的ID值,则发送到下一次迭代运算中继续运算,否则不发送
.join(iterative.getSolutionSet()).where(0).equalTo(0)
.with(new ComponetIDFilter());
//关闭迭代计算
DataSet<Tuple2<Long,Long>> result=iterative.closeWith(changes,changes);
result.print();
}
public static class NeighborWithComponentIDJoin implements JoinFunction<Tuple2<Long,Long>,Tuple2<Long,Long>,Tuple2<Long,Long>>{
@Override
public Tuple2<Long, Long> join(Tuple2<Long, Long> t1, Tuple2<Long, Long> t2) throws Exception {
return Tuple2.of(t2.f1,t1.f1);
}
}
public static class ComponetIDFilter implements FlatJoinFunction<Tuple2<Long,Long>,Tuple2<Long,Long>,Tuple2<Long,Long>> {
@Override
public void join(Tuple2<Long, Long> t1, Tuple2<Long, Long> t2, Collector<Tuple2<Long, Long>> collector) throws Exception {
if(t1.f1<t2.f1){
collector.collect(t1);
}
}
}
}
最终计算结果如下所示:
(7,5)
(3,1)
(6,5)
(5,5)
(1,1)
(4,1)
(2,1)
欢迎加入大数据交流群:731423890
参考资料:
https://ci.apache.org/projects/flink/flink-docs-release-1.6/dev/batch/iterations.html