weka平台的二次开发

安装与使用

weka的相关资料可以在https://ml.cms.waikato.ac.nz/weka/获取
下载后的软件界面如图
在这里插入图片描述

进行实验

在这里插入图片描述
点击“Explorer”,进入如下界面,打开文件,随机选择weka平台官方提供的数据集
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
之后选择分类,进去如下界面
在这里插入图片描述
下图中选择数据处理方式,cross——validation的Folds一般选择5
在这里插入图片描述
之后进行分类方法选择,例如贝叶斯分类
在这里插入图片描述
点击start可以得到分类结果
在这里插入图片描述
这里我们解释一下“Confusion Matrix”的含义。
=== Confusion Matrix ===
a b <-- classified as
170 31 | a = no-recurrence-events
47 38 | b = recurrence-events

这个矩阵是说,原本是“no-recurrence-events”的实例,有170个被正确的预测为“no-recurrence-events”,有64个错误的预测成了“recurrence-events”;原本是“recurrence”的实例,有47个被错误的预测为“no-recurrence-events”,有38个正确的预测成了“recurrence-events”。170+31+47+38= 286是实例总数,而(170+38)/286= 0.7272727正好是正确分类的实例所占比例。这个矩阵对角线上的数字越大,说明预测得越好。

多个实验结果比较

在这里插入图片描述
点击“Experimenter”,进入下图界面,之后点击“New”
在这里插入图片描述
之后点击“Browse”,选择文件保存的路径,
在这里插入图片描述
下面的红框中加入想要比较的数据集
在这里插入图片描述
下面的红框中选择想要进行对比的方法
在这里插入图片描述
在这里插入图片描述
进入Run界面,点击“Start”
在这里插入图片描述
进入Analyse界面,点击“File”,选择你之前保存的文件,例如我之前保存在桌面的test文件
在这里插入图片描述在这里插入图片描述
点击“Perform test”,右边出现对比结果表格
在这里插入图片描述

二次开发

步骤

打开Weka的安装路径,选择weka-src.jar文件,复杂到你想要实验的目录下,解压,基于此进行二次开发
在这里插入图片描述
用IntelliJ IDEA打开解压后的文件
在这里插入图片描述
实现自己的分类器,先在weka-src/src/main/java/weka/classifiers下新建一个Package,然后新建一个java文件
在这里插入图片描述
在这里插入图片描述
在gui包下面的GenericPropertiesCreator.props文件中增加你刚刚新建的包,例如我的包名称是xw
在这里插入图片描述打开gui包下面的GUIChooser文件,运行,可以直接打开weka平台
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
新建的java我命名为NB,打开weka平台,找到新建的NB方法,实现分类
在这里插入图片描述
在这里插入图片描述
一般情况下分类器只需要重载下面两个函数

public void buildClassifier(Instances instances) throws Exception
public double [] distributionForInstance(Instance instance) throws Exception

朴素贝叶斯实现代码

package weka.classifiers.xw;

import weka.core.*;
import weka.classifiers.*;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Discretize;

/**
 * Implement the NB classifier supporting both discrete and continuous attributes.
 */
public class NB extends AbstractClassifier {

    // 用于存储 朴素贝叶斯 属性参数
    protected List<Integer>[][] m_Distributions;

    // 用于存储 朴素贝叶斯 类别参数
    protected List<Integer> m_ClassDistribution;

    // 用于存储连续属性的均值和标准差
    protected double[][] m_Means;
    protected double[][] m_StdDevs;

    // 类别参数 的 种类数量
    protected int m_NumClasses;

    // 存储训练数据
    protected Instances m_Instances;


    public void buildClassifier(Instances instances) throws Exception {

        // 初始化训练数据
        m_Instances = instances;
        // 初始化类别参数 的 种类数量
        m_NumClasses = instances.numClasses();

        // 初始化 属性参数
        m_Distributions = new List[instances.numAttributes() - 1][m_NumClasses];
        m_Means = new double[instances.numAttributes() - 1][m_NumClasses];
        m_StdDevs = new double[instances.numAttributes() - 1][m_NumClasses];
        for (int i = 0; i < instances.numAttributes() - 1; i++) {
            for (int j = 0; j < m_NumClasses; j++) {
                m_Distributions[i][j] = new ArrayList<>();
                m_Means[i][j] = 0;
                m_StdDevs[i][j] = 0;
            }
        }

        // 初始化 类别参数
        m_ClassDistribution = new ArrayList<>();
        for (int i = 0; i < m_NumClasses; i++) {
            m_ClassDistribution.add(0);
        }

        // 获取属性参数的枚举类型
        Enumeration attributeEnumeration = instances.enumerateAttributes();
        // 遍历属性参数
        while (attributeEnumeration.hasMoreElements()) {
            // 获取属性参数
            Attribute attribute = (Attribute) attributeEnumeration.nextElement();
            // 获取属性参数的索引
            int attributeIndex = attribute.index();

            // 检查属性是否为离散型
            if (attribute.isNominal()) {
                // 获取属性参数的值的枚举类型
                Enumeration attributeValueEnumeration = attribute.enumerateValues();
                // 遍历属性参数的值
                while (attributeValueEnumeration.hasMoreElements()) {
                    // 获取属性参数的值
                    String attributeValue = (String) attributeValueEnumeration.nextElement();
                    // 遍历类别参数
                    for (int classIndex = 0; classIndex < m_NumClasses; classIndex++) {
                        // 初始化属性参数的某个值的某个类别参数的计数
                        m_Distributions[attributeIndex][classIndex].add(0);
                    }
                }
            }
        }

        // 遍历训练数据
        for (int instanceIndex = 0; instanceIndex < instances.numInstances(); instanceIndex++) {
            // 获取训练数据的实例
            Instance instance = instances.instance(instanceIndex);
            // 获取训练数据的类别参数的值
            int classValue = (int) instance.classValue();
            // 计数
            m_ClassDistribution.set(classValue, m_ClassDistribution.get(classValue) + 1);

            // 遍历属性参数
            for (int attributeIndex = 0; attributeIndex < instances.numAttributes() - 1; attributeIndex++) {
                // 获取属性参数
                Attribute attribute = instances.attribute(attributeIndex);
                if (attribute.isNominal()) {
                    // 获取训练数据的属性参数的值
                    int attributeValue = (int) instance.value(attributeIndex);

                    // 确保 m_Distributions[attributeIndex][classValue] 有足够的空间
                    while (m_Distributions[attributeIndex][classValue].size() <= attributeValue) {
                        m_Distributions[attributeIndex][classValue].add(0);
                    }

                    // 计数
                    m_Distributions[attributeIndex][classValue].set(attributeValue,
                            m_Distributions[attributeIndex][classValue].get(attributeValue) + 1);
                } else if (attribute.isNumeric()) {
                    // 更新均值和标准差
                    double value = instance.value(attributeIndex);
                    m_Means[attributeIndex][classValue] += value;
                    m_StdDevs[attributeIndex][classValue] += value * value;
                }
            }
        }

        // 计算连续属性的均值和标准差
        for (int attributeIndex = 0; attributeIndex < instances.numAttributes() - 1; attributeIndex++) {
            Attribute attribute = instances.attribute(attributeIndex);
            if (attribute.isNumeric()) {
                for (int classIndex = 0; classIndex < m_NumClasses; classIndex++) {
                    int count = m_ClassDistribution.get(classIndex);
                    if (count > 1) {
                        m_Means[attributeIndex][classIndex] /= count;
                        m_StdDevs[attributeIndex][classIndex] = Math.sqrt(
                                m_StdDevs[attributeIndex][classIndex] / count - m_Means[attributeIndex][classIndex] * m_Means[attributeIndex][classIndex]);
                    }
                }
            }
        }
    }

    /**
     * Calculates the class membership probabilities for the given test instance
     *
     * @param instance the instance to be classified
     * @return predicted class probability distribution
     * @exception Exception if there is a problem generating the prediction
     */
    public double[] distributionForInstance(Instance instance) throws Exception {
        // 初始化预测概率数组
        double[] predictionProbability = new double[m_NumClasses];

        // 遍历类别参数
        for (int classIndex = 0; classIndex < m_NumClasses; classIndex++) {
            // 初始化预测概率
            double prediction = 1;

            // 遍历属性参数
            for (int attributeIndex = 0; attributeIndex < m_Instances.numAttributes() - 1; attributeIndex++) {
                // 获取属性参数
                Attribute attribute = m_Instances.attribute(attributeIndex);
                if (attribute.isNominal()) {
                    // 获取属性参数的值
                    int attributeValue = (int) instance.value(attributeIndex);
                    // 获取当前属性可能的取值数
                    int attributeValueCount = m_Distributions[attributeIndex][classIndex].size();

                    // 检查索引是否有效
                    if (attributeValue >= attributeValueCount) {
                        throw new Exception("Attribute value index out of bounds: " + attributeValue);
                    }

                    // 计算条件概率P(x|c) (当前属性值在当前类别下占的比例) (拉普拉斯平滑)
                    double p_x_c = (double) (m_Distributions[attributeIndex][classIndex].get(attributeValue) + 1) /
                            (m_ClassDistribution.get(classIndex) + attributeValueCount);

                    // 计算预测概率
                    prediction *= p_x_c;
                } else if (attribute.isNumeric()) {
                    // 计算连续属性的概率密度函数值
                    double mean = m_Means[attributeIndex][classIndex];
                    double stdDev = m_StdDevs[attributeIndex][classIndex];
                    double value = instance.value(attributeIndex);
                    double p_x_c = (1 / (Math.sqrt(2 * Math.PI) * stdDev)) *
                                   Math.exp(-Math.pow(value - mean, 2) / (2 * Math.pow(stdDev, 2)));
                    prediction *= p_x_c;
                }
            }

            // 计算先验概率P(c) (当前类别占总类别的比例) (拉普拉斯平滑)
            double p_c = (double) (m_ClassDistribution.get(classIndex) + 1) /
                    (m_Instances.numInstances() + m_NumClasses);

            // 计算预测概率
            predictionProbability[classIndex] = prediction * p_c;
        }

        // 归一化
        Utils.normalize(predictionProbability);

        // 返回预测概率数组
        return predictionProbability;
    }

    public static void main(String[] argv) {
        runClassifier(new NB(), argv);
    }
}

数据准备

WEKA必须从CSV文件的第一行读取属性名,否则就会把第一行的各属性值读成变量名。因此我们对于Matllab给出的CSV文件需要用UltraEdit打开,手工添加一行属性名。注意属性名的个数要跟数据属性的个数一致,仍用逗号隔开。

.csv -> .arff

将CSV转换为ARFF最迅捷的办法是使用WEKA所带的命令行工具。
运行WEKA的主程序,出现GUI后可以点击下方按钮进入相应的模块。我们点击进入“Simple CLI”模块提供的命令行功能。在新窗口的最下方(上方是不能写字的)输入框写上
java weka.core.converters.CSVLoader filename.csv > filename.arff
即可完成转换。
在WEKA 3.5中提供了一个“Arff Viewer”模块,我们可以用它打开一个CSV文件将进行浏览,然后另存为ARFF文件。
进入“Exploer”模块,从上方的按钮中打开CSV文件然后另存为ARFF文件亦可。

  • 27
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值