使用Java ONNXRuntime进行语义相似度计算

1. 下载模型到本地

huggingface-cli download BAAI/bge-m3  --include onnx/*.*  --local-dir . --local-dir-use-symlinks False

package com.yucl.demo.djl;

import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;

public class OnnxEmbeddingDemo {

    public static void main(String[] args) throws Exception {
        String TOKENIZER_URI = "D:\\llm\\bge-m3-onnx\\tokenizer.json";
        String MODEL_URI = "D:\\llm\\bge-m3-onnx\\model.onnx";
        HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(TOKENIZER_URI), Map.of());
        try (OrtEnvironment environment = OrtEnvironment.getEnvironment();
                OrtSession session = environment.createSession(MODEL_URI);) {
            String[] sentences = new String[] { "I like you", "我喜欢你", "我讨厌你" };           
            float[][] embeddings = emb(environment, session, tokenizer, sentences);
            double similaryity = cosineSimilarity(embeddings[0], embeddings[1]);
            System.out.println(similaryity);
            double similaryity2 = cosineSimilarity(embeddings[0], embeddings[2]);
            System.out.println(similaryity2);
            double similaryity3 = cosineSimilarity(embeddings[1], embeddings[2]);
            System.out.println(similaryity3);
        }

    }

    public static float[][] emb(OrtEnvironment environment, OrtSession session, HuggingFaceTokenizer tokenizer,
            String[] sentences) throws OrtException {
        Encoding[] encodings = tokenizer.batchEncode(sentences);
        long[][] input_ids0 = new long[encodings.length][];
        long[][] attention_mask0 = new long[encodings.length][];
        float[][] embeddings = new float[0][0];
        for (int i = 0; i < encodings.length; i++) {
            input_ids0[i] = encodings[i].getIds();
            attention_mask0[i] = encodings[i].getAttentionMask();
        }
        try (OnnxTensor inputIds = OnnxTensor.createTensor(environment, input_ids0);
                OnnxTensor attentionMask = OnnxTensor.createTensor(environment, attention_mask0);) {
            Map<String, OnnxTensor> inputs = new HashMap<>();
            inputs.put("input_ids", inputIds);
            inputs.put("attention_mask", attentionMask);
            try (OrtSession.Result results = session.run(inputs)) {
                embeddings = (float[][]) results.get("sentence_embedding").get().getValue();
            }
            inputs.clear();

        }
        return embeddings;

    }

    public static double cosineSimilarity(float[] vectorA, float[] vectorB) {
        float dotProduct = 0.0f;
        float normA = 0.0f;
        float normB = 0.0f;
        for (int i = 0; i < vectorA.length; i++) {
            float v1 = vectorA[i];
            float v2 = vectorB[i];
            dotProduct += v1 * v2;
            normA += Math.pow(v1, 2);
            normB += Math.pow(v2, 2);
        }
        if (normA == 0 && normB == 0) {
            return 0.0f;
        }
        return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
    }

}

源码地址:yucl80/ai-demo-java (github.com)

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值