使用 DL4J 训练中文词向量

使用 DL4J 训练中文词向量

1 预处理

对中文语料的预处理,主要包括:分词、去停用词以及一些根据实际场景制定的规则。

package ai.mole.test;

import org.ansj.domain.Term;
import org.ansj.splitWord.analysis.ToAnalysis;
import org.nlpcn.commons.lang.tire.domain.Forest;
import org.nlpcn.commons.lang.tire.library.Library;

import java.io.*;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Pattern;

public class Preprocess {
    private static final Pattern NUMERIC_PATTERN = Pattern.compile("^[.\\d]+$");
    private static final Pattern ENGLISH_WORD_PATTERN = Pattern.compile("^[a-z]+$");

    public static void main(String[] args) {
        String inPath1 = "D:\\MyData\\XUGP3\\Desktop\\测试分词\\test1.txt";
        String inPath2 = "D:\\MyData\\XUGP3\\Desktop\\测试分词\\stop_words.txt";
        String outPath = "D:\\MyData\\XUGP3\\Desktop\\测试分词\\result1.txt";
        String encoding = "utf-8";

        PrintWriter writer = null;
        Forest forest = null;
        try {
            writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(outPath), encoding));
            forest = Library.makeForest(Test.class.getResourceAsStream("/library/userLibrary.dic"));

            List<String> lineList = IOUtil.readLines(new FileInputStream(inPath1), encoding);
            List<String> stopWordList = IOUtil.readLines(new FileInputStream(inPath2), encoding);

            for (String line : lineList) {
                String[] cols = line.split("\\t", -1);

                if (cols.length < 2) {
                    continue;
                }

                String text = cols[0].trim().toLowerCase() + " " + cols[1].trim().toLowerCase();

                // 分词
                List<Term> termList = ToAnalysis.parse(text, forest).getTerms();
                List<String> wordList = new LinkedList<>();
                for (Term term : termList) {
                    String word = term.getName();

                    if (word.length() < 2) {
                        continue;
                    }

                    if (stopWordList.contains(word)) {
                        continue;
                    }

                    if (isNumeric(word)) {
                        continue;
                    }

                    if (isEnglishWord(word)) {
                        continue;
                    }

                    wordList.add(word);
                }

                if (wordList.size() > 5) {
                    String outStr = listToLine(wordList);
                    writer.println(outStr);
                }
            }
        } catch (FileNotFoundException e) {
            System.out.println("The file does not exist or the path is not correct!!!");
            System.exit(-1);
        } catch (UnsupportedEncodingException e) {
            System.out.println("Does not support the current character set!!!");
        } catch (IOException e) {
            e.printStackTrace();
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (writer != null) {
                writer.close();
            }
        }
    }

    private static boolean isNumeric(String text) {
        return NUMERIC_PATTERN.matcher(text).matches();
    }

    private static boolean isEnglishWord(String text) {
        return ENGLISH_WORD_PATTERN.matcher(text).matches();
    }

    private static String listToLine(List<String> list) {
        StringBuilder sb = new StringBuilder();
        for (int i=0; i<list.size(); i++) {
            sb.append(list.get(i));
            if (i != list.size()-1) {
                sb.append(" ");
            }
        }
        return sb.toString();
    }
}

2 训练

训练的代码非常简单,可以直接看官网的教程,至于 word2vec 的原理可以看皮提果的博文。

package ai.mole.test;

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.util.Collection;

public class TrainWord2VecModel {
    private static Logger log = LoggerFactory.getLogger(TrainWord2VecModel.class);

