1. 背景
DAG:是一个无回路的有向图。一个经典的应用是用于任务的调度,用来定义任务的依赖关系和流向, 根据整个DAG的定义,可以从中获取哪个任务该先执行,哪个任务后执行。哪些步骤是可以并行执行的。
2. DAG的定义
这里阐述一个简单的应用例子。推荐系统的通常需要进行多队列召回,然后进行粗排、精排、混排。可以将这些操作抽象成不同类型的rpc调用,在数据召回之后,还可以抽象出来两种动作,包括一种是归并截断、一种是切割分段。有点map-reduce的味道。
定义字段可以如下:
字段名 | 字段类型 | 字段含义 |
---|---|---|
task_id | int | 任务id |
task_name | string | 任务名 |
task_type | int | 任务类型 |
task_deps | int[] | 依赖任务列表 |
output | map<id, score> | 输出结果 |
output_type | int | 输出类型 |
实际上由于业务的复杂性,output的定义会更加复杂一点,这里相当于是简化了。然后将上面的字段转化成json配置文件,即可完成DAG任务图的配置定义。
3. DAG 代码实现
这里没有定义输出字段,所有的节点的入度应该是这个node的任务的输入数据,因为业务的复杂性,可以使用反射机制来定义输出字段。
package dag
import (
"container/list"
"fmt"
"sync"
mapset "github.com/deckarep/golang-set"
)
// 任务附加信息
type TaskInfo struct {
TaskId int // 任务id
TaskName string // 任务名称
TaskType int // 任务类型 0 rpc调用 1 归并排序 2 切割分路
RpcObj string // 调用的rpc请求
RetrieveNum int // 召回数量
CutNum int // 归并截断数
SplitNum int // 切割分路打分
}
// 任务节点信息
type TaskNode struct {
TaskInfo TaskInfo // 任务信息
OutEdge map[int]bool // 本任务依赖的其他任务、完成状况
InEdge map[int]bool // 依赖了本任务的其他任务、完成状况
OutCounter int //剩余依赖任务
InCounter int // 剩余被依赖任务
Done bool // 任务是否完成
}
// 定义任务依赖和任务的本身的信息
func NewNode(TaskInfo TaskInfo, deps []int) *TaskNode {
taskNode := new(TaskNode)
taskNode.TaskInfo = TaskInfo
taskNode.OutCounter = 0
taskNode.InCounter = 0
taskNode.Done = false
taskNode.OutEdge = make(map[int]bool)
taskNode.InEdge = make(map[int]bool)
for _, dep := range deps {
taskNode.OutEdge[dep] = false
taskNode.OutCounter += 1
}
return taskNode
}
type TaskGraph struct {
graph map[int]*TaskNode
todo mapset.Set
}
func (taskGraph TaskGraph) New() TaskGraph {
if taskGraph.graph == nil {
taskGraph.graph = make(map[int]*TaskNode)
}
if taskGraph.todo == nil {
taskGraph.todo = mapset.NewSet()
}
return taskGraph
}
func (taskGraph *TaskGraph) AddTask(TaskInfo TaskInfo, deps []int) bool {
if _, v := taskGraph.graph[TaskInfo.TaskId]; v {
return false
}
taskNode := NewNode(TaskInfo, deps)
taskGraph.graph[TaskInfo.TaskId] = taskNode
return true
}
func (taskGraph *TaskGraph) InitGraph() bool {
graph := taskGraph.graph
topStack := list.New()
tmpOutCounter := make(map[int]int)
for TaskId, node := range graph {
for depIter := range node.OutEdge {
destIter := graph[depIter]
if destIter == nil {
return false
}
destIter.InEdge[TaskId] = false
destIter.InCounter += 1
}
tmpOutCounter[TaskId] = node.OutCounter
if node.OutCounter == 0 {
topStack.PushBack(TaskId)
}
}
topCount := 0
for topStack.Len() > 0 {
topCount++
item := topStack.Front()
topStack.Remove(item)
TaskId := item.Value.(int)
node := graph[TaskId]
for iter := range node.InEdge {
tmpOutCounter[iter] -= 1
if tmpOutCounter[iter] == 0 {
topStack.PushBack(iter)
}
}
}
if topCount != len(graph) {
return false
}
for iter, node := range graph {
if node.OutCounter == 0 {
taskGraph.todo.Add(iter)
}
}
return true
}
func (taskGraph *TaskGraph) GetTodoTasks() []int {
var todo []int
for taskId := range taskGraph.todo.Iter() {
todo = append(todo, taskId.(int))
}
return todo
}
func (taskGraph *TaskGraph) MarkTaskDone(taskId int) bool {
if !taskGraph.todo.Contains(taskId) {
return false
}
taskGraph.todo.Remove(taskId)
node := taskGraph.graph[taskId]
node.Done = true
for k := range node.InEdge {
from := taskGraph.graph[k]
from.OutEdge[taskId] = true
from.OutCounter -= 1
if from.OutCounter == 0 { // 出度为0, 依赖完全, 进入待办
taskGraph.todo.Add(k)
}
}
for k := range node.OutEdge {
dest := taskGraph.graph[k]
dest.InEdge[taskId] = true
dest.InCounter -= 1
}
return true
}
func (taskGraph *TaskGraph) DoTask(taskId int) {
node := taskGraph.graph[taskId]
if node.TaskInfo.TaskType == 0 {
fmt.Println("do rpc job")
} else if node.TaskInfo.TaskType == 1 {
fmt.Println("do merge sort")
} else if node.TaskInfo.TaskType == 2 {
fmt.Println("do split sort")
}
}
func (taskGraph *TaskGraph) TaskSchedule() {
for true {
todo := taskGraph.GetTodoTasks()
if len(todo) <= 0 {
break
}
wg := sync.WaitGroup{}
wg.Add(len(todo))
for i := 0; i < len(todo); i++ {
go func(taskId int) {
taskGraph.DoTask(taskId)
}(todo[i])
}
for i := 0; i < len(todo); i++ {
if !taskGraph.MarkTaskDone(todo[i]) {
fmt.Println(" markTaskDone fail")
return
}
}
}
}
func (taskGraph *TaskGraph) PrintGraph() {
fmt.Println("-----------------------------------")
for k, node := range taskGraph.graph {
fmt.Println("任务名:", k)
if node.Done {
fmt.Println("是否完成:", "YES")
} else {
fmt.Println("是否完成:", "NO")
}
fmt.Println("(当前)依赖这些任务:")
for taskName, v := range node.OutEdge {
if !v {
fmt.Print(" ", taskName, " ")
}
}
fmt.Println()
fmt.Println("(当前)被这些任务依赖:")
for taskName, v := range node.InEdge {
if !v {
fmt.Print(" ", taskName, " ")
}
}
fmt.Println()
}
fmt.Println("-----------------------------------")
}