通过扩展FlatMapFunction函数来对状态进行累计,通过实现CheckpointedFunction接口实现checkpoint和在初始化。
class CheckpointCount extends FlatMapFunction[(Int,Long),(Int,Long,Long)] with CheckpointedFunction{
/**
* The core method of the FlatMapFunction. Takes an element from the input data set and transforms
* it into zero, one, or more elements.
*
* @param value The input value.
* @param out The collector for returning result values.
* @throws Exception This method may throw exceptions. Throwing an exception will cause the operation
* to fail and may trigger recovery.
*/
private var operatorCount:Long=_
private var keyedState:ValueState[Long]=_
private var operatorState:ListState[Long]=_
override def flatMap(value: (Int, Long), out: Collector[(Int, Long, Long)]): Unit = {
val keyedCount=keyedState.value()+1
keyedState.update(keyedCount)
operatorCount=operatorCount+1
out.collect((value._1,keyedCount,operatorCount))
}
/**
* This method is called when a snapshot for a checkpoint is requested. This acts as a hook to the function to
* ensure that all state is exposed by means previously offered through {@link FunctionInitializationContext} when
* the Function was initialized, or offered now by {@link FunctionSnapshotContext} itself.
*
* @param context the context for drawing a snapshot of the operator
* @throws Exception
*/
override def snapshotState(context: FunctionSnapshotContext): Unit = {
operatorState.clear()
operatorState.add(operatorCount)
}
/**
* This method is called when the parallel function instance is created during distributed
* execution. Functions typically set up their state storing data structures in this method.
*
* @param context the context for initializing the operator
* @throws Exception
*/
override def initializeState(context: FunctionInitializationContext): Unit = {
keyedState=context.getKeyedStateStore.getState(new ValueStateDescriptor[Long]("keyedState",createTypeInformation[Long]))
operatorState=context.getOperatorStateStore.getListState(new ListStateDescriptor[Long]("operatorState",createTypeInformation[Long]))
println("Start="+operatorState.get()sum)
if(context.isRestored){
println("Restored="+operatorState.get()sum)
operatorCount=operatorState.get()sum
}
}
}
def main(args: Array[String]) {
val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment
val properties = new Properties()
properties.setProperty("bootstrap.servers" ,"localhost:9092")
properties.setProperty("group.id","flink_consumer")
env.setStateBackend(new FsStateBackend("file:///D:/output/state_checkpoint/"))
env.getCheckpointConfig.setCheckpointingMode(CheckpointingMode.EXACTLY_ONCE)
env.setRestartStrategy(RestartStrategies.noRestart())
env.getCheckpointConfig.isCheckpointingEnabled
env.enableCheckpointing(6000)
env.getCheckpointConfig.enableExternalizedCheckpoints(ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION)
val consumer = new FlinkKafkaConsumer010[String]("test", new SimpleStringSchema, properties)
val stream = env.addSource(consumer)
val inputStream: DataStream[(Int, Long)] =stream.filter(w=>w.length>4).map(w=>(w.split(" ")(0).toInt,w.split(" ")(1).toLong))
// env.getCheckpointConfig.setCheckpointInterval(20)
// val inputStream: DataStream[(Int, Long)] = env.fromElements((2, 2L), (4, 1L), (5, 4L),(4, 2L))
inputStream.keyBy(0).flatMap(new CheckpointCount()).print()
env.execute("EventTime processing example")
}