使用Go语言上线TensorFlow模型

昨天搞了一天用Go语言部署TensorFlow模型,把整个过程记录一下,以备大家参考(现在没有题图,以后在搞一个图)。

首先我们要有一个已经保存好的TensorFlow模型,也就是.pb文件。这个文件固化了计算图和权重,Go语言只需要根据这个代码跑相应的Session就行了。关于如何产生.pb文件,如果大家有兴趣的话可以私信我,我可以根据大家的需求情况写一份文档。具体部分可以参见tf.saved_model

然后编译TF的源代码得到libtensorflow.so和libtensorflow_framework.so。也可以用官网上的下载链接(我没用过,大家可以尝试一下)。需要注意的是请务必保证保存模型的TF版本和这个动态链接库的TF版本一致,不然的话后面的Go代码可能会挂(大坑)。如果需要编译的话可以参考TF的官方文档,如果有兴趣的话同上请私信我。

有了这个东西,为了让ld能够找到这两个文件,Linux上需要设置$LIBRARY_PATH和$LD_LIBRARY_PATH这两个环境变量。

export LIBRARY_PATH=[.so文件所在的目录]
export LD_LIBRARY_PATH=[.so文件所在的目录]

然后是下载我们的依赖包。可以使用下边的命令。第一个是下载依赖,第二个是测试下载的依赖有没有问题。如果第二个出错,就证明前面的步骤有问题。

go get github.com/tensorflow/tensorflow/tensorflow/go
go get github.com/tensorflow/tensorflow/tensorflow/go

接下来就可以愉快的载入模型开始玩了。下面是载入模型的示例代码。载入模型的时候需要给模型所在的文件夹和模型的名字(模型的名字可以用saved_model_cli这个工具来查看)。后面的一段是我自己家的,意思是打印出当前模型图里面所有的Operator。这个代码返回一个tf.SavedModel的struct,这个struct有两个成员,第一个是Session,第二个是Graph。如果大家对于TF的python API很熟应该知道这两个是什么东西。

func LoadModel(modelPath string, modelNames []string) *tf.SavedModel {
    model, err := tf.LoadSavedModel(modelPath, modelNames, nil) // 载入模型
    if err != nil {
        log.Fatal("LoadSavedModel(): %v", err)
    }

    log.Println("List possible ops in graphs") // 打印出所有的Operator
    for _, op := range model.Graph.Operations() {
        //log.Printf("Op name: %v, on device: %v", op.Name(), op.Device())
        log.Printf("Op name: %v", op.Name())
    }
    return model
}

有了Session和Graph之后,我们就能跑这个模型了。我这边用的是gin这个web框架,直接把输入的JSON编码成TensorFlow接受的输入,然后调用Session.Run方法来跑整个计算图。这个方法传三个参数,第一个参数是一个map,把每个tf.Output类型映射成一个tf.Tensor。前面一个在知道输入的Operator的情况下,可以通过Operator.Output(0)方法拿到,后面一个,可以使用tf.NewTensor这个函数,传入输入的Go数组来生成。如果大家熟悉TensorFlow的Python API的话,我们会发现,第一个类似与feed_dict这个参数。第二个参数是输出的张量的列表。我们同样可以在拿到Operator以后,通过Operator.Output(0)方法拿到,注意要把他们包装成一个[]Output类型,即使里面只有一个元素。第三个是不执行的Operator的列表,这里我们设置成nil。

func main () {
    m := LoadModel("../freeze_model", []string{"serve"})
    s := m.Session
    // ...
    ServeJSON := func (c *gin.Context) {
        var json map[string] int64
        if c.BindJSON(&json) == nil {
            log.Println(json)
        }
        ret, err:= s.Run(MapGraphInputs(CreateMapFromJSON(json), m),
            GetGraphOutputs([]string{"prob"}, m), nil)
        if err != nil {
            log.Fatal("Error in executing graph...", err)
        }
        // ...
    }
    // ...
}

然后我们编译一下源代码,跑一下,发现gin框架起来了,我们就可以用这个可执行文件做web服务了~这个可执行文件和.so文件,以及模型文件一起,完全可以一起copy到docker的container里面,这样就可以用k8s愉快的和这个模型玩耍了

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值