    public static void main(String[] args) throws IOException {
        String corpusPath = "/data/analyze/xgp/words.txt";
        String vectorsPath = "/data/analyze/xgp/word_vectors.txt";

        log.info("Start Training...");
        long st = System.currentTimeMillis();

        log.info("Load & vectorize sentences...");
        SentenceIterator iter = new BasicLineIterator(new File(corpusPath));
        TokenizerFactory t = new DefaultTokenizerFactory();
//        t.setTokenPreProcessor(new CommonPreprocessor());

        log.info("Building model...");
        Word2Vec vec = new Word2Vec.Builder()
                .minWordFrequency(50)
                .iterations(1)
                .epochs(100)
                .layerSize(500)
                .seed(42)
                .windowSize(5)
                .iterate(iter)
                .tokenizerFactory(t)
                .build();

        log.info("Fitting word2vec model...");
        vec.fit();

        log.info("Writing word vectors to text file...");
//        WordVectorSerializer.writeWord2VecModel(vec, vectorsPath);
        WordVectorSerializer.writeWordVectors(vec, vectorsPath);

        log.info("Closest words:");
        Collection<String> bydWordList = vec.wordsNearest("比亚迪", 10);
        Collection<String> changanWordList = vec.wordsNearest("长安", 10);
        System.out.print(bydWordList);
        System.out.println(changanWordList);

        log.info("10 words closest to '比亚迪': {}", bydWordList);
        log.info("10 words closest to '长安': {}", changanWordList);

        long et = System.currentTimeMillis();
        log.info("Training is completed, and the time taken is " + (et-st) + " ms.");
        System.out.println("Training is completed, and the time taken is " + (et-st) + " ms.");
    }
}

3 调用

调用训练好的词向量也非常简单,只需要调用 WordVectorSerializer 类的静态方法 readWord2VecModel 就可以了,提供的输入参数就是训练好的词向量路径。

Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel("D:\\MyData\\XUGP3\\Desktop\\测试分词\\vectors.txt");
Collection<String> bydWordList = word2Vec.wordsNearest("比亚迪", 10);
Collection<String> changanWordList = word2Vec.wordsNearest("长安", 10);
System.out.println(bydWordList);
System.out.println(changanWordList);

附录 - maven 依赖

<dependencies>
    <dependency>
        <groupId>org.apdplat</groupId>
        <artifactId>word</artifactId>
        <version>1.3</version>
    </dependency>

    <!-- ND4J backend. You need one in every DL4J project. Normally define artifactId as either "nd4j-native-platform" or "nd4j-cuda-7.5-platform" -->
    <dependency>
        <groupId>org.nd4j</groupId>
        <artifactId>${nd4j.backend}</artifactId>
        <version>${nd4j.version}</version>
    </dependency>

    <!-- Core DL4J functionality -->
    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>deeplearning4j-core</artifactId>
        <version>${dl4j.version}</version>
    </dependency>

    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>deeplearning4j-nlp</artifactId>
        <version>${dl4j.version}</version>
    </dependency>

    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>deeplearning4j-zoo</artifactId>
        <version>${dl4j.version}</version>
    </dependency>

    <!-- deeplearning4j-ui is used for visualization: see http://deeplearning4j.org/visualization -->
    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>deeplearning4j-ui_${scala.binary.version}</artifactId>
        <version>${dl4j.version}</version>
    </dependency>

    <!-- ParallelWrapper & ParallelInference live here -->
    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>deeplearning4j-parallel-wrapper_${scala.binary.version}</artifactId>
        <version>${dl4j.version}</version>
    </dependency>

    <!-- Next 2: used for MapFileConversion Example. Note you need *both* together -->
    <dependency>
        <groupId>org.datavec</groupId>
        <artifactId>datavec-hadoop</artifactId>
        <version>${datavec.version}</version>
    </dependency>

    <dependency>
        <groupId>org.apache.hadoop</groupId>
        <artifactId>hadoop-common</artifactId>
        <version>${hadoop.version}</version>
    </dependency>


    <!-- Arbiter - used for hyperparameter optimization (grid/random search) -->
    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>arbiter-deeplearning4j</artifactId>
        <version>${arbiter.version}</version>
    </dependency>
    
    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>arbiter-ui_2.11</artifactId>
        <version>${arbiter.version}</version>
    </dependency>

    <!-- datavec-data-codec: used only in video example for loading video data -->
    <dependency>
        <artifactId>datavec-data-codec</artifactId>
        <groupId>org.datavec</groupId>
        <version>${datavec.version}</version>
    </dependency>
</dependencies>

转载于:https://www.cnblogs.com/xugenpeng/p/9144656.html

