Hanlp之文本分类

1、语料库格式

分类语料的根目录.目录必须满足如下结构:<br>
                     根目录<br>
                     ├── 分类A<br>
                     │   └── 1.txt<br>
                     │   └── 2.txt<br>
                     │   └── 3.txt<br>
                     ├── 分类B<br>
                     │   └── 1.txt<br>
                     │   └── ...<br>
                     └── ...<br>
 文件不一定需要用数字命名,也不需要以txt作为后缀名,但一定需要是文本文件.

2、项目格式

训练分类语料库要放到data/test/ 目录下

3、代码

(1)TestUtility类

/*
 * <author>Han He</author>
 * <email>me@hankcs.com</email>
 * <create-date>2018-06-23 11:05 PM</create-date>
 *
 * <copyright file="TestUtility.java">
 * Copyright (c) 2018, Han He. All Rights Reserved, http://www.hankcs.com/
 * This source is subject to Han He. Please contact Han He for more information.
 * </copyright>
 */
package com.cn.test.TextClassification;

import com.hankcs.hanlp.HanLP;

import java.io.*;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;

/**
 * @author hankcs
 */
public class TestUtility
{
    static
    {
        ensureFullData();
    }

    public static void ensureFullData()
    {
        ensureData(HanLP.Config.PerceptronCWSModelPath, "http://nlp.hankcs.com/download.php?file=data", HanLP.Config.PerceptronCWSModelPath.split("data")[0], false);
    }

    /**
     * 保证 name 存在,不存在时自动下载解压
     *
     * @param name 路径
     * @param url  下载地址
     * @return name的绝对路径
     */
    public static String ensureData(String name, String url)
    {
        return ensureData(name, url, null, true);
    }

    /**
     * 保证 name 存在,不存在时自动下载解压
     *
     * @param name 路径
     * @param url  下载地址
     * @return name的绝对路径
     */
    public static String ensureData(String name, String url, String parentPath, boolean overwrite)
    {
        File target = new File(name);
        if (target.exists()) return target.getAbsolutePath();
        try
        {
            File parentFile = parentPath == null ? new File(name).getParentFile() : new File(parentPath);
            if (!parentFile.exists()) parentFile.mkdirs();
            String filePath = downloadFile(url, parentFile.getAbsolutePath());
            if (filePath.endsWith(".zip"))
            {
                unzip(filePath, parentFile.getAbsolutePath(), overwrite);
            }
            return target.getAbsolutePath();
        }
        catch (Exception e)
        {
            System.err.printf("数据下载失败,请尝试手动下载 %s 到 %s 。原因如下:\n", url, target.getAbsolutePath());
            e.printStackTrace();
            System.exit(1);
            return null;
        }
    }

    /**
     * 保证 data/test/name 存在
     *
     * @param name
     * @param url
     * @return
     */
    public static String ensureTestData(String name, String url)
    {
        return ensureData(String.format("data/test/%s", name), url);
    }

