这个模块我测试了,可以使用,可得到一个目标类型语言的代码,但是跟原来的结果(输出的score值)不太一样,后来经过自己修改,输出的模型可以使用,我的机器学习模型是目标检测。
m2cgen (Model 2 Code Generator) - is a lightweight library which provides an easy way to transpile trained statistical models into a native code (Python, C, Java, Go).
- Installation
- Supported Languages
- Supported Models
- Classification Output
- Usage
Installation
Supported Python version is >= 3.4.
pip install m2cgen
Supported Languages
Python
Java
C
Go
Classification Output
Linear/Linear SVM
Binary
Scalar value; signed distance of the sample to the hyperplane for the second class.
Multiclass
Vector value; signed distance of the sample to the hyperplane per each class.
Comment
The output is consistent with the output of LinearClassifierMixin.decision_function.
SVM
Binary
Scalar value; signed distance of the sample to the hyperplane for the second class.
Multiclass
Vector value; one-vs-one score for each class, shape (n_samples, n_classes * (n_classes-1) / 2).
Comment
The output is consistent with the output of BaseSVC.decision_function when the decision_function_shape is set to ovo.
Tree/Random Forest/XGBoost/LightGBM
Binary
Vector value; class probabilities.
Multiclass
Vector value; class probabilities.
Comment
The output is consistent with the output of the predict_proba method of DecisionTreeClassifier/ForestClassifier/XGBClassifier/LGBMClassifier.
Usage
Here’s a simple example of how a linear model trained in Python environment can be represented in Java code:
from sklearn.datasets import load_boston
from sklearn import linear_model
import m2cgen as m2c
boston = load_boston()
X, y = boston.data, boston.target
estimator = linear_model.LinearRegression()
estimator.fit(X, y)
code = m2c.export_to_java(estimator)
output of java code:
public class Model {
public static double score(double[] input) {
return (((((((((((((36.45948838508965) + ((input[0]) * (-0.10801135783679647))) + ((input[1]) * (0.04642045836688297))) + ((input[2]) * (0.020558626367073608))) + ((input[3]) * (2.6867338193449406))) + ((input[4]) * (-17.76661122830004))) + ((input[5]) * (3.8098652068092163))) + ((input[6]) * (0.0006922246403454562))) + ((input[7]) * (-1.475566845600257))) + ((input[8]) * (0.30604947898516943))) + ((input[9]) * (-0.012334593916574394))) + ((input[10]) * (-0.9527472317072884))) + ((input[11]) * (0.009311683273794044))) + ((input[12]) * (-0.5247583778554867));
}
}