Kai - Golang实现的目标检测云服务

YOLO/Darknet是目前比较流行的Object Detection算法(后面统一称为Darknet),在GPU上的表现不但速度快而且准确率很高。但是使用起来不方便,只提供了命令行接口和简单的Python接口。所以我想用RESTful来实现一个云端的Darknet服务kai

选择用Go的原因不是考虑并发,而是goroutine之间的同步能方便的处理,适合实现Pipeline的功能。问题来了,Darknet是c语言实现的,那Go必须得用cgo进行封装,才能调用c函数。目标是为了实现三个基本功能:1. 图片检测 2. 视频检测 3. 摄像头检测。为了方便使用我修改了Darknet的部分代码,然后重新定义下面几个函数:

// Set a gpu device
void set_gpu(int gpu);

// Recognize a image
void image_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename,
	float thresh, float hier_thresh, char *outfile);

// Recognize a video
void video_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename,
	float thresh, float hier_thresh, char *outfile);

// Recognize a camera stream
void camera_detector(char *datacfg, char *cfgfile, char *weightfile, int camindex,
        float thresh, float hier_thresh, char *outpath);
复制代码

有了这几个函数,就好办了,下面用cgo导入相应的库和头文件即可:

// #cgo pkg-config: opencv
// #cgo linux LDFLAGS: -ldarknet -lm -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand -lcudnn
// #cgo darwin LDFLAGS: -ldarknet
// #include "yolo.h"
import "C"

// SetGPU set a gpu device you want
func SetGPU(gpu int) {
	C.set_gpu(C.int(gpu))
}

// ImageDetector recognize a image
func ImageDetector(dc, cf, wf, fn string, t, ht float64, of ...string) {
    ...
}

// VideoDetector recognize a video
func VideoDetector(dc, cf, wf, fn string, t, ht float64, of ...string) {
    ...
}

// CameraDetector recognize a camera stream
func CameraDetector(dc, cf, wf string, i int, t, ht float64, of ...string) {
    ...
}
复制代码

这样对Darknet的封装go-yolo就完成了。

下面进入主题,介绍一下kai的实现。

kai的设计目标如下:

  • 后端基于Darknet(不支持训练)
  • 提供RESTful接口进行图片和视频的检测
  • 支持Amazon S3下载和上传
  • 支持Ftp下载和上传
  • 支持检测结果持久化到MongoDB

架构图是这样的

这里重点介绍一下Kai的Pipeline机制,这里的Pipeline包括下载(Download),检测(Yolo)和上(Upload)传这一系列流程。 先上个图:

这里的难点在于下载(Download),检测(Yolo)和上传(Upload)这三个步骤可以配置不同的Goroutine数量,而这三步之间是一个同步操作。

  1. 首先需要定义3个buffered channel来进行同步
// KaiServer represents the server for processing all job requests
type KaiServer struct {
	net.Listener
	logger        *logging.Logger
	config        types.ServerConfig
	listenAddr    string
	listenNetwork string
	router        *Router
	server        *http.Server
	db            db.Storage
	// jobDownBuff is the buffered channel for job downloading
	jobDownBuff chan types.Job
	// jobDownBuff is the buffered channel for job todo
	jobTodoBuff chan types.Job
	// jobDownBuff is the buffered channel for job done
	jobDoneBuff chan types.Job
}
复制代码
  1. Pipeline的执行流程如下
// Pipeline contains downloading, processing and uploading a job
func Pipeline(logger *logging.Logger, config types.ServerConfig, dbInstance db.Storage, jobDownBuff chan types.Job,
	jobTodoBuff chan types.Job, jobDoneBuff chan types.Job, job types.Job) {
	logger.Infof("pipeline-job %+v", job)

	// download a job
	setupAndDownloadJob(logger, config.System, dbInstance, job, jobDownBuff)

	// jobDownBuff -> jobTodoBuff -> jobDoneBuff
	yoloJob(logger, config, dbInstance, jobDownBuff, jobTodoBuff, jobDoneBuff)

	// upload a job
	uploadJob(logger, dbInstance, jobDoneBuff)
}
复制代码
  1. 下载(Download)
// setupAndDownloadJob setup and download jobs into jobDownBuff
func setupAndDownloadJob(logger *logging.Logger, config types.SystemConfig,
	dbInstance db.Storage, job types.Job, jobDownBuff chan<- types.Job) {

	go func() {
		logger.Infof("start setup and download a job: %+v", job)
		newJob, err := SetupJob(logger, job.ID, dbInstance, config)
		job = *newJob
		if err != nil {
			logger.Error("setup-job failed", err)
			return
		}

		downloadFunc := downloaders.GetDownloadFunc(job.Source)
		if err := downloadFunc(logger, config, dbInstance, job.ID); err != nil {
			logger.Error("download failed", err)
			job.Status = types.JobError
			job.Details = err.Error()
			dbInstance.UpdateJob(job.ID, job)
			return
		}

		jobDownBuff <- job
	}()
}
复制代码
  1. 检测(Yolo)
func yoloJob(logger *logging.Logger, config types.ServerConfig, dbInstance db.Storage,
	jobDownBuff <-chan types.Job, jobTodoBuff chan types.Job, jobDoneBuff chan types.Job) {

	go func() {
		job, ok := <-jobDownBuff
		if !ok {
			logger.Info("job download buffer is closed")
			return
		}
		logger.Infof("start a yolo job: %+v", job)
		// limit the number of job in the jobTodoBuff
		jobTodoBuff <- job
		jobTodo, ok := <-jobTodoBuff
		if !ok {
			logger.Info("job todo buffer is closed")
			return
		}

		nGpu := config.System.NGpu
		t := yolo.NewTask(config.Yolo, jobTodo.Media.Cate, nGpu, jobTodo.LocalSource, jobTodo.LocalDestination)
		logger.Debugf("yolo task: %+v", *t)
		yolo.StartTask(t, logger, dbInstance, jobTodo.ID)
		jobDoneBuff <- job
	}()
}
复制代码
  1. 上传(Upload)
func uploadJob(logger *logging.Logger, dbInstance db.Storage, jobDoneBuff <-chan types.Job) {
	go func() {
		jobDone, ok := <-jobDoneBuff
		if !ok {
			logger.Info("job done buffer is closed")
			return
		}
		logger.Infof("start a upload job: %+v", jobDone)

		uploadFunc := uploaders.GetUploadFunc(jobDone.Destination)
		if err := uploadFunc(logger, dbInstance, jobDone.ID); err != nil {
			logger.Error("upload failed", err)
			jobDone.Status = types.JobError
			jobDone.Details = err.Error()
			dbInstance.UpdateJob(jobDone.ID, jobDone)
			return
		}

		logger.Info("erasing temporary files")
		if err := CleanSwap(dbInstance, jobDone.ID); err != nil {
			logger.Error("erasing temporary files failed", err)
		}

		jobDone.Status = types.JobFinished
		dbInstance.UpdateJob(jobDone.ID, jobDone)

		logger.Infof("end a job: %+v", jobDone)
	}()
}
复制代码

到此,这个项目主要机制都已经介绍完了,如果大家有兴趣的可以去点击下面的项目主页。

项目链接: go-yolo kai

转载于:https://juejin.im/post/5abb03986fb9a028cf328276

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值