    /**
     * Downloads a file from a URL
     *
     * @param fileURL  HTTP URL of the file to be downloaded
     * @param savePath path of the directory to save the file
     * @throws IOException
     * @author www.codejava.net
     */
    public static String downloadFile(String fileURL, String savePath)
        throws IOException
    {
        System.err.printf("Downloading %s to %s\n", fileURL, savePath);
        HttpURLConnection httpConn = request(fileURL);
        while (httpConn.getResponseCode() == HttpURLConnection.HTTP_MOVED_PERM || httpConn.getResponseCode() == HttpURLConnection.HTTP_MOVED_TEMP)
        {
            httpConn = request(httpConn.getHeaderField("Location"));
        }

        // always check HTTP response code first
        if (httpConn.getResponseCode() == HttpURLConnection.HTTP_OK)
        {
            String fileName = "";
            String disposition = httpConn.getHeaderField("Content-Disposition");
            String contentType = httpConn.getContentType();
            int contentLength = httpConn.getContentLength();

            if (disposition != null)
            {
                // extracts file name from header field
                int index = disposition.indexOf("filename=");
                if (index > 0)
                {
                    fileName = disposition.substring(index + 10,
                                                     disposition.length() - 1);
                }
            }
            else
            {
                // extracts file name from URL
                fileName = new File(httpConn.getURL().getPath()).getName();
            }

//            System.out.println("Content-Type = " + contentType);
//            System.out.println("Content-Disposition = " + disposition);
//            System.out.println("Content-Length = " + contentLength);
//            System.out.println("fileName = " + fileName);

            // opens input stream from the HTTP connection
            InputStream inputStream = httpConn.getInputStream();
            String saveFilePath = savePath;
            if (new File(savePath).isDirectory())
                saveFilePath = savePath + File.separator + fileName;
            String realPath;
            if (new File(saveFilePath).isFile())
            {
                System.err.printf("Use cached %s instead.\n", fileName);
                realPath = saveFilePath;
            }
            else
            {
                saveFilePath += ".downloading";

                // opens an output stream to save into file
                FileOutputStream outputStream = new FileOutputStream(saveFilePath);

                int bytesRead;
                byte[] buffer = new byte[4096];
                long start = System.currentTimeMillis();
                int progress_size = 0;
                while ((bytesRead = inputStream.read(buffer)) != -1)
                {
                    outputStream.write(buffer, 0, bytesRead);
                    long duration = (System.currentTimeMillis() - start) / 1000;
                    duration = Math.max(duration, 1);
                    progress_size += bytesRead;
                    int speed = (int) (progress_size / (1024 * duration));
                    float ratio = progress_size / (float) contentLength;
                    float percent = ratio * 100;
                    int eta = (int) (duration / ratio * (1 - ratio));
                    int minutes = eta / 60;
                    int seconds = eta % 60;

                    System.err.printf("\r%.2f%%, %d MB, %d KB/s, ETA %d min %d s", percent, progress_size / (1024 * 1024), speed, minutes, seconds);
                }
                System.err.println();
                outputStream.close();
                realPath = saveFilePath.substring(0, saveFilePath.length() - ".downloading".length());
                if (!new File(saveFilePath).renameTo(new File(realPath)))
                    throw new IOException("Failed to move file");
            }
            inputStream.close();
            httpConn.disconnect();

            return realPath;
        }
        else
        {
            httpConn.disconnect();
            throw new IOException("No file to download. Server replied HTTP code: " + httpConn.getResponseCode());
        }
    }

    private static HttpURLConnection request(String url) throws IOException
    {
        HttpURLConnection httpConn = (HttpURLConnection) new URL(url).openConnection();
        httpConn.setRequestProperty("User-Agent", "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.4; en-US; rv:1.9.2.2) Gecko/20100316 Firefox/3.6.2");
        return httpConn;
    }

    private static void unzip(String zipFilePath, String destDir, boolean overwrite)
    {
        System.err.println("Unzipping to " + destDir);
        File dir = new File(destDir);
        // create output directory if it doesn't exist
        if (!dir.exists()) dir.mkdirs();
        FileInputStream fis;
        //buffer for read and write data to file
        byte[] buffer = new byte[4096];
        try
        {
            fis = new FileInputStream(zipFilePath);
            ZipInputStream zis = new ZipInputStream(fis);
            ZipEntry ze = zis.getNextEntry();
            while (ze != null)
            {
                String fileName = ze.getName();
                File newFile = new File(destDir + File.separator + fileName);
                if (overwrite || !newFile.exists())
                {
                    if (ze.isDirectory())
                    {
                        //create directories for sub directories in zip
                        newFile.mkdirs();
                    }
                    else
                    {
                        new File(newFile.getParent()).mkdirs();
                        FileOutputStream fos = new FileOutputStream(newFile);
                        int len;
                        while ((len = zis.read(buffer)) > 0)
                        {
                            fos.write(buffer, 0, len);
                        }
                        fos.close();
                        //close this ZipEntry
                        zis.closeEntry();
                    }
                }
                ze = zis.getNextEntry();
            }
            //close last ZipEntry
            zis.closeEntry();
            zis.close();
            fis.close();
            new File(zipFilePath).delete();
        }
        catch (IOException e)
        {
            e.printStackTrace();
        }
    }
}

(2)ModelTrain类

package com.cn.test.TextClassification;