卷积神经网络(Convolutional Neural Network, CNN)是一种专门针对图像、视频等结构化数据设计的深度学习模型,它在计算机视觉、语音识别、自然语言处理等多个领域都有广泛应用。CNN的核心设计理念源于对生物视觉系统的模拟,尤其是大脑皮层中视觉信息处理的方式,其主要特点包括局部感知、权重共享、多层级抽象以及空间不变性。以下是CNN技术的详细介绍: ### **1. 局部感知与卷积操作** **卷积层**是CNN的基本构建块,它通过使用一组可学习的滤波器(或称为卷积核)对输入图像进行扫描。每个滤波器在图像上滑动(卷积),并以局部区域(感受野)内的像素值与滤波器权重进行逐元素乘法后求和,生成一个输出值。这一过程强调了局部特征的重要性,因为每个滤波器仅对一小部分相邻像素进行响应,从而能够捕获图像中的边缘、纹理、颜色分布等局部特征。 ### **2. 权重共享** 在CNN中,同一滤波器在整个输入图像上保持相同的权重(参数)。这意味着,无论滤波器在图像的哪个位置应用,它都使用相同的参数集来提取特征。这种权重共享显著减少了模型所需的参数数量,增强了模型的泛化能力,并且体现了对图像平移不变性的内在假设,即相同的特征(如特定形状或纹理)不论出现在图像的哪个位置,都应由相同的滤波器识别。 ### **3. 池化操作** **池化层**通常紧随卷积层之后,用于进一步降低数据维度并引入一定的空间不变性。常见的池化方法有最大池化和平均池化,它们分别取局部区域的最大值或平均值作为输出。池化操作可以减少模型对微小位置变化的敏感度,同时保留重要的全局或局部特征。 ### **4. 多层级抽象** CNN通常包含多个卷积和池化层堆叠在一起,形成深度网络结构。随着网络深度的增加,每一层逐渐提取更复杂、更抽象的特征。底层可能识别边缘、角点等低级特征,中间层识别纹理、部件等中级特征,而高层可能识别整个对象或场景等高级语义特征。这种层级结构使得CNN能够从原始像素数据中自动学习到丰富的表示,无需人工设计复杂的特征。 ### **5. 激活函数与正则化** CNN中通常使用非线性激活函数(如ReLU、sigmoid、tanh等)来引入非线性表达能力,使得网络能够学习复杂的决策边界。为了防止过拟合,CNN常采用正则化技术,如L2正则化(权重衰减)来约束模型复杂度,以及Dropout技术,在训练过程中随机丢弃一部分神经元的输出,以增强模型的泛化性能。 ### **6. 应用场景** CNN在诸多领域展现出强大的应用价值,包括但不限于: - **图像分类**:如识别图像中的物体类别(猫、狗、车等)。 - **目标检测**:在图像中定位并标注出特定对象的位置及类别。 - **语义分割**:对图像中的每个像素进行分类,确定其所属的对象或背景类别。 - **人脸识别**:识别或验证个体身份。 - **图像生成**:通过如生成对抗网络(GANs)等技术创建新的、逼真的图像。 - **医学影像分析**:如肿瘤检测、疾病诊断等。 - **自然语言处理**:如文本分类、情感分析、词性标注等,尽管这些任务通常结合其他类型的网络结构(如循环神经网络)。 ### **7. 发展与演变** CNN的概念起源于20世纪80年代,但其影响力在硬件加速(如GPU)和大规模数据集(如ImageNet)出现后才真正显现。经典模型如LeNet-5用于手写数字识别,而AlexNet、VGG、GoogLeNet、ResNet等现代架构在图像识别竞赛中取得突破性成果,推动了CNN技术的快速发展。如今,CNN已经成为深度学习图像处理领域的基石,并持续创新,如引入注意力机制、残差学习、深度可分离卷积等先进思想。 综上所述,卷积神经网络通过其独特的局部感知、权重共享、多层级抽象等特性,高效地从图像数据中提取特征并进行学习,已成为解决图像和视频处理任务不可或缺的工具,并在众多实际应用中取得了卓越的效果。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值