在SpringBoot项目中,将预训练好的W2V模型bin文件放到项目中,引入依赖,使用java的方法分词、计算余弦相似度,计算出两个句子的语义相似度。注意bin模型要自己训练或下载。
1. 引入pom依赖
常用想要的依赖这里都有:
Maven Repository: Search/Browse/Explore (mvnrepository.com)https://mvnrepository.com/
这里引入了深度学习库和分词库:
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>1.0.0-beta7</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>1.0.0-beta7</version>
</dependency>
<dependency>
<groupId>com.huaban</groupId>
<artifactId>jieba-analysis</artifactId>
<version>1.0.2</version>
</dependency>
2. 写Word2Vector方法类
分词——停用部分分词以提升效果——计算余弦相似度
package com.example.demo.service.Impl.data;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import com.huaban.analysis.jieba.JiebaSegmenter;
import java.util.Arrays;
import java.util.List;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
public class SentenceSimilarityCalculator {
private Word2Vec word2Vec;
private JiebaSegmenter segmenter;
private Set<String> stopwords;
public SentenceSimilarityCalculator(String modelPath) {
word2Vec = WordVectorSerializer.readWord2VecModel(modelPath);
segmenter = new JiebaSegmenter();
initializeStopwords(); // 初始化停用词
}
private void initializeStopwords() {
// 设置分词的停用词,以提升效果,这里示例停用了(),
stopwords = new HashSet<>();
stopwords.add(")");
stopwords.add("(");
stopwords.add(",");
}
// 过滤和加权分词结果
private List<String> filterAndWeightWords(List<String> words) {
List<String> filteredWords = new ArrayList<>();
for (String word : words) {
// 过滤掉停用词和特定符号
if (!stopwords.contains(word) && !word.matches("[\\pP\\pS]+")) {
// 这里可以做加权处理
filteredWords.add(word);
}
}
return filteredWords;
}
public double calculateSimilarity(String sentence1, String sentence2) {
List<String> words1 = segmenter.sentenceProcess(sentence1);
List<String> words2 = segmenter.sentenceProcess(sentence2);
// 过滤和加权处理分词结果
List<String> filteredWords1 = filterAndWeightWords(words1);
List<String> filteredWords2 = filterAndWeightWords(words2);
String[] arrayWords1 = filteredWords1.toArray(new String[0]);
String[] arrayWords2 = filteredWords2.toArray(new String[0]);
double[] vector1 = calculateAverageVector(arrayWords1);
double[] vector2 = calculateAverageVector(arrayWords2);
double cosineSimilarity = cosineSimilarity(vector1, vector2);
return cosineSimilarity;
}
private double[] calculateAverageVector(String[] words) {
double[] sumVector = new double[word2Vec.getWordVector(word2Vec.vocab().wordAtIndex(0)).length];
for (String word : words) {
double[] wordVector = word2Vec.getWordVector(word);
if (wordVector != null) {
for (int i = 0; i < wordVector.length; i++) {
sumVector[i] += wordVector[i];
}
}
}
for (int i = 0; i < sumVector.length; i++) {
sumVector[i] /= words.length;
}
System.out.println("Words: " + Arrays.toString(words));
System.out.println("SumVector: " + Arrays.toString(sumVector));
//上面打印出来,方便在控制台观察W2V分词相似等情况
return sumVector;
}
private double cosineSimilarity(double[] vector1, double[] vector2) {
double dotProduct = 0.0;
double norm1 = 0.0;
double norm2 = 0.0;
for (int i = 0; i < vector1.length; i++) {
dotProduct += vector1[i] * vector2[i];
norm1 += Math.pow(vector1[i], 2);
norm2 += Math.pow(vector2[i], 2);
}
return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
}
//当结果需要拉伸变幻时
//如果相似度在0.0到0.2之间,将线性映射到-1.0到0.2的范围
public static double mapSimilarity(double similarity) {//(实际范围)/实际范围差*映射范围差+映射起点
if (similarity >= 0.2) {
return similarity;
} else {
return (similarity - 0.0) / (0.2 - 0.0) * (0.2 - (-1.0)) + (-1.0);
}
}
}
3. 在需要计算的地方引用计算方法
String modelPath = "src/main/java/com/example/demo/service/Impl/data/pretrained_word2vec.bin";//自己的路径
SentenceSimilarityCalculator calculator = new SentenceSimilarityCalculator(modelPath);
sentenceSimilarity = calculator.calculateSimilarity(sentence1, sentence2);
double mappedSimilarity = calculator.mapSimilarity(sentenceSimilarity);
mappedSimilarity *= 100;
如果需要打印观察结果
String modelPath = "src/main/java/com/example/demo/service/Impl/data/pretrained_word2vec.bin";
SentenceSimilarityCalculator calculator = new SentenceSimilarityCalculator(modelPath);
sentenceSimilarity = calculator.calculateSimilarity(sentence1, sentence2);
System.out.println("sentenceSimilarity between the two sentences: " + sentenceSimilarity);
double mappedSimilarity = calculator.mapSimilarity(sentenceSimilarity);
System.out.println("mappedSimilarity between the two sentences: " + mappedSimilarity);
mappedSimilarity *= 100;
System.out.println("FinalMappedSimilarity between the two sentences: " + mappedSimilarity);
4. 输入或从其他地方接入需要计算语义相似度的句子
String sentence1 = "比利时大个子费莱尼,传奇轰炸机";
String sentence2 = "泰山老队长费莱尼,来时英雄去时传奇";