序列化sklearn model 为 json (sklearn-json)

前言

需求:导出sklearn训练好的算法模型为json格式,方便在不同编程语言间传递数据。
方案:使用 sklearn-json


安装 sklearn-json

pip install sklearn-json

注: 需要 scikit-learn >= 0.21.3


使用

序列化模型为json

以分类决策树为例子

from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
import sklearn_json as skljson

# data
wine = load_wine()

# train/test split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data, wine.target, test_size=0.3)

# train with deicision tree
clf = tree.DecisionTreeClassifier(criterion='gini', max_depth=5, random_state=0)
clf = clf.fit(Xtrain, Ytrain) # after fit, clf is the model

# save model to json
skljson.to_json(clf, "tree_model") # 重点重点重点

至此,分类决策树已经存成json格式。
json是肉眼可理解的,打开"tree_model"文件,看到如下:

{
  "meta": "decision-tree",
  "feature_importances_": [
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.4245297894999867,
    0.0,
    0.0,
    0.39702026297677956,
    0.02368943909521627,
    0.046465929707028834,
    0.10829457872098862
  ],
  "max_features_": 13,
  "n_classes_": 3,
  "n_features_": 13,
  "n_outputs_": 1,
  "tree_": {
    "max_depth": 3,
    "node_count": 11,
    "nodes": [
      [
        1,
        4,
        9,
        3.819999933242798,
        0.6619406867845994,
        124,
        124.0
      ],
      [
        2,
        3,
        11,
        3.694999933242798,
        0.08869659275283936,
        43,
        43.0
      ],
      [
        -1,
        -1,
        -2,
        -2.0,
        0.0,
        41,
        41.0
      ],
      [
        -1,
        -1,
        -2,
        -2.0,
        0.0,
        2,
        2.0
      ],
      [
        5,
        8,
        6,
        1.5800000429153442,
        0.5639384240207286,
        81,
        81.0
      ],
      [
        6,
        7,
        10,
        0.9699999988079071,
        0.054012345679012363,
        36,
        36.0
      ],
      [
        -1,
        -1,
        -2,
        -2.0,
        0.0,
        35,
        35.0
      ],
      [
        -1,
        -1,
        -2,
        -2.0,
        0.0,
        1,
        1.0
      ],
      [
        9,
        10,
        12,
        670.0,
        0.19753086419753085,
        45,
        45.0
      ],
      [
        -1,
        -1,
        -2,
        -2.0,
        0.0,
        5,
        5.0
      ],
      [
        -1,
        -1,
        -2,
        -2.0,
        0.0,
        40,
        40.0
      ]
    ],
    "values": [
      [
        [
          42.0,
          47.0,
          35.0
        ]
      ],
      [
        [
          2.0,
          41.0,
          0.0
        ]
      ],
      [
        [
          0.0,
          41.0,
          0.0
        ]
      ],
      [
        [
          2.0,
          0.0,
          0.0
        ]
      ],
      [
        [
          40.0,
          6.0,
          35.0
        ]
      ],
      [
        [
          0.0,
          1.0,
          35.0
        ]
      ],
      [
        [
          0.0,
          0.0,
          35.0
        ]
      ],
      [
        [
          0.0,
          1.0,
          0.0
        ]
      ],
      [
        [
          40.0,
          5.0,
          0.0
        ]
      ],
      [
        [
          0.0,
          5.0,
          0.0
        ]
      ],
      [
        [
          40.0,
          0.0,
          0.0
        ]
      ]
    ],
    "nodes_dtype": [
      "<i8",
      "<i8",
      "<i8",
      "<f8",
      "<f8",
      "<i8",
      "<f8"
    ]
  },
  "classes_": [
    0,
    1,
    2
  ],
  "params": {
    "ccp_alpha": 0.0,
    "class_weight": null,
    "criterion": "gini",
    "max_depth": 5,
    "max_features": null,
    "max_leaf_nodes": null,
    "min_impurity_decrease": 0.0,
    "min_impurity_split": null,
    "min_samples_leaf": 1,
    "min_samples_split": 2,
    "min_weight_fraction_leaf": 0.0,
    "presort": "deprecated",
    "random_state": 0,
    "splitter": "best"
  }
}

此json文件包含了所有的关于已经训练好的分类决策树模型的相关数据。
若想深入了解json中各个属性的含义,尤其是最核心的tree_.nodes,可以结合树的graphviz可视化去理解。

反序列化

此处给出在python中的反序列化

# 承接上面的代码

model = skljson.from_json("tree_model") # 重点重点重点

print(model.score(Xtrain, Ytrain)) # accuray of training dataset
print(model.score(Xtest, Ytest)) # accuracy of test dataset

print(model.predict(Xtest)) # prediction of test dataset

若想在其他语言比如java中解析json,可自行解决,网上方法有很多。

参考

目前sklearn-json已经对sklearn中多种常用的算法进行了json序列化和反序列化支持。

具体请看:
https://pypi.org/project/sklearn-json/

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值