deeplearning4j使用vgg19图片向量比对springboot+es环境


一、桌面创建两个目录读图

二、POM

<dependency>
    <groupId>org.springframework.data</groupId>
    <artifactId>spring-data-elasticsearch</artifactId>
</dependency>

<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-beta7</version>
</dependency>
<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-zoo</artifactId>
    <version>1.0.0-beta7</version>
</dependency>
<dependency>
    <groupId>org.elasticsearch</groupId>
    <artifactId>elasticsearch</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.client</groupId>
    <artifactId>transport</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.client</groupId>
    <artifactId>elasticsearch-rest-client</artifactId>
</dependency>
<dependency>
    <groupId>org.elasticsearch.plugin</groupId>
    <artifactId>transport-netty4-client</artifactId>
</dependency>

三、code

import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.zoo.model.VGG19;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import javax.annotation.PostConstruct;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

@Service("vgg19Service")
public class Vgg19ServiceImpl implements Vgg19Service {

    private static ComputationGraph vgg19Model;

    @PostConstruct
    public void init() throws IOException {
        VGG19 vgg19 = VGG19.builder().build();
        vgg19Model = (ComputationGraph) vgg19.initPretrained();
    }

    @Autowired
    private INDArrayPojoRepository indArrayPojoRepository;

    @Override
    public String find(MultipartFile file) throws IOException {

//        VGG19 vgg19 = VGG19.builder().build();
//         vgg19Model = (ComputationGraph) vgg19.initPretrained();


        String templateImagePath = "C:\\Users\\Administrator\\Desktop\\template\\1.png";

        // 图像文件夹路径
        String imageFolder = "C:\\Users\\Administrator\\Desktop\\target";

        // 加载模板图像
        NativeImageLoader imageLoader = new NativeImageLoader(224, 224, 3);
        INDArray templateImage = imageLoader.asMatrix(new File(templateImagePath));

        // 提取模板图像的特征向量
        INDArray templateFeatures = vgg19Model.outputSingle(templateImage);

        // 存储图像相似度的映射
        Map<String, Double> similarityMap = new HashMap<>();

        // 遍历图像文件夹
        File folder = new File(imageFolder);
        File[] imageFiles = folder.listFiles();
        long i = 1L;
        indArrayPojoRepository.deleteAll();
        if (imageFiles != null) {
            for (File imageFile : imageFiles) {
                // 加载当前图像
//                INDArray currentImage = imageLoader.asMatrix(imageFile);
//                // 提取当前图像的特征向量
//                INDArray currentFeatures = vgg19Model.outputSingle(currentImage);
//                long[] longVector = currentFeatures.toLongVector();
//                System.out.println(longVector);
//                double[] doubleVector = currentFeatures.toDoubleVector();
//                System.out.println(new ImagesArrayPojo(i,doubleVector));
                indArrayPojoRepository.save( new ImagesArrayPojo(i,new double[]{1,11.11,1}));
//                indArrayPojoRepository.findBySimilarity(templateFeatures.toDoubleVector(), PageRequest.of(1, 20));
//                System.out.println(currentFeatures);
//                // 计算余弦相似度
//                double similarityScore = Transforms.cosineSim(templateFeatures, currentFeatures);
//
//                // 将图像名称和相似度存储到映射中
//                similarityMap.put(imageFile.getName(), similarityScore);

                i ++;
            }
        }

        // 打印相似度最高的三张图像名称
//        similarityMap.entrySet().stream()
//                .sorted(Map.Entry.<String, Double>comparingByValue().reversed())
//                .limit(3)
//                .forEach(entry -> System.out.println("Image: " + entry.getKey() + ", Similarity: " + entry.getValue()));
return null;
    }
}

java实体类

@Data
@AllArgsConstructor
@NoArgsConstructor
@Document(indexName = "images_double")
public class ImagesArrayPojo {

    @Id
    private Long id;

    @Field(type = FieldType.Dense_Vector,dims = 1000)
    private double[] ndDoubleArray;
}

搭配

<dependency>
     <groupId>org.springframework.boot</groupId>
     <artifactId>spring-boot-starter-data-elasticsearch</artifactId>
 </dependency>

四、es查询脚本

这里注意查看官方文档,不同的es脚本写法稍有不同,这里使用的是7.4.2

docker run -d -e ES_JAVA_OPTS="-Xms128m -Xmx128m" -e "discovery.type=single-node" -e "script.disable_dynamic: false" -p 9200:9200 -p 9300:9300 -e ES_MIN_MEM=128m -e ES_MAX_MEM=4096m --name es elasticsearch:7.4.2 
{
  "query": {
    "script_score": {
      "query": {
        "match_all": {}
      },
      "script": {
        "source": "cosineSimilarity(params.query_vector,doc['ndDoubleArray']) + 1.0",
        "params": {
          "query_vector": [维度数组]
        }
      }
    }
  }
}

五、没测试的代码

import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.annotations.Query;
import org.springframework.data.elasticsearch.repository.ElasticsearchRepository;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;

@Repository
public interface INDArrayPojoRepository extends ElasticsearchRepository<ImagesArrayPojo,Long> {

    @Query("{\n" +
            "  \"size\": 10,\n" +
            "  \"from\": 0,\n" +
            "  \"query\": {\n" +
            "    \"script_score\": {\n" +
            "      \"query\": {\n" +
            "        \"match_all\": {}\n" +
            "      },\n" +
            "      \"script\": {\n" +
            "        \"source\": \"cosineSimilarity(params.query_vector,doc['ndDoubleArray']) + 1.0\",\n" +
            "        \"params\": {\n" +
            "          \"query_vector\": [?1]\n" +
            "        }\n" +
            "      }\n" +
            "    }\n" +
            "  }\n" +
            "}")
    Page<ImagesArrayPojo> findBySimilarity(@Param("queryVector") double[] queryVector, Pageable pageable);
}

总结

思路:首先使用deeplearning4j加载vgg19采集图片的向量值,然后将向量值存储到es中,然后后续搜索使用es的余弦脚本查询

  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值