导入数据:
方法一:使用封装好的类。Scikit-learn 是一个用于机器学习的库,datasets
模块提供了一些常用的数据集供学习和实验使用,其中的 load_iris
方法类用于加载鸢尾花数据集(Iris 数据集)。
from sklearn.datasets import load_iris # 导入方法类
iris = load_iris() #导入数据集iris
iris_feature = iris.data #特征数据
iris_target = iris.target #分类数据
print (iris.data) #输出数据集
print (iris.target) #输出真实标签
print (len(iris.target) )
print (iris.data.shape ) #150个样本 每个样本4个特征
方法二:用uml,下载数据
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class']
dataset = pandas.read_csv(url, names=names) #读取csv数据
print(dataset.describe())
方法三:下载数据集到本地。注意这里用的数据集第一行没有标签名(header指标签名)。
只要文件内是csv格式,函数就能正确读取,不一定必须是csv文件
path = 'iris.data' # 数据文件路径
data = pd.read_csv(path, header=None)
数据预处理:
这里把前四列作为特征,第五列作为标签(且把第五列转化为整数编码)
#方法一:
x = data[list(range(4))]
y = LabelEncoder().fit_transform(data[4]) #讲栾尾花类别编码
x = x.iloc[:, :4]
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=1)
#方法二:
iris = load_iris() #导入数据集iris
iris_feature = iris.data #特征数据
iris_target = iris.target #分类数据
确定更适合的深度和子树个数:http://t.csdn.cn/4px4x
训练与测试:
new_data = [[1.0, 1.0, 2.0,0.1], [0.5, 1.5, 2.5,0.2]]
测试时要输入二维数组,每一项要包含四个特征
#训练
clf.fit(x_train, y_train)
# # 使用训练好的 clf 对新数据进行预测
# new_data = [[1.0, 1.0, 2.0,0.1],
# [0.5, 1.5, 2.5,0.2]]
#
# predictions = clf.predict(new_data)
#
# # 查看预测结果
# print(predictions)
# 使用测试集进行预测
y_pred = clf.predict(x_test)
# 在测试集上评估模型性能
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_test, y_pred)
print("测试集准确率:", accuracy)
保存训练好的模型:
# 保存模型到文件
model_filename = 'trained_model.joblib'
joblib.dump(clf, model_filename)
把模型整合入java项目
import weka.core.SerializationHelper;
import weka.core.Instance;
import weka.core.Instances;
import weka.classifiers.trees.J48;
public class Main {
public static void main(String[] args) {
try {
// 加载模型
String modelFilename = "trained_model.joblib";
J48 model = (J48) SerializationHelper.read(modelFilename);
// 创建新的 Instance 对象,填入特征数据
Instance instance = new Instance(4); // 假设特征数为4
instance.setDataset(getDataset()); // 设置数据集
// 填入特征数据,假设特征值为 5.1, 3.5, 1.4, 0.2
instance.setValue(0, 5.1);
instance.setValue(1, 3.5);
instance.setValue(2, 1.4);
instance.setValue(3, 0.2);
// 进行预测
double prediction = model.classifyInstance(instance);
// 输出预测结果
System.out.println("预测结果:" + prediction);
} catch (Exception e) {
e.printStackTrace();
}
}
// 创建一个空的 Instances 对象,假设特征数为4
private static Instances getDataset() {
Instances dataset = new Instances("TestInstances", null, 4);
dataset.setClassIndex(dataset.numAttributes() - 1); // 设置目标属性的索引
return dataset;
}
}