import com.hankcs.hanlp.classification.classifiers.IClassifier;
import com.hankcs.hanlp.classification.classifiers.NaiveBayesClassifier;
import com.hankcs.hanlp.classification.models.NaiveBayesModel;
import com.hankcs.hanlp.corpus.io.IOUtil;

import java.io.File;


public class ModelTrain {
    /**
     * 搜狗文本分类语料库5个类目,每个类目下1000篇文章,共计5000篇文章
     */
    public static final String CORPUS_FOLDER = TestUtility.ensureTestData("搜狗文本分类语料库迷你版", "");

    /**
     * 模型保存路径
     */
    public static final String MODEL_PATH = "data/test/classification-model.ser";


    public static NaiveBayesModel trainOrLoadModel()
    {
        NaiveBayesModel model = (NaiveBayesModel) IOUtil.readObjectFrom(MODEL_PATH);
        if (model != null) return model;

        File corpusFolder = new File(CORPUS_FOLDER);
        if (!corpusFolder.exists() || !corpusFolder.isDirectory())
        {
            System.err.println("没有文本分类语料!");
            System.exit(1);
        }
        try{
            IClassifier classifier = new NaiveBayesClassifier(); // 创建分类器,更高级的功能请参考IClassifier的接口定义
            classifier.train(CORPUS_FOLDER);                     // 训练后的模型支持持久化,下次就不必训练了
            model = (NaiveBayesModel) classifier.getModel();
            IOUtil.saveObjectTo(model, MODEL_PATH);
        }catch (Exception e){
            e.printStackTrace();
        }
        return model;
    }
}

(3)InitOneObject类

package com.cn.test.TextClassification;

import com.hankcs.hanlp.classification.classifiers.IClassifier;
import com.hankcs.hanlp.classification.classifiers.NaiveBayesClassifier;

public class InitOneObject {
    public static final InitOneObject instance= new InitOneObject();
    //获取classifier对象,训练后的模型支持持久化,下次就不必训练了.
    public IClassifier classifier = new NaiveBayesClassifier(ModelTrain.trainOrLoadModel());
}

(4)TrainMain类

package com.cn.test.TextClassification;

import com.hankcs.hanlp.classification.classifiers.IClassifier;

import java.io.IOException;

public class TrainMain {
    public static void main(String[] args) throws IOException
    {
        //IClassifier classifier = new NaiveBayesClassifier(trainOrLoadModel());
        IClassifier classifier= InitOneObject.instance.classifier;
        predict(classifier, "C罗获2018环球足球奖最佳球员 德尚荣膺最佳教练");
        predict(classifier, "英国造航母耗时8年仍未服役 被中国速度远远甩在身后");
        predict(classifier, "研究生考录模式亟待进一步专业化");
        predict(classifier, "如果真想用食物解压,建议可以食用燕麦");
        predict(classifier, "锄禾日当午,汗滴禾下土");
    }

    private static void predict(IClassifier classifier, String text)
    {
        System.out.printf("《%s》 属于分类 【%s】\n", text, classifier.classify(text));
    }
}

4、运行结果

模式:训练集
文本编码:UTF-8
根目录:C:\MyselfApplication\MyProject\HanLp\data\test\搜狗文本分类语料库迷你版
加载中...
[体育]...100.00% 1000 篇文档
[健康]...100.00% 1000 篇文档
[军事]...100.00% 1000 篇文档
[教育]...100.00% 1000 篇文档
[汽车]...100.00% 1000 篇文档
耗时 15868 ms 加载了 5 个类目,共 5000 篇文档
原始数据集大小:5000
使用卡方检测选择特征中...耗时 189 ms,选中特征数:18156 / 80986 = 22.42%
贝叶斯统计结束
《C罗获2018环球足球奖最佳球员 德尚荣膺最佳教练》 属于分类 【体育】
《英国造航母耗时8年仍未服役 被中国速度远远甩在身后》 属于分类 【军事】
《研究生考录模式亟待进一步专业化》 属于分类 【教育】
《如果真想用食物解压,建议可以食用燕麦》 属于分类 【健康】
《锄禾日当午,汗滴禾下土》 属于分类 【健康】

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

郝少

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值