本文实现使用 Java深度学习框架DL4J 完成Word2Vec模型的建立。
将训练语料保存在 all_data.txt文件内,格式为每一行一条样本,且经过分词、过滤处理。
如:
//原文本样本
String raw = "超半数省份出台供给侧改革方案,降低要素成本成难点。";
//分词过滤后,空格相间隔
String washed = "超 半数 省份 出台 供给 侧 改革 方案 降低 要素 成本 成 难点";
模型构建代码如下:
import java.io.File;
import java.io.IOException;
import java.util.Collection;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
/**
* 实现Word2Vec模型建立
* @author windows
*
*/
public class Word2VecModel {
/** 训练语料位置 */
public static String washed_data_path = "D:/Document/data/";
public static void main(String[] args) throws IOException {
//训练语料保存路径
String filePath = washed_data_path + "all_data.txt";
System.out.println("Load & Vectorize Sentences....");
SentenceIterator iter = new LineSentenceIterator(new File(filePath));
iter.setPreProcessor(new SentencePreProcessor() {
public String preProcess(String sentence) {
return sentence.toLowerCase();
}
});
//分词器,以空格为间隔
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
System.out.println("Building model...");
//模型构建
Word2Vec vec = new Word2Vec.Builder()
.minWordFrequency(5) //最小词频
.iterations(1) //迭代次数
.layerSize(128) //词向量维度
.seed(42) //随机种子数
.windowSize(5) //窗口大小
.iterate(iter) //文本迭代
.tokenizerFactory(t) //分词器
.build();
System.out.println("Fitting Word2Vec model....");
vec.fit();
//Saving model
System.out.println("Save vectors....");
WordVectorSerializer.writeWord2VecModel(vec, "D:/Document/data/text_category/word2vec.txt");
//Reload model
// System.out.println("Reload model....");
// Word2Vec vec = WordVectorSerializer.readWord2VecModel("D:/Document/data/text_category/word2vec.txt");
//找出相近词
System.out.println("Closest Words:");
Collection<String> lst = vec.wordsNearest("财经", 10);
System.out.println("10 Words closest :");
System.out.println(lst);
System.out.println("end.");
}
}
项目jar包依赖:
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>1.0.0-beta4</version>
</dependency>
完!