前言
进行技术储备。直接上代码。
创建ES向量索引
PUT test_img
{
"settings": {
"index.codec": "proxima",
"index.vector.algorithm": "hnsw",
"index.number_of_replicas":1,
"index.number_of_shards":3
},
"mappings": {
"properties": {
"feature": {
"type": "proxima_vector", //向量索引
"dim": 512, //512维度
"vector_type": "float"
},
"goods_id":{
"type": "long"
}
}
}
}
集成AIAS的jar包
注意本环境是centos8.2,centos7.9以上才支持图片引擎
图片处理jar包:AIAS官方
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.22.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>0.22.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>0.22.0</version>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.22.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>0.22.0</version>
</dependency>
<dependency>
<groupId>ai.djl.opencv</groupId>
<artifactId>opencv</artifactId>
<version>0.22.0</version>
</dependency>
提取图片特征
先定义一个实体类
@Data
public class Fc_GoodsFeature {
Long goods_id;
float[] feature;
String goods_name;
}
定义图片提取工具类
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import lombok.extern.slf4j.Slf4j;
import org.opencv.core.Mat;
@Slf4j
public class ImageUtil {
public static Criteria<Image, float[]> criteria = null;
public static ZooModel model = null;
public static Predictor<Image, float[]> predictor = null;
public static ImageFactory imageFactory = null;
public static void ImageInit(){
try{
criteria = new ImageEncoderModel().criteria();
model = ModelZoo.loadModel(criteria);
predictor = model.newPredictor();
imageFactory = ImageFactory.getInstance();
}catch (Exception e){
e.printStackTrace();
}
}
public static float[] featureExtraction(String url){
float[] feature = null;
try {
log.info("开始获取图片");
Image img = imageFactory.fromUrl("http://it.mz.gold"+url);
log.info("获取图片结束");
log.info("开始提取特征");
feature = predictor.predict(img);
((Mat) img.getWrappedImage()).release();
log.info("提取特征结束");
}catch (Exception e){
e.printStackTrace();
}
return feature;
}
}
写好同步ES方法
public static void indexFeature(Fc_GoodsFeature goods){
try {
Es_GoodsFeature esGoodsFeature=new Es_GoodsFeature();
esGoodsFeature.setGoods_id(goods.getGoods_id());
esGoodsFeature.setFeature(goods.getFeature());
esGoodsFeature.setGoods_name(goods.getGoods_name());
indexFeature(esGoodsFeature,0);
} catch (Exception e) {
e.printStackTrace();
log.error("添加到ES异常:"+ goods);
}
}
public static void indexFeature(Es_GoodsFeature esGoods,int retryNum){
log.debug("索引到ES开始:"+esGoods.toString());
try {
//index_name为索引名称;type_name为类型名称,7.0及以上版本必须为_doc;doc_id为文档的id。
IndexRequest indexRequest = new IndexRequest(esFeatureConfig.getIndex(), "_doc", esGoods.getGoods_id().toString()).source(JSON.toJSONString(esGoods), XContentType.JSON);
// 同步执行,并使用自定义RequestOptions(COMMON_OPTIONS)。
IndexResponse indexResponse = esFeatureClient.index(indexRequest, RequestOptions.DEFAULT);
} catch (IOException e) {
e.printStackTrace();
log.error("添加到ES异常:"+esGoods.toString());
retryNum++;
if(retryNum<=2){
log.info("重新同步ES数据{}:{}",retryNum,esGoods.toString());
indexFeature(esGoods,retryNum);
}
}
log.debug("索引到ES结束:"+esGoods.toString());
}
同步图片特征到ES
String coverImg = "图片地址";
if(coverImg.contains("https")){
coverImg = coverImg.substring(coverImg.indexOf("/data"));
}
float[] feature = ImageUtil.featureExtraction(coverImg);//提取图片特征
if(feature == null) return;
Fc_GoodsFeature goodsFeature=new Fc_GoodsFeature();
goodsFeature.setGoods_id(goods.getGoods_id());
goodsFeature.setFeature(feature);
goodsFeature.setGoods_name(goods.getGoods_name());
EsPriceService.indexFeature(goodsFeature);//讲图片特征同步到ES
搜索图片
并发量越大占用的内存越大,不超过并发量上限就不会增加内存
//定义锁 不然图片提取会占用太多内存与CPU
public static ReentrantLock imgLock = new ReentrantLock();
//初始换ES链接
public static void EsFeatureInit(EsFeatureConfig esFeaCon) {
try {
log.info("ES初始化开始...");
esFeatureConfig=esFeaCon;
//初始化ES
// 阿里云Elasticsearch集群需要basic auth验证。
final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
//访问用户名和密码为您创建阿里云Elasticsearch实例时设置的用户名和密码,也是Kibana控制台的登录用户名和密码。
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(esFeatureConfig.getUsername(), esFeatureConfig.getPassword()));
// 通过builder创建rest client,配置http client的HttpClientConfigCallback。
// 单击所创建的Elasticsearch实例ID,在基本信息页面获取公网地址,即为ES集群地址。
RestClientBuilder restClientBuilder = RestClient.builder(new HttpHost(esFeatureConfig.getHost(), esFeatureConfig.getPort()))
.setHttpClientConfigCallback(new RestClientBuilder.HttpClientConfigCallback() {
@Override
public HttpAsyncClientBuilder customizeHttpClient(HttpAsyncClientBuilder httpClientBuilder) {
return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
}
});
esClient=new RestHighLevelClient(restClientBuilder);
searchRequest = new SearchRequest(esFeatureConfig.getIndex());
searchSourceBuilder = new SearchSourceBuilder();
log.info("ES初始化完毕!");
} catch (Exception e) {
e.printStackTrace();
}
}
//搜索接口
@GetMapping("ImgSearch")
public String ImgSearch(String msg){
String str = "111";
String imgUrl = "图片地址";
try{
imgLock.lock();
float[] feature = ImageUtil.featureExtraction(imgUrl);//提取搜索的图片特征
imgLock.unlock();
//拼接ES查询语句
String hnswQuery = "{\n" +
" \"hnsw\": { \n" +
" \"feature\": {\n" +
" \"vector\": " + Arrays.toString(feature) + ", \n" +
" \"size\": 10 \n" +
" }\n" +
" }\n" +
" }";
queryBuilder = QueryBuilders.wrapperQuery(hnswQuery);
searchSourceBuilder.query(queryBuilder)
.fetchSource(new String[]{"goods_id","goods_name"},null);//指定返回的字段
// .minScore(0.9f);//匹配度
searchRequest.source(searchSourceBuilder);
log.info("开始ES搜索");
searchResponse = esClient.search(searchRequest, RequestOptions.DEFAULT);
log.info("ES搜索结束");
hits = searchResponse.getHits();
hit = hits.getAt(0);
str = hit.getId();
log.info("id: "+hit.getId()+" 商品名:"+hit.getSourceAsMap().get("goods_name")+" 相似度:"+hit.getScore());
// for (SearchHit hit : hits){
// System.out.println("id: "+hit.getId()+" 商品名:"+hit.getSourceAsMap().get("goods_name")+" 相似度:"+hit.getScore());
// }
}catch (Exception e){
if(imgLock.isLocked()){
imgLock.unlock();
}
return e.toString();
}
return str;
}
总结
记录到这了。