worker端处理的主要类图
下面的图就是worker端的主要类图了,主要由接到一个task,再到判断task所属类型,替换参数的占位符最后再提交到集群的一个并将部分日志回写的过程~,(只是画一些自己认为是主要的类)
TaskExecutorThread
上一篇主要讲到了NettyRemotingClient与TaskExecuteProcessor这两个类,并且最后通过new 一个TaskExecutorThread 对象并将该对象丢给线程池帮运行,今天就看这类的代码,上代码凑字数:
/**
* task scheduler thread
*/
public class TaskExecuteThread implements Runnable {
/**
* logger
*/
private final Logger logger = LoggerFactory.getLogger(TaskExecuteThread.class);
/**
* task instance
*/
private TaskExecutionContext taskExecutionContext;
/**
* abstract task
*/
private AbstractTask task;
/**
* task callback service
*/
protected TaskCallbackService taskCallbackService;
/**
* taskExecutionContextCacheManager
*/
protected TaskExecutionContextCacheManager taskExecutionContextCacheManager;
protected WorkerSleepTaskContextCacheManager sleepTaskContextCacheManager;
/**
* constructor
*
* @param taskExecutionContext taskExecutionContext
* @param taskCallbackService taskCallbackService
*/
public TaskExecuteThread(TaskExecutionContext taskExecutionContext, TaskCallbackService taskCallbackService) {
this.taskExecutionContext = taskExecutionContext;
this.taskCallbackService = taskCallbackService;
this.taskExecutionContextCacheManager = SpringApplicationContext.getBean(TaskExecutionContextCacheManagerImpl.class);
this.sleepTaskContextCacheManager = WorkerSleepTaskContextCacheManager.getInstance();
if (this.taskCallbackService == null) {
this.taskCallbackService = SpringApplicationContext.getBean(TaskCallbackService.class);
}
}
@Override
public void run() {
TaskExecuteResponseCommand responseCommand = new TaskExecuteResponseCommand(taskExecutionContext.getTaskInstanceId());
try {
logger.info("script path : {}", taskExecutionContext.getExecutePath());
// task node
TaskNode taskNode = JSONObject.parseObject(taskExecutionContext.getTaskJson(), TaskNode.class);
// copy hdfs/minio file to local
downloadResource(taskExecutionContext.getExecutePath(),
taskExecutionContext.getResources(),
logger);
taskExecutionContext.setTaskParams(taskNode.getParams());
taskExecutionContext.setEnvFile(CommonUtils.getSystemEnvPath());
taskExecutionContext.setDefinedParams(getGlobalParamsMap());
// set task timeout
setTaskTimeout(taskExecutionContext, taskNode);
taskExecutionContext.setTaskAppId(String.format("%s_%s_%s",
taskExecutionContext.getProcessDefineId(),
taskExecutionContext.getProcessInstanceId(),
taskExecutionContext.getTaskInstanceId()));
// custom logger
Logger taskLogger = LoggerFactory.getLogger(LoggerUtils.buildTaskId(LoggerUtils.TASK_LOGGER_INFO_PREFIX,
taskExecutionContext.getProcessDefineId(),
taskExecutionContext.getProcessInstanceId(),
taskExecutionContext.getTaskInstanceId()));
task = TaskManager.newTask(taskExecutionContext,
taskLogger);
// task init
task.init();
// task handle
task.handle();
// task result process
task.after();
responseCommand.setStatus(task.getExitStatus().getCode());
responseCommand.setEndTime(new Date());
responseCommand.setProcessId(task.getProcessId());
responseCommand.setAppIds(task.getAppIds());
logger.info("task instance id : {},task final status : {},exitStatusCode:{}", taskExecutionContext.getTaskInstanceId(), task.getExitStatus(),task.getExitStatusCode());
} catch (Exception e) {
logger.error("task scheduler failure", e);
kill();
responseCommand.setStatus(ExecutionStatus.FAILURE.getCode());
responseCommand.setEndTime(new Date());
responseCommand.setProcessId(task.getProcessId());
responseCommand.setAppIds(task.getAppIds());
task.setExitStatusCode(Constants.EXIT_CODE_FAILURE);
} finally {
if (task.getExitStatusCode() != Constants.EXIT_CODE_SLEEP) {
taskExecutionContextCacheManager.removeByTaskInstanceId(taskExecutionContext.getTaskInstanceId());
ResponceCache.get().cache(taskExecutionContext.getTaskInstanceId(), responseCommand.convert2Command(), Event.RESULT);
taskCallbackService.sendResult(taskExecutionContext.getTaskInstanceId(), responseCommand.convert2Command());
} else {
if (task instanceof ExternalCheckTask) {
ExternalCheckTask checkTask = (ExternalCheckTask) this.task;
sleepTaskContextCacheManager.cacheSleepTask(System.currentTimeMillis() / 1000, checkTask.getCheckInterval(), checkTask);
}
}
}
}
/**
* get global paras map
*
* @return
*/
private Map<String, String> getGlobalParamsMap() {
Map<String, String> globalParamsMap = new HashMap<>(16);
if (taskExecutionContext.getScheduleParams() != null) {
globalParamsMap.putAll(taskExecutionContext.getScheduleParams());
}
// global params string
String globalParamsStr = taskExecutionContext.getGlobalParams();
if (globalParamsStr != null) {
List<Property> globalParamsList = JSONObject.parseArray(globalParamsStr, Property.class);
globalParamsMap.putAll(globalParamsList.stream().collect(Collectors.toMap(Property::getProp, Property::getValue)));
}
return globalParamsMap;
}
/**
* set task timeout
*
* @param taskExecutionContext TaskExecutionContext
* @param taskNode
*/
private void setTaskTimeout(TaskExecutionContext taskExecutionContext, TaskNode taskNode) {
// the default timeout is the maximum value of the integer
taskExecutionContext.setTaskTimeout(Integer.MAX_VALUE);
TaskTimeoutParameter taskTimeoutParameter = taskNode.getTaskTimeoutParameter();
if (taskTimeoutParameter.getEnable()) {
// get timeout strategy
taskExecutionContext.setTaskTimeoutStrategy(taskTimeoutParameter.getStrategy().getCode());
switch (taskTimeoutParameter.getStrategy()) {
case WARN:
break;
case FAILED:
if (Integer.MAX_VALUE > taskTimeoutParameter.getInterval() * 60) {
taskExecutionContext.setTaskTimeout(taskTimeoutParameter.getInterval() * 60);
}
break;
case WARNFAILED:
if (Integer.MAX_VALUE > taskTimeoutParameter.getInterval() * 60) {
taskExecutionContext.setTaskTimeout(taskTimeoutParameter.getInterval() * 60);
}
break;
default:
logger.error("not support task timeout strategy: {}", taskTimeoutParameter.getStrategy());
throw new IllegalArgumentException("not support task timeout strategy");
}
}
}
/**
* kill task
*/
public void kill() {
if (task != null) {
try {
task.cancelApplication(true);
} catch (Exception e) {
logger.error(e.getMessage(), e);
}
}
}
/**
* download resource file
*
* @param execLocalPath
* @param projectRes
* @param logger
*/
private void downloadResource(String execLocalPath,
Map<String, String> projectRes,
Logger logger) throws Exception {
if (MapUtils.isEmpty(projectRes)) {
return;
}
Set<Map.Entry<String, String>> resEntries = projectRes.entrySet();
for (Map.Entry<String, String> resource : resEntries) {
String fullName = resource.getKey();
String tenantCode = resource.getValue();
File resFile = new File(execLocalPath, fullName);
if (!resFile.exists()) {
try {
// query the tenant code of the resource according to the name of the resource
String resHdfsPath = HadoopUtils.getHdfsResourceFileName(tenantCode, fullName);
logger.info("get resource file from hdfs :{}", resHdfsPath);
HadoopUtils.getInstance().copyHdfsToLocal(resHdfsPath, execLocalPath + File.separator + fullName, false, true);
} catch (Exception e) {
logger.error(e.getMessage(), e);
throw new RuntimeException(e.getMessage());
}
} else {
logger.info("file : {} exists ", resFile.getName());
}
}
}
}
上一篇将任务执行上下文,与保存了与服务端通信句柄taskCallbackService都传进了该类的构造函数下,创建出了该类模板的一个对象,该类又是实现了runnable接口,所以主要的活也是在run方法下,看代码解释~
1.获取TaskNode,该Node下存放许多后面需要到的参数,得提出来。
2.如果任务上下文里携带有资源,得下载到当前节点当前租户的任务目录下(一般资源都是UDF函数,任务jar的吧)
3.再然后由做一些准备工作...
接着就是要根据任务类型创建对应的任务对象了,task=TaskManager.newTask(taskExecutionContext,taskLogger);看一下该TaksManager的newTask方法。
/**
* create new task
* @param taskExecutionContext taskExecutionContext
* @param logger logger
* @return AbstractTask
* @throws IllegalArgumentException illegal argument exception
*/
public static AbstractTask newTask(TaskExecutionContext taskExecutionContext,
Logger logger)
throws IllegalArgumentException {
switch (EnumUtils.getEnum(TaskType.class,taskExecutionContext.getTaskType())) {
case SHELL:
return new ShellTask(taskExecutionContext, logger);
case PROCEDURE:
return new ProcedureTask(taskExecutionContext, logger);
case SQL:
return new SqlTask(taskExecutionContext, logger);
case MR:
return new MapReduceTask(taskExecutionContext, logger);
case SPARK:
return new SparkTask(taskExecutionContext, logger);
case FLINK:
return new FlinkTask(taskExecutionContext, logger);
case PYTHON:
return new PythonTask(taskExecutionContext, logger);
case HTTP:
return new HttpTask(taskExecutionContext, logger);
case DATAX:
return new DataxTask(taskExecutionContext, logger);
case SQOOP:
return new SqoopTask(taskExecutionContext, logger);
default:
logger.error("unsupport task type: {}", taskExecutionContext.getTaskType());
throw new IllegalArgumentException("not support task type");
}
}
就是做了类型判断,然后创建出对应的类,接着就到了经典的模板设计模式。
task.init(),task.handle(),task.after();根据不同的task类型需要自己重写对应的方法。
…
最后在finally 下又委托taskCallbackService将任务的结果发送回master端。哈~~,又水一篇。