在Python中把机器学习模型转成API的代码样例

前因

今天想了解系统如何调用机器学习(或深度学习)算法实现识别分类,在CDSN上搜到了这篇文章,让我受益匪浅。全文参考至这篇文章,只是代码有些许修改,大家一定要先看这篇文章,整体思路讲得非常好。

美中不足的是这篇文章的代码我跑的时候还是出现了一些问题,且没有测试代码,修改之后发个文章,供大家学习参考。

具体错误我就不分析了,个人比较懒,放上代码就溜了。

模型背景

这里我们以Kaggle上最受欢迎的数据集——泰坦尼克 为例进行讲解。这个数据集主要是个分类问题,我们的任务是根据表格数据预测乘客的生存概率。为了进一步简化,我们只用四个变量:age(年龄)、sex(性别)、embarked(登船港口:C=Cherbourg, Q=Queenstown, S=Southampton)和survived。其中survived是个类别标签。

总体测试流程

1.运行model.py,训练模型,将训练好的模型打包成pkl文件;
2.运行api.py,打开web服务,加载模型,监听端口;
3.另开一个终端,运行测试代码,发送post请求,终端输出返回的模型预测结果。

1. 模型代码model.py

# Import dependencies
import pandas as pd
import numpy as np

# Load the dataset in a dataframe object and include only four features as mentioned
url = "http://s3.amazonaws.com/assets.datacamp.com/course/Kaggle/train.csv"
df = pd.read_csv(url)
include = ['Age', 'Sex', 'Embarked', 'Survived'] # Only four features
df_ = df[include]

# 往下5行的代码不能运行,自己根据df_制作可迭代数据
# iterable = [str(x) for x in df_.dtypes]

# Data Preprocessing
categoricals = []
for col, col_type in zip(include, df_.dtypes):
     if str(col_type) == 'object':
          categoricals.append(col)
     else:
          df_[col].fillna(0, inplace=True)

df_ohe = pd.get_dummies(df_, columns=categoricals, dummy_na=True)

# Logistic Regression classifier
from sklearn.linear_model import LogisticRegression
dependent_variable = 'Survived'
x = df_ohe[df_ohe.columns.difference([dependent_variable])]
y = df_ohe[dependent_variable]
lr = LogisticRegression()
lr.fit(x, y)

# Save your model
import joblib
joblib.dump(lr, 'model.pkl')
print("Model dumped!")

# Load the model that you just saved
lr = joblib.load('model.pkl')

# Saving the data columns from training
model_columns = list(x.columns)
joblib.dump(model_columns, 'model_columns.pkl')
print("Models columns dumped!")

2.服务端代码api.py

from flask import Flask, request, jsonify
import joblib
import traceback
import pandas as pd
import numpy as np
import sys
import json

app = Flask(__name__)

@app.route('/')
def hello():
    return "这是根目录!"

@app.route('/predict', methods=['POST']) # Your API endpoint URL would consist /predict
def predict():
    if lr:
        try:
            with open('request.txt','w') as f:
                f.write(request.json)

            json_ = request.json
            query = pd.get_dummies(pd.DataFrame(json.loads(json_)))
            query = query.reindex(columns=model_columns, fill_value=0)

            prediction = list(lr.predict(query))

            return jsonify({'prediction': str(prediction)})

        except:

            return jsonify({'trace': traceback.format_exc()})
    else:
        print ('Train the model first')
        return ('No model here to use')

if __name__ == '__main__':
    try:
        port = int(sys.argv[1]) # This is for a command-line input
    except:
        port = 12345 # If you don't provide any port the port will be set to 12345

    lr = joblib.load("model.pkl") # Load "model.pkl"
    print ('Model loaded')
    model_columns = joblib.load("model_columns.pkl") # Load "model_columns.pkl"
    print ('Model columns loaded')

    app.run(port=port, debug=True)

3. 测试代码my_request.py

import requests
import json

# POST 请求的 URL
url = 'http://localhost:12345/predict'

# POST 请求的参数(表单数据)
data = [
    {"Age": 85, "Sex": "male", "Embarked": "S"},
    {"Age": 24, "Sex": "female", "Embarked": "C"},
    {"Age": 3, "Sex": "male", "Embarked": "C"},
    {"Age": 21, "Sex": "male", "Embarked": "S"}
]

# 发送 POST 请求
response = requests.post(url, json=json.dumps(data))

# 检查响应状态码
if response.status_code == 200:
    # 打印响应内容
    print(response.text)
else:
    print('请求失败')

2023-6-6补充
这里在my_request.py发起的请求是list,发起post请求时转换为了string格式,所以api.py里直接获取内容写入txt和加载为json是没问题的。
但如果是从网页端发起请求(例如postman测试),需要做一定修改。
这里在my_request.py文件中不再将数据转化为string,在json.loads处改为加载request.data。代码如下:

# my_request.py
# 发送 POST 请求
response = requests.post(url, json=data)

# api.py
        try:
            with open('request.txt','w') as f:
                f.write(str(request.data))

            # json_ = request.json 拿到的是数据原型

            query = pd.get_dummies(pd.DataFrame(json.loads(request.data)))

这样flask就能处理list数据了,但是处理字符串数据的话又会有一点问题(在json.loads处),后续用到再来补充吧…

参考文献

[1]:在Python中把机器学习模型转成API的具体步骤

  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

只想睡觉111

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值