机器学习决策树(实战)

文章介绍了三种导入鸢尾花数据集的方法,包括使用scikit-learn的load_iris函数和pandas的read_csv函数。之后,数据被分为特征和标签,经过预处理后进行训练和测试。模型使用了train_test_split进行划分,并计算了测试集上的准确率。最后,模型被保存为joblib格式,并展示了如何在Java项目中使用weka进行加载和预测。
摘要由CSDN通过智能技术生成

导入数据:

方法一:使用封装好的类。Scikit-learn 是一个用于机器学习的库,datasets 模块提供了一些常用的数据集供学习和实验使用,其中的 load_iris 方法类用于加载鸢尾花数据集(Iris 数据集)。

from sklearn.datasets import load_iris  # 导入方法类
 
iris = load_iris()  #导入数据集iris
iris_feature = iris.data    #特征数据
iris_target = iris.target   #分类数据
print (iris.data)          #输出数据集
print (iris.target)        #输出真实标签
print (len(iris.target) )
print (iris.data.shape )   #150个样本 每个样本4个特征

方法二:用uml,下载数据

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class']
dataset = pandas.read_csv(url, names=names) #读取csv数据
print(dataset.describe())

方法三:下载数据集到本地。注意这里用的数据集第一行没有标签名(header指标签名)。

只要文件内是csv格式,函数就能正确读取,不一定必须是csv文件

 path = 'iris.data'  # 数据文件路径
    data = pd.read_csv(path, header=None)

数据预处理:

这里把前四列作为特征,第五列作为标签(且把第五列转化为整数编码)

 #方法一:
x = data[list(range(4))]
    y = LabelEncoder().fit_transform(data[4])   #讲栾尾花类别编码
 
    x = x.iloc[:, :4]
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=1)

#方法二:

iris = load_iris()  #导入数据集iris
iris_feature = iris.data    #特征数据
iris_target = iris.target   #分类数据

确定更适合的深度和子树个数:http://t.csdn.cn/4px4x

 训练与测试:

new_data = [[1.0, 1.0, 2.0,0.1], [0.5, 1.5, 2.5,0.2]]   

测试时要输入二维数组,每一项要包含四个特征

    #训练
    clf.fit(x_train, y_train)

    # # 使用训练好的 clf 对新数据进行预测
    # new_data = [[1.0, 1.0, 2.0,0.1],
    #      [0.5, 1.5, 2.5,0.2]]
    #
    # predictions = clf.predict(new_data)
    #
    # # 查看预测结果
    # print(predictions)

    # 使用测试集进行预测
    y_pred = clf.predict(x_test)

    # 在测试集上评估模型性能
    from sklearn.metrics import accuracy_score

    accuracy = accuracy_score(y_test, y_pred)
    print("测试集准确率:", accuracy)

保存训练好的模型:

# 保存模型到文件
model_filename = 'trained_model.joblib'
joblib.dump(clf, model_filename)

把模型整合入java项目

import weka.core.SerializationHelper;
import weka.core.Instance;
import weka.core.Instances;
import weka.classifiers.trees.J48;

public class Main {
    public static void main(String[] args) {
        try {
            // 加载模型
            String modelFilename = "trained_model.joblib";
            J48 model = (J48) SerializationHelper.read(modelFilename);

            // 创建新的 Instance 对象,填入特征数据
            Instance instance = new Instance(4); // 假设特征数为4
            instance.setDataset(getDataset()); // 设置数据集

            // 填入特征数据,假设特征值为 5.1, 3.5, 1.4, 0.2
            instance.setValue(0, 5.1);
            instance.setValue(1, 3.5);
            instance.setValue(2, 1.4);
            instance.setValue(3, 0.2);

            // 进行预测
            double prediction = model.classifyInstance(instance);

            // 输出预测结果
            System.out.println("预测结果:" + prediction);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    // 创建一个空的 Instances 对象,假设特征数为4
    private static Instances getDataset() {
        Instances dataset = new Instances("TestInstances", null, 4);
        dataset.setClassIndex(dataset.numAttributes() - 1); // 设置目标属性的索引
        return dataset;
    }
}

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值