目前深度学习已经越来越受到重视,深度学习的框架也是层出不穷,例如谷歌的TensorFlow,它是基于Python进行开发的,对于许多对Python不够了解的程序员来说用起来可能没有那么的方便,这里说一下一个基于Java的深度学习框架——DL4J。本博客主要介绍在代码层面基于DL4J模型实现训练Word2Vec,一起来看一下吧~
【代码】
package com.xzw.dl4j;
import java.io.File;
import java.io.IOException;
import java.util.Collection;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.EndingPreProcessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
/**
*
* @author xzw
*
*/
public class Word2VecTest {
@SuppressWarnings("deprecation")
public static void main(String[] args) throws IOException {
System.out.println("Load data...");
File file = new File("C://Users//Machenike//Desktop//zzz//raw_sentences.txt");
SentenceIterator iterator = new LineSentenceIterator(file);
iterator.setPreProcessor(new SentencePreProcessor() {
private static final long serialVersionUID = 1L;
@Override
public String preProcess(String sentence) {
// TODO Auto-generated method stub
return sentence.toLowerCase();
}
});
System.out.println("Tokenize data...");
final EndingPreProcessor preProcessor = new EndingPreProcessor();
TokenizerFactory tokenizer = new DefaultTokenizerFactory();
tokenizer.setTokenPreProcessor(new TokenPreProcess() {
@Override
public String preProcess(String token) {
// TODO Auto-generated method stub
token = token.toLowerCase();
String base = preProcessor.preProcess(token);
base = base.replaceAll("\\d", "d");
return base;
}
});
System.out.println("Build model...");
int batchSize = 1000;
int iterations = 3;
int layerSize = 150;
Word2Vec vec = new Word2Vec.Builder()
.batchSize(batchSize)
.minWordFrequency(5)
.useAdaGrad(false)
.layerSize(layerSize)
.iterations(iterations)
.learningRate(0.025)
.minLearningRate(1e-3)
.negativeSample(10)
.iterate(iterator)
.tokenizerFactory(tokenizer)
.build();
//train
System.out.println("Learning...");
vec.fit();
//model save
System.out.println("Save model...");
WordVectorSerializer.writeWordVectors(vec, "C://Users//Machenike//Desktop//zzz//words.txt");
System.out.println("Evaluate model...");
String word1 = "people";
String word2 = "money";
double similarity = vec.similarity(word1, word2);
System.out.println(String.format("The similarity between %s and %s is %f", word1, word2, similarity));
String word = "day";
int ranking = 10;
Collection<String> similarTop10 = vec.wordsNearest(word, ranking);
System.out.println(String.format("Similar word to %s is %s", word, similarTop10));
}
}
【用到的数据集】
【保存的Word2Vec模型】
【运行结果】