FfDL官方目前代码任务挂起仅仅实现将数据库里面的Status字段修改为HALTED,没有实现真正意义上的pod销毁、任务状态信息保留,resume接口更是没有实现。由于项目需要,需要实现这部分功能。
思路:采用restapi PatchModel进行服用,当消息是Halt的时候标记挂起操作,Resume标记恢复操作。
挂起实现逻辑:请求参数为训练id,从mongo数据库中获取需要挂起任务的信息,将status标记为HALT状态,调用lcm gRPC调用,lcm gRPC调用k8s api销毁该任务相关的pod。
恢复实现逻辑:请求参数为训练id,从mongo数据库中获取需要恢复的训练任务信息,查询status是否为HALT状态,若是,将status标记为pending,调用lcm的gRPC调用,调用k8s api创建训练任务相关pod
tensorflow实验,每1000次迭代保存ckpt文件,在训练过程中调用halt挂起任务,随后调用resume恢复任务,实验代码可以从保存的ckpt中恢复之前保存的训练结果,并继续训练
贴上修改代码:
From 89cc0990e6b012723bdbe0d981be3ba1ac00bbf4 Mon Sep 17 00:00:00 2001
From: James <mydota@qq.com>
Date: Mon, 18 Feb 2019 11:28:59 +0000
Subject: [PATCH] add trainer halt and resume
Signed-off-by: James <mydota@qq.com>
---
restapi/api_v1/server/models_impl.go | 3 +-
trainer/trainer/trainer_impl.go | 54 ++++++++++++++++++++++++----
2 files changed, 50 insertions(+), 7 deletions(-)
diff --git a/restapi/api_v1/server/models_impl.go b/restapi/api_v1/server/models_impl.go
index f0776a1..a04c6ed 100644
--- a/restapi/api_v1/server/models_impl.go
+++ b/restapi/api_v1/server/models_impl.go
@@ -795,7 +795,7 @@ func patchModel(params models.PatchModelParams) middleware.Responder {
logr := logger.LocLogger(logWithUpdateStatusParams(params))
logr.Debugf("patchModel invoked: %v", params.HTTPRequest.Header)
- if params.Payload.Status != "halt" {
+ if (params.Payload.Status != "halt" && params.Payload.Status != "resume") {
return models.NewPatchModelBadRequest().WithPayload(&restmodels.Error{
Error: "Bad request",
Code: http.StatusBadRequest,
@@ -814,6 +814,7 @@ func patchModel(params models.PatchModelParams) middleware.Responder {
TrainingId: params.ModelID,
UserId: getUserID(params.HTTPRequest),
Status: grpc_trainer_v2.Status_HALTED,
+ StatusMessage: params.Payload.Status,
})
//
if err != nil {
diff --git a/trainer/trainer/trainer_impl.go b/trainer/trainer/trainer_impl.go
index d34a4f0..721b5a2 100644
--- a/trainer/trainer/trainer_impl.go
+++ b/trainer/trainer/trainer_impl.go
@@ -70,7 +70,7 @@ const (
collectionNameTrainingJobs = "training_jobs"
collectionNameJobHistory = "job_history"
- debugLogsMode = false
+ debugLogsMode = true
oldEndpointInternalPageSize = 10
@@ -604,7 +604,7 @@ func (s *trainerService) CreateTrainingJob(ctx context.Context, req *grpc_traine
qHandler = s.queues["ANY"]
}
- rateLimited := true
+ rateLimited := true
qSize, err := qHandler.Size()
logGpuTypeQueueSize := fmt.Sprintf("%s_%s", gpuType, "queue_size")
logr.WithFields(logrus.Fields{
@@ -617,6 +617,7 @@ func (s *trainerService) CreateTrainingJob(ctx context.Context, req *grpc_traine
rateLimited = s.rateLimitTrainingJob(tr, logr)
}
+ //rateLimited = true
if rateLimited {
// either queue was not empty or rate-limiting was needed, so send this job to the queue
logr.Infof("training job %s is rate-limited, adding to queue %s", tr.TrainingID, gpuType)
@@ -733,7 +734,28 @@ func (s *trainerService) GetTrainingStatusID(ctx context.Context, req *grpc_trai
func (s *trainerService) UpdateTrainingJob(ctx context.Context, req *grpc_trainer_v2.UpdateRequest) (*grpc_trainer_v2.UpdateResponse, error) {
logr := logger.LocLogger(logWith(req.TrainingId, req.UserId))
- logr.Debugf("UpdateTrainingJob called for training %s", req.TrainingId)
+ logr.Debugf("UpdateTrainingJob called for training %s message %s", req.TrainingId, req.StatusMessage)
+
+ if(req.Status == grpc_trainer_v2.Status_HALTED) {
+ training, err := s.repo.Find(req.TrainingId)
+ if err != nil {
+ logr.WithError(err).Errorf("Cannot retrieve training '%s'", req.TrainingId)
+ return nil, err
+ }
+ ts := training.TrainingStatus
+ if (ts.Status == grpc_trainer_v2.Status_HALTED && req.StatusMessage == "resume") {
+ s.ResumeTrainingJob(ctx, &grpc_trainer_v2.ResumeRequest{
+ TrainingId: req.TrainingId,
+ UserId: req.UserId,
+ })
+ } else if (ts.Status != grpc_trainer_v2.Status_FAILED && ts.Status != grpc_trainer_v2.Status_COMPLETED && req.StatusMessage == "halt") {
+ s.HaltTrainingJob(ctx, &grpc_trainer_v2.HaltRequest{
+ TrainingId: req.TrainingId,
+ UserId: req.UserId,
+ })
+ }
+ return &grpc_trainer_v2.UpdateResponse{TrainingId: req.TrainingId}, nil
+ }
return updateTrainingJobPostLock(s, req)
}
@@ -1132,7 +1154,7 @@ func (s *trainerService) HaltTrainingJob(ctx context.Context, req *grpc_trainer_
logr.Debugf("Kubernetes job '%s' no longer exists.", job.JobId)
// update the status in mongo
- _, err = updateTrainingJobPostLock(s, &grpc_trainer_v2.UpdateRequest{
+/* _, err = updateTrainingJobPostLock(s, &grpc_trainer_v2.UpdateRequest{
TrainingId: req.TrainingId,
UserId: req.UserId,
Status: grpc_trainer_v2.Status_HALTED,
@@ -1142,7 +1164,15 @@ func (s *trainerService) HaltTrainingJob(ctx context.Context, req *grpc_trainer_
if err != nil {
logr.WithError(err).Errorln("Unable to update job status to halted")
return nil, err
- }
+ }*/
+ training, _ := s.repo.Find(req.TrainingId)
+ ts := training.TrainingStatus
+ ts.Status = grpc_trainer_v2.Status_HALTED
+ err = s.repo.Store(training)
+ if err != nil {
+ logr.WithError(err).Errorf("Failed updating status of training %s in DB", req.TrainingId)
+ return nil, err
+ }
return &grpc_trainer_v2.HaltResponse{TrainingId: job.JobId, UserId: job.UserId, Status: grpc_trainer_v2.Status_HALTED}, nil
}
@@ -1151,7 +1181,19 @@ func (s *trainerService) HaltTrainingJob(ctx context.Context, req *grpc_trainer_
func (s *trainerService) ResumeTrainingJob(ctx context.Context, req *grpc_trainer_v2.ResumeRequest) (*grpc_trainer_v2.ResumeResponse, error) {
logr := logger.LocLogger(logWith(req.TrainingId, req.UserId))
- logr.Debugf("HaltTrainingJob called")
+ logr.Debugf("ResumeTrainingJob called")
+
+ training, err := s.repo.Find(req.TrainingId)
+ if err != nil {
+ logr.WithError(err).Errorf("Cannot retrieve training '%s'", req.TrainingId)
+ return nil, err
+ }
+ err = s.submitJobToLCM(training, logr)
+ if err != nil {
+ // err logged in submitJobToLCM
+ return nil, err
+ }
+
return nil, gerrf(codes.Unimplemented, "ResumeTrainingJob not implemented yet")
}
--
2.17.1