import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
import gradio as gr
# 加载鸢尾花数据集
iris = datasets.load_iris()
x = iris.data
y = iris.target
print(x)
print(y)
# 划分数据集为训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=702)
# 数据标准化
scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)
# 创建KNN分类器
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(x_train, y_train)
# 定义gradio界面函数
def classify_iris(sepal_length, sepal_width, petal_length, petal_width):
# 将输入转换为numpy数组
inputs = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
# 数据标准化
inputs_scaled = scaler.transform(inputs)
# 使用KNN模型进行预测
prediction = knn.predict(inputs_scaled)
# 返回预测的花的种类
return iris.target_names[prediction[0]]
# 创建gradio界面
iris_sort = gr.Interface(fn=classify_iris,
inputs=["number", "number", "number", "number"],
outputs="text",
examples=[
[5.1, 3.5, 1.4, 0.2],
[7.0, 3.2, 4.7, 1.4],
[6.3, 3.3, 6.0, 2.5]
],
title="鸢尾花分类器")
# 启动gradio界面
iris_sort.launch(share=True)
浏览器打开地址: