基于weka的中文文本分类(java版)

      

        本例子是用springboot完成,基于weka实现中文文本分类 , 本例子只是一个简单版,可以在简单版基础上进行扩展分词后再分类,最后达到程序帮我们分词、分类,大大提高效率、简化了人工分类的成本。

        首先什么是weka,百度给了解释:是一款免费的,非商业化(与之对应的是SPSS公司商业数据挖掘产品--Clementine )的,基于JAVA环境下开源的机器学习(machine learning)以及数据挖掘(data mining)软件。

         但是软件使用java写的,软件中有部分是做文本分类的,我们可以把它提取出来自己用,首先先创建一个springboot项目,如果不清楚springboot如何创建,先看我的博客  https://blog.csdn.net/sinat_23225111/article/details/77984344

           weka是怎么实现文本分类的呢?   其实可以简单理解为把所有分类、和每一类具体包括哪些词存到一个文件 weka称这个文件叫ARFF文件,然后把这个ARFF文件进行训练、成一个压缩版的模型,一个新词来了,就和模型就行比对,如果模型中有这个词,weka就把这个新词分到对应的分类中,举个例子   城市是一类,城市包括:深圳、北京、上海,先把这些词写到ARFF中,然后训练模型,  如果有个词是深圳的话 ,根据训练好的模型比对,程序就知道它属于城市这一分类。

         那ARFF文件是什么格式?见博客 https://blog.csdn.net/sinat_23225111/article/details/77983583

          再简单说就是机器是不知道怎么分类的,但是如果我们告诉他什么词属于什么类,当那个词再次被输入的时候,机器就知道它属于哪一类了。下面我们来看具体怎么实现。

          首先现在springboot中导入依赖,重要的是weka的依赖(已标红)

	<parent>
		<groupId>org.springframework.boot</groupId>
		<artifactId>spring-boot-starter-parent</artifactId>
		<version>1.5.9.RELEASE</version>
		<relativePath/> <!-- lookup parent from repository -->
	</parent>

	<properties>
		<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
		<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
		<java.version>1.8</java.version>
	</properties>

	<dependencies>
		<dependency>
			<groupId>org.mybatis.spring.boot</groupId>
			<artifactId>mybatis-spring-boot-starter</artifactId>
			<version>1.3.1</version>
		</dependency>

		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-web</artifactId>
			<exclusions>
				<exclusion>
					<groupId>org.springframework.boot</groupId>
					<artifactId>spring-boot-starter-logging</artifactId>
				</exclusion>
			</exclusions>
		</dependency>

		<dependency>
			<groupId>mysql</groupId>
			<artifactId>mysql-connector-java</artifactId>
			<scope>runtime</scope>
		</dependency>
		<dependency>
			<groupId>org.projectlombok</groupId>
			<artifactId>lombok</artifactId>
			<optional>true</optional>
		</dependency>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-test</artifactId>
			<scope>test</scope>
		</dependency>

		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-log4j2</artifactId>
		</dependency>

		<dependency>
			<groupId>org.apache.commons</groupId>
			<artifactId>commons-lang3</artifactId>
			<version>3.3.2</version>
		</dependency>
		<dependency>
			<groupId>nz.ac.waikato.cms.weka</groupId>
			<artifactId>weka-stable</artifactId>
			<version>3.8.1</version>
		</dependency>
		<dependency>
			<groupId>com.alibaba</groupId>
			<artifactId>druid</artifactId>
			<version>1.0.11</version>
		</dependency>
	</dependencies>		<dependency>
			<groupId>nz.ac.waikato.cms.weka</groupId>
			<artifactId>weka-stable</artifactId>
			<version>3.8.1</version>
		</dependency>
		<dependency>
			<groupId>com.alibaba</groupId>
			<artifactId>druid</artifactId>
			<version>1.0.11</version>
		</dependency>
	</dependencies>

        配置文件如下

################################server################################
server.port=8601
server.session.timeout=30
server.tomcat.max-threads=5000
server.tomcat.uri-encoding=UTF-8

################################datasource-druid################################
spring.datasource.driverClass=com.mysql.jdbc.Driver
spring.datasource.url=jdbc:mysql://localhost:3306/test?useUnicode=true&characterEncoding=utf-8
spring.datasource.username=root
spring.datasource.password=12345
spring.datasource.platform=mysql
spring.datasource.type=com.alibaba.druid.pool.DruidDataSource
spring.datasource.initialSize=1
spring.datasource.minIdle=3
spring.datasource.maxActive=20
spring.datasource.maxWait=60000
spring.datasource.timeBetweenEvictionRunsMillis=60000
spring.datasource.minEvictableIdleTimeMillis=30000
spring.datasource.validationQuery=select 'x'
spring.datasource.testWhileIdle=true
spring.datasource.testOnBorrow=false
spring.datasource.testOnReturn=false
spring.datasource.poolPreparedStatements=true
spring.datasource.maxPoolPreparedStatementPerConnectionSize=20
spring.datasource.filters=stat,wall,slf4j
spring.datasource.connectionProperties=druid.stat.mergeSql=true;druid.stat.slowSqlMillis=5000


#########################################weka#########################################
model.path=E://model.arff

接下来我们要把词和分类写入ARFF中 ,但是那么多词事先些好吗?太多了吧,而且后期为了让机器知道更到的词怎么分,要不断的往模型中加入,怎么实现动态扩展?

    我们可以把分类以及每一类都有那些词放到数据库中、用java代码自动生成ARFF文件,具体见https://blog.csdn.net/sinat_23225111/article/details/78022802

  目前只讲解关键部分代码,如需要全部代码请访问最上面的github地址,modelPath为ARFF生成模型的地址,如果是本地,可以放在任何一个盘中,我是从配置文件中读出,放在E盘,如果是服务器,模型需要放在服务器的特定路径下
 

@Service
public class ClassifyService {
    private static final Logger logger = getLogger(ClassifyService.class);

    @Autowired
    private ClassifyMapper classifyMapper;
    @Autowired
    private KeyWordMapper keyWordMapper;
   
    @Value("${model.path}")
    private String modelPath;

    @Transactional
    public void createWekaModel() {

        // 从数据库查找到所有的收入分类数据
        List<Classify> allClassifyList = classifyMapper.selectAll();
        if (allClassifyList == null || allClassifyList.isEmpty()) {
            logger.error("没有从数据库查找到分类数据");
            return;
        }
        logger.info("分类模型训练开始");
        generateInstanceAndModelLearn(allClassifyList);
    }


    private void generateInstanceAndModelLearn(List<Classify> allClassifyList) {
        // 生成Instances(每个Instances对应一个ARFF)
        Instances trainData = generateInstance(allClassifyList);
        // 模型学习
        FilteredClassifier evaluateAndLearn = WekaUtil.evaluateAndLearn(trainData);
        WekaUtil.saveModel(modelPath, evaluateAndLearn);
        logger.info("收入分类模型训练完毕并存储到硬盘");
    }

    /**
     * 程序构建Arff文件
     *
     * @param allClassifyList
     * @return
     */
    private Instances generateInstance(List<Classify> allClassifyList) {
        // 得到所有的分类名
        List<String> varietyOfClassify = new ArrayList<>(100);
        for (Classify classify : allClassifyList) {
            varietyOfClassify.add(classify.getClassifyName());
        }
        // 构建@data数据
        List<CreateData> entities = createArffData(allClassifyList);

        ArrayList<Attribute> attributes = new ArrayList<>();
        attributes.add(new Attribute("@@class@@", varietyOfClassify));
        attributes.add(new Attribute("text", true));

        // 构建instances
        Instances instances = new Instances("classify", attributes, 500);
        // 设置分类的索引
        instances.setClassIndex(instances.numAttributes() - 1);

        // 添加数据到@data
        for (CreateData secRepoEntity : entities) {
            Instance instance = new DenseInstance(attributes.size());
            // 必须放在创建一个新的instance后 否则会报没加入Dataset异常
            instance.setDataset(instances);
            if (varietyOfClassify.contains(secRepoEntity.getClassifyName())) {
                instance.setValue(0, secRepoEntity.getClassifyName());
                instance.setValue(1, secRepoEntity.getTestValue());
            }
            instances.add(instance);
        }

        instances.setClassIndex(0);
        return instances;
    }


    /**
     * 准备ArffData数据
     *
     * @param allClassifyList
     * @return
     */
    private List<CreateData> createArffData(List<Classify> allClassifyList) {
        List<CreateData> createArffData = new ArrayList<>();
        for (Classify classify : allClassifyList) {
            List<KeyWord> classifyKeywordByClassifyId = keyWordMapper.selectByClassifyId(classify.getId());
            for (int i = 0; i < classifyKeywordByClassifyId.size(); i++) {
                createArffData.add(new CreateData(classify.getClassifyName(), classifyKeywordByClassifyId.get(i).getKeywordName()));
            }
        }
        return createArffData;
    }

    @Transactional
    public String getResultByExecuteParticipleAndClassify(String word) {
        try {
            if (StringUtils.isBlank(word)) {
                return "";
            }
            logger.info("需要分类的词是" + word);

            // 加载词库模型
            FilteredClassifier model = WekaUtil.loadModel(modelPath);
            List<Classify> allClassifyList = classifyMapper.selectAll();
            List<String> nameString= allClassifyList.stream().map(Classify::getClassifyName).collect(Collectors.toList());
            // 得到分类结果
            String result = makeInstance(model, nameString,word);
            logger.info("分类结果" + result);
            return result;
        } catch (Exception e) {
            logger.error("wordclassify error ,  detail message:{}", e);
        }
        return "";
    }

    /**
     * 生成一个新的instance用于得出结果
     *
     * @param evaluateAndLearn
     * @param varietyOfClassify
     */
    public String makeInstance(FilteredClassifier evaluateAndLearn, List<String> varietyOfClassify,String word) {

        // 添加第一个分类值
        FastVector fvNominalVal = new FastVector(50);
        for (String classify : varietyOfClassify) {
            fvNominalVal.addElement(classify);
        }
        Attribute attribute1 = new Attribute("@@class@@", fvNominalVal);
        Attribute attribute2 = new Attribute("text", (FastVector) null);

        FastVector fvWekaAttributes = new FastVector(2);
        fvWekaAttributes.addElement(attribute1);
        fvWekaAttributes.addElement(attribute2);
        Instances instances = new Instances("cardniu_text_classify", fvWekaAttributes, 1);
        // 设置索引
        instances.setClassIndex(0);
        // 创造一个新instance
        DenseInstance instance = new DenseInstance(2);
        instance.setValue(attribute2, word);
        instances.add(instance);
        double pred;
        try {
            pred = evaluateAndLearn.classifyInstance(instances.instance(0));
            return instances.classAttribute().value((int) pred);
        } catch (Exception e) {
            logger.info(e.getMessage());
        }
        return "";
    }
}@Value("${model.path}")
    private String modelPath;

    @Transactional
    public void createWekaModel() {

        // 从数据库查找到所有的收入分类数据
        List<Classify> allClassifyList = classifyMapper.selectAll();
        if (allClassifyList == null || allClassifyList.isEmpty()) {
            logger.error("没有从数据库查找到分类数据");
            return;
        }
        logger.info("分类模型训练开始");
        generateInstanceAndModelLearn(allClassifyList);
    }


    private void generateInstanceAndModelLearn(List<Classify> allClassifyList) {
        // 生成Instances(每个Instances对应一个ARFF)
        Instances trainData = generateInstance(allClassifyList);
        // 模型学习
        FilteredClassifier evaluateAndLearn = WekaUtil.evaluateAndLearn(trainData);
        WekaUtil.saveModel(modelPath, evaluateAndLearn);
        logger.info("收入分类模型训练完毕并存储到硬盘");
    }

    /**
     * 程序构建Arff文件
     *
     * @param allClassifyList
     * @return
     */
    private Instances generateInstance(List<Classify> allClassifyList) {
        // 得到所有的分类名
        List<String> varietyOfClassify = new ArrayList<>(100);
        for (Classify classify : allClassifyList) {
            varietyOfClassify.add(classify.getClassifyName());
        }
        // 构建@data数据
        List<CreateData> entities = createArffData(allClassifyList);

        ArrayList<Attribute> attributes = new ArrayList<>();
        attributes.add(new Attribute("@@class@@", varietyOfClassify));
        attributes.add(new Attribute("text", true));

        // 构建instances
        Instances instances = new Instances("classify", attributes, 500);
        // 设置分类的索引
        instances.setClassIndex(instances.numAttributes() - 1);

        // 添加数据到@data
        for (CreateData secRepoEntity : entities) {
            Instance instance = new DenseInstance(attributes.size());
            // 必须放在创建一个新的instance后 否则会报没加入Dataset异常
            instance.setDataset(instances);
            if (varietyOfClassify.contains(secRepoEntity.getClassifyName())) {
                instance.setValue(0, secRepoEntity.getClassifyName());
                instance.setValue(1, secRepoEntity.getTestValue());
            }
            instances.add(instance);
        }

        instances.setClassIndex(0);
        return instances;
    }


    /**
     * 准备ArffData数据
     *
     * @param allClassifyList
     * @return
     */
    private List<CreateData> createArffData(List<Classify> allClassifyList) {
        List<CreateData> createArffData = new ArrayList<>();
        for (Classify classify : allClassifyList) {
            List<KeyWord> classifyKeywordByClassifyId = keyWordMapper.selectByClassifyId(classify.getId());
            for (int i = 0; i < classifyKeywordByClassifyId.size(); i++) {
                createArffData.add(new CreateData(classify.getClassifyName(), classifyKeywordByClassifyId.get(i).getKeywordName()));
            }
        }
        return createArffData;
    }

    @Transactional
    public String getResultByExecuteParticipleAndClassify(String word) {
        try {
            if (StringUtils.isBlank(word)) {
                return "";
            }
            logger.info("需要分类的词是" + word);

            // 加载词库模型
            FilteredClassifier model = WekaUtil.loadModel(modelPath);
            List<Classify> allClassifyList = classifyMapper.selectAll();
            List<String> nameString= allClassifyList.stream().map(Classify::getClassifyName).collect(Collectors.toList());
            // 得到分类结果
            String result = makeInstance(model, nameString,word);
            logger.info("分类结果" + result);
            return result;
        } catch (Exception e) {
            logger.error("wordclassify error ,  detail message:{}", e);
        }
        return "";
    }

    /**
     * 生成一个新的instance用于得出结果
     *
     * @param evaluateAndLearn
     * @param varietyOfClassify
     */
    public String makeInstance(FilteredClassifier evaluateAndLearn, List<String> varietyOfClassify,String word) {

        // 添加第一个分类值
        FastVector fvNominalVal = new FastVector(50);
        for (String classify : varietyOfClassify) {
            fvNominalVal.addElement(classify);
        }
        Attribute attribute1 = new Attribute("@@class@@", fvNominalVal);
        Attribute attribute2 = new Attribute("text", (FastVector) null);

        FastVector fvWekaAttributes = new FastVector(2);
        fvWekaAttributes.addElement(attribute1);
        fvWekaAttributes.addElement(attribute2);
        Instances instances = new Instances("cardniu_text_classify", fvWekaAttributes, 1);
        // 设置索引
        instances.setClassIndex(0);
        // 创造一个新instance
        DenseInstance instance = new DenseInstance(2);
        instance.setValue(attribute2, word);
        instances.add(instance);
        double pred;
        try {
            pred = evaluateAndLearn.classifyInstance(instances.instance(0));
            return instances.classAttribute().value((int) pred);
        } catch (Exception e) {
            logger.info(e.getMessage());
        }
        return "";
    }
}

        WekaUtil如下,里面重要的是算法,我这里经过测试,选择了一种分类效果比较好的算法(已标红)

public class WekaUtil {
	private static final Logger logger = getLogger(WekaUtil.class);

	private WekaUtil() {
	}

	/**
	 * 以ARFF格式加载数据生成Instances
	 * 
	 * @param fileName
	 * @return
	 */
	public static Instances loadDataset(String fileName) {
		try (BufferedReader reader = new BufferedReader(new FileReader(fileName))) {

			ArffReader arff = new ArffReader(reader);
			Instances trainData = arff.getData();
			// 设置分类索引为0 必须在训练库之前加上这条代码
			trainData.setClassIndex(0);
			return trainData;
		} catch (IOException e) {
			logger.error(e.getMessage(), e);
		}
		return null;
	}

	/**
	 * 这个方法构建了一个分类器,然后分类器根据读到的数据进行训练、学习,生成模型
	 * 
	 * @param trainData
	 * @return FilteredClassifier
	 */
	public static FilteredClassifier evaluateAndLearn(Instances trainData) {
		try {

			StringToWordVector filter = new StringToWordVector();
			filter.setAttributeIndices("last");
			FilteredClassifier classifier = new FilteredClassifier();
			classifier.setFilter(filter);

			// 可选择不同算法 这里选择效率比较高的算法
			classifier.setClassifier(new RandomTree());

			classifier.buildClassifier(trainData);

			Evaluation eval = new Evaluation(trainData);
			eval.crossValidateModel(classifier, trainData, 2, new Random(1));

			trainData.setClassIndex(0);

			return classifier;
		} catch (Exception e) {
			logger.error(e.getMessage(), e);
		}
		return null;
	}

	/**
	 * 把分类器模型存储到文件中
	 * 
	 * @param fileName
	 * @param classifier
	 */
	public static void saveModel(String fileName, FilteredClassifier classifier) {
		try (ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(fileName))) {
			out.writeObject(classifier);
		} catch (IOException e) {
			logger.error(e.getMessage(), e);
		}
	}

	/**
	 * 这个方法加载分类器模型
	 * 
	 * @param fileName
	 * @return
	 */
	public static FilteredClassifier loadModel(String fileName) {
		try (ObjectInputStream in = new ObjectInputStream(new FileInputStream(fileName))) {

			Object tmp = in.readObject();
			return (FilteredClassifier) tmp;

		} catch (Exception e) {
			logger.error(e.getMessage(), e);
		}
		return null;
	}

}// 可选择不同算法 这里选择效率比较高的算法
			classifier.setClassifier(new RandomTree());

			classifier.buildClassifier(trainData);

			Evaluation eval = new Evaluation(trainData);
			eval.crossValidateModel(classifier, trainData, 2, new Random(1));

			trainData.setClassIndex(0);

			return classifier;
		} catch (Exception e) {
			logger.error(e.getMessage(), e);
		}
		return null;
	}

	/**
	 * 把分类器模型存储到文件中
	 * 
	 * @param fileName
	 * @param classifier
	 */
	public static void saveModel(String fileName, FilteredClassifier classifier) {
		try (ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(fileName))) {
			out.writeObject(classifier);
		} catch (IOException e) {
			logger.error(e.getMessage(), e);
		}
	}

	/**
	 * 这个方法加载分类器模型
	 * 
	 * @param fileName
	 * @return
	 */
	public static FilteredClassifier loadModel(String fileName) {
		try (ObjectInputStream in = new ObjectInputStream(new FileInputStream(fileName))) {

			Object tmp = in.readObject();
			return (FilteredClassifier) tmp;

		} catch (Exception e) {
			logger.error(e.getMessage(), e);
		}
		return null;
	}

}

数据库设计如下:service层把这些数据写入weka的ARFF文件中。

DROP TABLE IF EXISTS `classify`;
CREATE TABLE `classify` (
  `id` int(4) NOT NULL AUTO_INCREMENT,
  `classify_name` varchar(255) DEFAULT NULL,
  PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=4 DEFAULT CHARSET=utf8;


-- ----------------------------
-- Records of classify
-- ----------------------------
INSERT INTO `classify` VALUES ('1', '人名');
INSERT INTO `classify` VALUES ('2', '地名');
INSERT INTO `classify` VALUES ('3', '省份');


DROP TABLE IF EXISTS `key_word`;
CREATE TABLE `key_word` (
  `id` int(4) NOT NULL AUTO_INCREMENT,
  `classify_id` int(4) DEFAULT NULL,
  `keyword_name` varchar(255) DEFAULT NULL,
  PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=7 DEFAULT CHARSET=utf8;


-- ----------------------------
-- Records of key_word
-- ----------------------------
INSERT INTO `key_word` VALUES ('1', '1', '张三');
INSERT INTO `key_word` VALUES ('2', '1', '李四');
INSERT INTO `key_word` VALUES ('3', '2', '超市');
INSERT INTO `key_word` VALUES ('4', '2', '学校');
INSERT INTO `key_word` VALUES ('5', '3', '广东省');
INSERT INTO `key_word` VALUES ('6', '3', '黑龙江省');

验证结果 ,首先启动项目,第一步训练模型(把数据库中数据写入ARFF并生成模型存储在本地)

第二部,验证结果,我输入数据库的广东省,对应数据库中的省份,所以weka分类结果为省份;如果我输入张三,则对应的是人名,可想而知,当你把大量数据放入数据库,weka就能分类的越来越准,机器在不断地学习,模型在不断地完善。我们可以把训练模型变为定时任务,或者再加上分类,就可以对句子进行分类了。

    本demo已上传github 地址:https://github.com/CharlsShan/word-classify/tree/master/src/main 

  • 6
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一杯咖啡半杯糖

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值