Python 是当下最流行的机器学习语言,大多数机器学习从业者都是直接使用 Python 工作,所以有很多开源的资源可以使用。而Go 语言的速度很快,能很好地处理并发,可以编译成单一的二进制文件。所以在实际开发时我们往往想利用二者的优势,用Python做模型训练,用Go做预测服务。
GO语言TensorFlow环境配置,之前已经有介绍 Go TensorFlow 环境配置
那么今天我们就来实现Go 语言部署Python 训练机器模型。
首先需要注意Python模型保存时需要保存为指定的格式:
from tensorflow.python.saved_model.builder_impl import SavedModelBuilder
with tf.Session() as session:
# 训练模型操作。。。
# 保存模型
builder = SavedModelBuilder("存储路径")
# 保存时需要定义tag
builder.add_meta_graph_and_variables(session, ["tag"])
builder.save()
Go语言运行模型,
具体代码:
package main
import (
"fmt"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
func main() {
m, err := tf.LoadSavedModel("modelPath", []string{"modelTag"}, nil) // 载入模型
if err != nil {
// 模型加载失败
fmt.Printf("err: %v", err)
}
// 打印出所有的Operator
for _, op := range m.Graph.Operations() {
fmt.Printf("Op name: %v", op.Name())
}
// 构造输入Tensor。根据你的模型入参格式来定义
x := [1][8]int32{
{0,1,2,3,4,5,6,7},
}
tensor_x, err := tf.NewTensor(x)
if err != nil {
fmt.Printf("err: %s", err.Error())
return
}
kb, err := tf.NewTensor(float32(1))
if err != nil {
fmt.Printf("err: %s", err.Error())
return
}
s := m.Session
feeds := map[tf.Output]*tf.Tensor{
// operation name 需要根据你的模型入参来写
m.Graph.Operation("input_x").Output(0): tensor_x,
m.Graph.Operation("keep_prob").Output(0): kb,
}
fetches := []tf.Output{
// 输出层的name 也要根据你的模型写
m.Graph.Operation("score/ArgMax").Output(0),
}
result, err:= s.Run(feeds, fetches,nil)
if err != nil {
// 模型预测失败
fmt.Printf("err: %s ", err.Error())
}
fmt.Printf("%#v", result)
}
需要注意的是:
- 载入模型时,需要传参:模型保存路径,Python保存模型时定义的Tag。这里tag 需要为[]string{}类型。
- 模型入参需要通过tf.NewTensor() 转为Tensor;
- 输入输出节点操作名称需要根据Python定义的模型操作节点名称填写;
- 输出节点和Python类似,可以传递多个操作名获取多个值。
【来思Go】,Let's Go!欢迎关注留言交流学习!