这是我的第315篇原创文章。
一、引言
用训练好的模型对新数据进行预测,在机器学习工程上有一个更专业的名词叫做「推理 Inference」。一般情况下,推理又分为:静态推理与动态推理。
静态推理通过集中对批量数据进行推理,并将结果存放在数据表或数据库中。当有需要的时候,再直接通过查询来获得推理结果。而动态推理一般将模型部署到服务器中。当有需要时,通过向服务器发送请求来获得模型返回的预测结果。
与静态推理不同的是,动态推理的过程是实时计算的,而静态推理是提前批量处理好的。静态推理适合于对大批量数据进行处理,因为动态推理面对大数据量时非常耗时。但是静态推理无法实时更新,而动态推理的结果是即时计算结果。
静态推理相信大家都很熟悉了,因为前面的内容中,我们对新数据预测实际上就类似于静态推理的过程。你只需要使用 scikit-learn 提供的 predict
操作即可完成。接下来,我们重点讨论动态推理的过程,并教你使用 RESTful API 的方式部署 scikit-learn 模型并完成动态推理。
二、项目结构
项目结构:
实现步骤:
-
执行train.py训练模型,把训练好的模型存储下来
-
基于FlaskWeb 应用框架构建一个 RESTful API,执行run.py 启动 Flask app
-
执行predict.py向目的地址发送请求并返回结果
注意:用Flask构建API的方法时,就是首先先训练模型,然后把模型给序列化了,当线上的测试数据来的时候,就直接使用已经训练好的模型,如果上线后模型表现效果不好,还是需要再训练模型的。
三、实现过程
3.1 训练保存模型
核心代码:
df = load_dataset("titanic") # 加载泰坦尼克数据集
X = df[["pclass", "sex", "embarked"]] # 特征
y = df["alive"] # 目标
X = pd.get_dummies(X) # 独热编码
model = RandomForestClassifier() # 随机森林
np.mean(cross_val_score(model, X, y, cv=5)) # 5 次交叉验证求平均
model.fit(X, y) # 训练模型
joblib.dump(model, "titanic.pkl") # 保存模型
模型保存的文件:
3.2 构建RESTful API
核心代码:
app = Flask(__name__)
@app.route('/')
def index():
return 'Please use the POST method to get predictions.'
@app.route("/", methods=["POST"]) # 请求方法为 POST
def predict():
try:
json_ = request.json # 解析请求数据
query_df = pd.DataFrame(json_) # 将 json 变为 DataFrame
columns_onehot = ["pclass", "sex_female", "sex_male",
"embarked_C", "embarked_Q", "embarked_S"] # 独热编码 DataFrame 列名
query = pd.get_dummies(query_df).reindex(
columns=columns_onehot, fill_value=0) # 将请求数据 DataFrame 处理成独热编码样式
clf = joblib.load("titanic.pkl") # 加载模型
predictions = clf.predict(query) # 模型推理
return jsonify({"predict": list(predictions)}) # 返回推理结果
except Exception as e:
return f"Error: {e}"
if __name__ == "__main__":
app.run(host='0.0.0.0', debug=True)
执行run.py,启动Flask App,网页访问127.0.0.1:5000,结果:
3.3 发送数据请求并返回预测结果
向服务器发送数据,并获得预测的结果:
# 向服务器发送请求获得预测结果
sample = [
{"pclass": 1, "sex": "male", "embarked": "C"},
{"pclass": 2, "sex": "female", "embarked": "S"},
{"pclass": 3, "sex": "male", "embarked": "Q"},
{"pclass": 3, "sex": "female", "embarked": "S"},
]
# 稍等片刻,Render 线上服务存在冷却启动时间
# requests.post(url="https://titanic-demo.onrender.com", json=sample).content
result = requests.post(url='http://127.0.0.1:5000', json=sample).content
print(result)
我们向服务器传送了4条记录,返回4个预测结果,分别为no、yes、no、no:
作者简介:
读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。