在本教程中,我们将简单了解FastAPI,并用它构建一个用于机器学习(ML)模型推理的API。随后,我们将结合Jinja2模板,打造一个完善的Web界面。这个项目简单有趣,即使你对API和Web开发的知识有限,也可以轻松上手并自主完成。
FastAPI简介
FastAPI是一个流行且现代的Web框架,用于用Python构建API。它专为高效、快速开发而设计,充分利用Python的标准类型提示,带来极佳的开发体验。FastAPI易学易用,只需少量代码即可开发高性能API。Uber、Netflix、微软等公司都在使用FastAPI构建API和应用程序。其设计非常适合为机器学习模型的推理和测试创建API端点,甚至可以通过集成Jinja2模板,打造完整的Web应用。
模型训练
我们将使用最常见的Iris(鸢尾花)数据集训练一个随机森林分类器。训练完成后,会展示模型评估指标,并将模型以pickle格式保存。
train_model.py:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
import joblib
# 加载鸢尾花数据集
iris = load_iris()
X, y = iris.data, iris.target
# 划分训练集与测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# 训练随机森林分类器
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
# 模型评估
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
report = classification_report(y_test, y_pred, target_names=iris.target_names)
print(f"模型准确率: {accuracy}")
print("分类报告:")
print(report)
# 保存模型
joblib.dump(clf, "iris_model.pkl")
执行命令:
$ python train_model.py
输出示例:
模型准确率: 1.0
分类报告:
precision recall f1-score support
setosa 1.00 1.00 1.00 10
versicolor 1.00 1.00 1.00 9
virginica 1.00 1.00 1.00 11
accuracy 1.00 30
macro avg 1.00 1.00 1.00 30
weighted avg 1.00 1.00 1.00 30
使用FastAPI构建ML推理API
接下来,我们将安装FastAPI和Uvicorn库,用于构建模型推理API。
$ pip install fastapi uvicorn
在app.py
文件中,我们将:
-
加载之前保存的模型。
-
创建输入和预测的Python类(需指定数据类型)。
-
编写预测函数,并使用
@app.post
装饰器将其定义为/predict
的POST端点。 -
预测函数接收
IrisInput
类的数据,并以IrisPrediction
类返回结果。 -
使用
uvicorn.run
函数运行应用,指定主机和端口。
app.py:
from fastapi import FastAPI
from pydantic import BaseModel
import joblib
import numpy as np
from sklearn.datasets import load_iris
# 加载已训练模型
model = joblib.load("iris_model.pkl")
app = FastAPI()
class IrisInput(BaseModel):
sepal_length: float
sepal_width: float
petal_length: float
petal_width: float
class IrisPrediction(BaseModel):
predicted_class: int
predicted_class_name: str
@app.post("/predict", response_model=IrisPrediction)
def predict(data: IrisInput):
# 转换输入为numpy数组
input_data = np.array(
[[data.sepal_length, data.sepal_width, data.petal_length, data.petal_width]]
)
# 预测
predicted_class = model.predict(input_data)[0]
predicted_class_name = load_iris().target_names[predicted_class]
return IrisPrediction(
predicted_class=predicted_class, predicted_class_name=predicted_class_name
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)
运行Python文件:
$ python app.py
FastAPI服务器启动后,通过链接即可访问:
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
默认首页暂无内容,只有/predict
POST请求端点,因此首页不会显示任何内容。
利用FastAPI测试API
我们可以通过SwaggerUI接口方便地测试API,只需在链接后加上/docs
即可访问。
在SwaggerUI中点击“/predict”,编辑输入数据并运行预测,在响应区域即可看到结果。例如,预测结果为“Virginica”。这样可以直接在Swagger界面测试模型,确保部署前一切运行正常。
为Web应用构建UI界面
除了Swagger UI,我们可以自定义一个简洁美观的网页界面。为此,需要在应用中集成Jinja2Templates。Jinja2Templates允许我们用HTML文件自定义网页组件,搭建真正的Web界面。
主要步骤如下:
-
初始化Jinja2Templates并指定HTML文件目录。
-
定义异步路由,将根路径("/")的访问返回index.html页面。
-
修改
predict
函数,接收Request和Form表单数据。 -
定义异步POST端点
/predict
,接收表单参数,预测结果后用TemplateResponse渲染result.html页面。 -
其余代码与前述类似。
引人入胜的标题:
集成Jinja2模板,开发网页界面
下面是集成Jinja2模板后的app.py
代码:
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
import joblib
import numpy as np
from sklearn.datasets import load_iris
# 加载已训练模型
model = joblib.load("iris_model.pkl")
# 初始化FastAPI
app = FastAPI()
# 配置模板目录
templates = Jinja2Templates(directory="templates")
# Pydantic数据模型
class IrisInput(BaseModel):
sepal_length: float
sepal_width: float
petal_length: float
petal_width: float
class IrisPrediction(BaseModel):
predicted_class: int
predicted_class_name: str
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/predict", response_model=IrisPrediction)
async def predict(
request: Request,
sepal_length: float = Form(...),
sepal_width: float = Form(...),
petal_length: float = Form(...),
petal_width: float = Form(...),
):
# 构造输入数组
input_data = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
# 预测
predicted_class = model.predict(input_data)[0]
predicted_class_name = load_iris().target_names[predicted_class]
return templates.TemplateResponse(
"result.html",
{
"request": request,
"predicted_class": predicted_class,
"predicted_class_name": predicted_class_name,
"sepal_length": sepal_length,
"sepal_width": sepal_width,
"petal_length": petal_length,
"petal_width": petal_width,
},
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)
创建前端HTML模板
-
在
app.py
同级目录下新建templates
文件夹。 -
在
templates
文件夹内,新建index.html
和result.html
两个文件。
index.html:用户输入界面
<!DOCTYPE html>
<html>
<head>
<title>Iris Flower Prediction</title>
</head>
<body>
<h1>Predict Iris Flower Species</h1>
<form action="/predict" method="post">
<label for="sepal_length">Sepal Length:</label>
<input type="number" step="any" id="sepal_length" name="sepal_length" required><br>
<label for="sepal_width">Sepal Width:</label>
<input type="number" step="any" id="sepal_width" name="sepal_width" required><br>
<label for="petal_length">Petal Length:</label>
<input type="number" step="any" id="petal_length" name="petal_length" required><br>
<label for="petal_width">Petal Width:</label>
<input type="number" step="any" id="petal_width" name="petal_width" required><br>
<button type="submit">Predict</button>
</form>
</body>
</html>
这份HTML代码会生成一个网页表单,方便用户输入鸢尾花的“萼片”和“花瓣”长度与宽度,然后通过POST请求提交到/predict
端点。
result.html:结果显示界面
<!DOCTYPE html>
<html>
<head>
<title>Prediction Result</title>
</head>
<body>
<h1>Prediction Result</h1>
<p>Sepal Length: {{ sepal_length }}</p>
<p>Sepal Width: {{ sepal_width }}</p>
<p>Petal Length: {{ petal_length }}</p>
<p>Petal Width: {{ petal_width }}</p>
<h2>Predicted Class: {{ predicted_class_name }} (Class ID: {{ predicted_class }})</h2>
<a href="/">Predict Again</a>
</body>
</html>
这段HTML将在页面上显示用户输入的各项参数和模型预测结果,包括预测类别ID与名称,并提供返回首页重新预测的按钮。
运行并体验Web应用
再次运行Python应用:
$ python app.py
你将看到如下信息:
INFO: Started server process [2932]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO: 127.0.0.1:63153 - "GET / HTTP/1.1" 200 OK
在浏览器访问 http://127.0.0.1:8000,首页将展示输入界面,用户可填写“萼片”和“花瓣”的长度与宽度,点击“Predict”按钮,跳转到结果页面,显示预测结果。同时可点击“Predict Again”回到首页继续测试。
更多资源
所有源码、数据、模型和相关信息可以在 GitHub 仓库 kingabzpro/FastAPI-for-ML 获取,欢迎 star ⭐ 支持。
结语
如今,许多大型企业都采用FastAPI来为其机器学习模型创建API端点,实现模型的无缝部署与集成。FastAPI开发速度快、编码简单、功能丰富,非常适合现代数据技术栈的需求。想进入这一领域,最好的方式就是多做项目并善于文档记录,这将帮你积累经验、提升能力,也便于初筛环节展示你的实力。招聘方会通过你的项目作品集评估是否适合团队。所以,不妨现在就开始用FastAPI搭建属于你的ML项目吧!