1. 搭建maven 项目,项目pom.xml 文件为:
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>my-elastisearch-aknn5</groupId>
<artifactId>my-elastisearch-aknn5</artifactId>
<version>0.0.1-SNAPSHOT</version>
<packaging>jar</packaging>
<name>my-elastisearch-aknn5</name>
<url>http://maven.apache.org</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>3.8.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.elasticsearch</groupId>
<artifactId>elasticsearch</artifactId>
<version>5.3.3</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-api</artifactId>
<version>2.6.1</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-assembly-plugin</artifactId>
<version>2.3</version>
<configuration>
<appendAssemblyId>false</appendAssemblyId>
<outputDirectory>${project.build.directory}/releases/</outputDirectory>
<descriptors>
<descriptor>${basedir}/src/assembly/plugin.xml</descriptor>
</descriptors>
</configuration>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</build>
</project>
2. 要构建elasticsearch 插件,还要加入如下配置文件plugin-descriptor.properties
description=aknn
version=1.0
name=aknn
classname=org.elasticsearch.plugin.aknn.AknnPlugin
java.version=1.8
elasticsearch.version=5.3.3
其中classname 制定插件全类名,java.version 制定jdk 的版本,elastisearch.version 制定elastisearch 的版本
3. 要构建elastisearch 插件,其类必须实现 ActionPlugin 类,我们的插件类为:
package org.elasticsearch.plugin.aknn;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.IndexScopedSettings;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsFilter;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;
public class AknnPlugin extends Plugin implements ActionPlugin {
private static final Setting<String> SETTINGS =
new Setting<>("aknn.sample.setting", "foo", (value) -> value, Setting.Property.NodeScope);
@Override
public List<Setting<?>> getSettings() {
return Arrays.asList(SETTINGS);
}
public List<RestHandler> getRestHandlers(final Settings settings,
final RestController restController,
final ClusterSettings clusterSettings,
final IndexScopedSettings indexScopedSettings,
final SettingsFilter settingsFilter,
final IndexNameExpressionResolver indexNameExpressionResolver,
final Supplier<DiscoveryNodes> nodesInCluster) {
return Arrays.asList(new AknnRestAction(settings, restController));
}
}
4. 要构建插件模型,我们还要定义如何构建一个常驻es 内存的模型,模型以一个索引的形式存储在es集群中,我们构建模型的代码为:
package org.elasticsearch.plugin.aknn;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.StopWatch;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static java.lang.Math.min;
import static org.elasticsearch.rest.RestRequest.Method.GET;
import static org.elasticsearch.rest.RestRequest.Method.POST;
public class AknnRestAction extends BaseRestHandler {
public static String NAME = “_aknn”;
private final String NAME_SEARCH = “_aknn_search”;
private final String NAME_INDEX = “_aknn_index”;
private final String NAME_CREATE = “_aknn_create”;
// TODO: check how parameters should be defined at the plugin level.
private final String HASHES_KEY = “_aknn_hashes”;
private final String VECTOR_KEY = “_aknn_vector”;
private final Integer K1_DEFAULT = 99;
private final Integer K2_DEFAULT = 10;
// TODO: add an option to the index endpoint handler that empties the cache.
private Map<String, LshModel> lshModelCache = new HashMap<>();
@Inject
public AknnRestAction(Settings settings, RestController controller) {
super(settings);
controller.registerHandler(GET, “/{index}/{type}/{id}/” + NAME_SEARCH, this);
controller.registerHandler(POST, NAME_INDEX, this);
controller.registerHandler(POST, NAME_CREATE, this);
}
public String getName() {
return NAME;
}
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
if (restRequest.path().endsWith(NAME_SEARCH))
return handleSearchRequest(restRequest, client);
else if (restRequest.path().endsWith(NAME_INDEX))
return handleIndexRequest(restRequest, client);
else
return handleCreateRequest(restRequest, client);
}
public static Double euclideanDistance(List A, List B) {
Double squaredDistance = 0.;
for (Integer i = 0; i < A.size(); i++)
squaredDistance += Math.pow(A.get(i) - B.get(i), 2);
return Math.sqrt(squaredDistance);
}
private RestChannelConsumer handleSearchRequest(RestRequest restRequest, NodeClient client) throws IOException {
StopWatch stopWatch = new StopWatch("StopWatch to Time Search Request");
// Parse request parameters.
stopWatch.start("Parse request parameters");
final String index = restRequest.param("index");
final String type = restRequest.param("type");
final String id = restRequest.param("id");
final Integer k1 = restRequest.paramAsInt("k1", K1_DEFAULT);
final Integer k2 = restRequest.paramAsInt("k2", K2_DEFAULT);
stopWatch.stop();
logger.info("Get query document at {}/{}/{}", index, type, id);
stopWatch.start("Get query document");
GetResponse queryGetResponse = client.prepareGet(index, type, id).get();
Map<String, Object> baseSource = queryGetResponse.getSource();
stopWatch.stop();
logger.info("Parse query document hashes");
stopWatch.start("Parse query document hashes");
@SuppressWarnings("unchecked")
Map<String, Long> queryHashes = (Map<String, Long>) baseSource.get(HASHES_KEY);
stopWatch.stop();
stopWatch.start("Parse query document vector");
@SuppressWarnings("unchecked")
List<Double> queryVector = (List<Double>) baseSource.get(VECTOR_KEY);
stopWatch.stop();
// Retrieve the documents with most matching hashes. https://stackoverflow.com/questions/10773581
logger.info("Build boolean query from hashes");
stopWatch.start("Build boolean query from hashes");
QueryBuilder queryBuilder = QueryBuilders.boolQuery();
for (Map.Entry<String, Long> entry : queryHashes.entrySet()) {
String termKey = HASHES_KEY + "." + entry.getKey();
((BoolQueryBuilder) queryBuilder).should(QueryBuilders.termQuery(termKey, entry.getValue()));
}
stopWatch.stop();
logger.info("Execute boolean search");
stopWatch.start("Execute boolean search");
SearchResponse approximateSearchResponse = client
.prepareSearch(index)
.setTypes(type)
.setFetchSource("*", HASHES_KEY)
.setQuery(queryBuilder)
.setSize(k1)
.get();
stopWatch.stop();
// Compute exact KNN on the approximate neighbors.
// Recreate the SearchHit structure, but remove the vector and hashes.
logger.info("Compute exact distance and construct search hits");
stopWatch.start("Compute exact distance and construct search hits");
List<Map<String, Object>> modifiedSortedHits = new ArrayList<>();
for (SearchHit hit: approximateSearchResponse.getHits())
//原来的版本
// Map<String, Object> hitSource = hit.getSourceAsMap();
Map<String, Object> hitSource = hit.getSource();
@SuppressWarnings("unchecked")
List<Double> hitVector = (List<Double>) hitSource.get(VECTOR_KEY);
hitSource.remove(VECTOR_KEY);
hitSource.remove(HASHES_KEY);
modifiedSortedHits.add(new HashMap<String, Object>() {{
put("_index", hit.getIndex());
put("_id", hit.getId());
put("_type", hit.getType());
put("_score", euclideanDistance(queryVector, hitVector));
put("_source", hitSource);
}});
}
stopWatch.stop();
logger.info("Sort search hits by exact distance");
stopWatch.start("Sort search hits by exact distance");
modifiedSortedHits.sort(Comparator.comparingDouble(x -> (Double) x.get("_score")));
stopWatch.stop();
logger.info("Timing summary\n {}", stopWatch.prettyPrint());
return channel -> {
XContentBuilder builder = channel.newBuilder();
builder.startObject();
builder.field("took", stopWatch.totalTime().getMillis());
builder.field("timed_out", false);
builder.startObject("hits");
builder.field("max_score", 0);
// In some cases there will not be enough approximate matches to return *k2* hits. For example, this could
// be the case if the number of bits per table in the LSH model is too high, over-partioning the space.
builder.field("total", min(k2, modifiedSortedHits.size()));
builder.field("hits", modifiedSortedHits.subList(0, min(k2, modifiedSortedHits.size())));
builder.endObject();
builder.endObject();
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
};
}
private RestChannelConsumer handleCreateRequest(RestRequest restRequest, NodeClient client) throws IOException {
StopWatch stopWatch = new StopWatch("StopWatch to time create request");
logger.info("Parse request");
stopWatch.start("Parse request");
XContentParser xContentParser = XContentHelper.createParser(
restRequest.getXContentRegistry(), restRequest.content(), restRequest.getXContentType());
Map<String, Object> contentMap = xContentParser.mapOrdered();
@SuppressWarnings("unchecked")
Map<String, Object> sourceMap = (Map<String, Object>) contentMap.get("_source");
final String _index = String.valueOf( contentMap.get("_index"));
final String _type = String.valueOf( contentMap.get("_type"));
final String _id = String.valueOf( contentMap.get("_id"));
final String description = String.valueOf( sourceMap.get("_aknn_description"));
final Integer nbTables = (Integer) sourceMap.get("_aknn_nb_tables");
final Integer nbBitsPerTable = (Integer) sourceMap.get("_aknn_nb_bits_per_table");
final Integer nbDimensions = (Integer) sourceMap.get("_aknn_nb_dimensions");
@SuppressWarnings("unchecked")
final List<List<Double>> vectorSample = (List<List<Double>>) contentMap.get("_aknn_vector_sample");
stopWatch.stop();
logger.info("Fit LSH model from sample vectors");
stopWatch.start("Fit LSH model from sample vectors");
LshModel lshModel = new LshModel(nbTables, nbBitsPerTable, nbDimensions, description);
lshModel.fitFromVectorSample(vectorSample);
stopWatch.stop();
logger.info("Serialize LSH model");
stopWatch.start("Serialize LSH model");
Map<String, Object> lshSerialized = lshModel.toMap();
stopWatch.stop();
logger.info("Index LSH model");
stopWatch.start("Index LSH model");
IndexResponse indexResponse = client.prepareIndex(_index, _type, _id)
.setSource(lshSerialized)
.get();
stopWatch.stop();
logger.info("Timing summary\n {}", stopWatch.prettyPrint());
return channel -> {
XContentBuilder builder = channel.newBuilder();
builder.startObject();
builder.field("took", stopWatch.totalTime().getMillis());
builder.endObject();
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
};
}
private RestChannelConsumer handleIndexRequest(RestRequest restRequest, NodeClient client) throws IOException {
StopWatch stopWatch = new StopWatch("StopWatch to time bulk indexing request");
logger.info("Parse request parameters");
stopWatch.start("Parse request parameters");
XContentParser xContentParser = XContentHelper.createParser(
restRequest.getXContentRegistry(), restRequest.content(), restRequest.getXContentType());
Map<String, Object> contentMap = xContentParser.mapOrdered();
final String index = String.valueOf( contentMap.get("_index"));
final String type = String.valueOf( contentMap.get("_type"));
final String aknnURI = String.valueOf( contentMap.get("_aknn_uri"));
@SuppressWarnings("unchecked")
final List<Map<String, Object>> docs = (List<Map<String, Object>>) contentMap.get("_aknn_docs");
logger.info("Received {} docs for indexing", docs.size());
stopWatch.stop();
// TODO: check if the index exists. If not, create a mapping which does not index continuous values.
// This is rather low priority, as I tried it via Python and it doesn't make much difference.
// Check if the LshModel has been cached. If not, retrieve the Aknn document and use it to populate the model.
LshModel lshModel;
if (! lshModelCache.containsKey(aknnURI)) {
// Get the Aknn document.
logger.info("Get Aknn model document from {}", aknnURI);
stopWatch.start("Get Aknn model document");
String[] annURITokens = aknnURI.split("/");
GetResponse aknnGetResponse = client.prepareGet(annURITokens[0], annURITokens[1], annURITokens[2]).get();
stopWatch.stop();
// Instantiate LSH from the source map.
logger.info("Parse Aknn model document");
stopWatch.start("Parse Aknn model document");
lshModel = LshModel.fromMap(aknnGetResponse.getSourceAsMap());
stopWatch.stop();
// Save for later.
lshModelCache.put(aknnURI, lshModel);
} else {
logger.info("Get Aknn model document from local cache");
stopWatch.start("Get Aknn model document from local cache");
lshModel = lshModelCache.get(aknnURI);
stopWatch.stop();
}
// Prepare documents for batch indexing.
logger.info("Hash documents for indexing");
stopWatch.start("Hash documents for indexing");
BulkRequestBuilder bulkIndexRequest = client.prepareBulk();
for (Map<String, Object> doc: docs) {
@SuppressWarnings("unchecked")
Map<String, Object> source = (Map<String, Object>) doc.get("_source");
@SuppressWarnings("unchecked")
List<Double> vector = (List<Double>) source.get(VECTOR_KEY);
source.put(HASHES_KEY, lshModel.getVectorHashes(vector));
bulkIndexRequest.add(client
.prepareIndex(index, type, String.valueOf(doc.get("_id")))
.setSource(source));
}
stopWatch.stop();
logger.info("Execute bulk indexing");
stopWatch.start("Execute bulk indexing");
BulkResponse bulkIndexResponse = bulkIndexRequest.get();
stopWatch.stop();
logger.info("Timing summary\n {}", stopWatch.prettyPrint());
if (bulkIndexResponse.hasFailures()) {
logger.error("Indexing failed with message: {}", bulkIndexResponse.buildFailureMessage());
return channel -> {
XContentBuilder builder = channel.newBuilder();
builder.startObject();
builder.field("took", stopWatch.totalTime().getMillis());
builder.field("error", bulkIndexResponse.buildFailureMessage());
builder.endObject();
channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, builder));
};
}
logger.info("Indexed {} docs successfully", docs.size());
return channel -> {
XContentBuilder builder = channel.newBuilder();
builder.startObject();
builder.field("size", docs.size());
builder.field("took", stopWatch.totalTime().getMillis());
builder.endObject();
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
};
}
}
这个类,指定 了模型生成,模型的查询使用,以及根据模型录入数据到es 集群中.
package org.elasticsearch.plugin.aknn;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class LshModel {
private Integer nbTables;
private Integer nbBitsPerTable;
private Integer nbDimensions;
private String description;
private List<RealMatrix> midpoints;
private List<RealMatrix> normals;
private List<RealMatrix> normalsTransposed;
private List<RealVector> thresholds;
public LshModel(Integer nbTables, Integer nbBitsPerTable, Integer nbDimensions, String description) {
this.nbTables = nbTables;
this.nbBitsPerTable = nbBitsPerTable;
this.nbDimensions = nbDimensions;
this.description = description;
this.midpoints = new ArrayList<>();
this.normals = new ArrayList<>();
this.normalsTransposed = new ArrayList<>();
this.thresholds = new ArrayList<>();
}
public void fitFromVectorSample(List<List<Double>> vectorSample) {
RealMatrix vectorsA, vectorsB, midpoint, normal, vectorSampleMatrix;
vectorSampleMatrix = MatrixUtils.createRealMatrix(vectorSample.size(), this.nbDimensions);
for (int i = 0; i < vectorSample.size(); i++)
for (int j = 0; j < this.nbDimensions; j++)
vectorSampleMatrix.setEntry(i, j, vectorSample.get(i).get(j));
for (int i = 0; i < vectorSampleMatrix.getRowDimension(); i += (nbBitsPerTable * 2)) {
// Select two subsets of nbBitsPerTable vectors.
vectorsA = vectorSampleMatrix.getSubMatrix(i, i + nbBitsPerTable - 1, 0, nbDimensions - 1);
vectorsB = vectorSampleMatrix.getSubMatrix(i + nbBitsPerTable, i + 2 * nbBitsPerTable - 1, 0, nbDimensions - 1);
// Compute the midpoint between each pair of vectors.
midpoint = vectorsA.add(vectorsB).scalarMultiply(0.5);
midpoints.add(midpoint);
// Compute the normal vectors for each pair of vectors.
normal = vectorsB.subtract(midpoint);
normals.add(normal);
}
}
public Map<String, Long> getVectorHashes(List<Double> vector) {
RealMatrix xDotNT, vectorAsMatrix;
RealVector threshold;
Map<String, Long> hashes = new HashMap<>();
Long hash;
Integer i, j;
// Have to convert the vector to a matrix to support multiplication below.
// TODO: if the List<Double> vector argument can be changed to an Array double[] or float[], this would be faster.
vectorAsMatrix = MatrixUtils.createRealMatrix(1, nbDimensions);
for (i = 0; i < nbDimensions; i++)
vectorAsMatrix.setEntry(0, i, vector.get(i));
// Compute the hash for this vector with respect to each table.
for (i = 0; i < nbTables; i++) {
xDotNT = vectorAsMatrix.multiply(normalsTransposed.get(i));
threshold = thresholds.get(i);
hash = 0L;
for (j = 0; j < nbBitsPerTable; j++)
if (xDotNT.getEntry(0, j) > threshold.getEntry(j))
hash += (long) Math.pow(2, j);
hashes.put(i.toString(), hash);
}
return hashes;
}
@SuppressWarnings("unchecked")
public static LshModel fromMap(Map<String, Object> serialized) {
LshModel lshModel = new LshModel(
(Integer) serialized.get("_aknn_nb_tables"), (Integer) serialized.get("_aknn_nb_bits_per_table"),
(Integer) serialized.get("_aknn_nb_dimensions"), (String) serialized.get("_aknn_description"));
// TODO: figure out how to cast directly to List<double[][]> or double[][][] and use MatrixUtils.createRealMatrix.
List<List<List<Double>>> midpointsRaw = (List<List<List<Double>>>) serialized.get("_aknn_midpoints");
List<List<List<Double>>> normalsRaw = (List<List<List<Double>>>) serialized.get("_aknn_normals");
for (int i = 0; i < lshModel.nbTables; i++) {
RealMatrix midpoint = MatrixUtils.createRealMatrix(lshModel.nbBitsPerTable, lshModel.nbDimensions);
RealMatrix normal = MatrixUtils.createRealMatrix(lshModel.nbBitsPerTable, lshModel.nbDimensions);
for (int j = 0; j < lshModel.nbBitsPerTable; j++) {
for (int k = 0; k < lshModel.nbDimensions; k++) {
midpoint.setEntry(j, k, midpointsRaw.get(i).get(j).get(k));
normal.setEntry(j, k, normalsRaw.get(i).get(j).get(k));
}
}
lshModel.midpoints.add(midpoint);
lshModel.normals.add(normal);
lshModel.normalsTransposed.add(normal.transpose());
}
for (int i = 0; i < lshModel.nbTables; i++) {
RealMatrix normal = lshModel.normals.get(i);
RealMatrix midpoint = lshModel.midpoints.get(i);
RealVector threshold = new ArrayRealVector(lshModel.nbBitsPerTable);
for (int j = 0; j < lshModel.nbBitsPerTable; j++)
threshold.setEntry(j, normal.getRowVector(j).dotProduct(midpoint.getRowVector(j)));
lshModel.thresholds.add(threshold);
}
return lshModel;
}
public Map<String, Object> toMap() {
return new HashMap<String, Object>() {{
put("_aknn_nb_tables", nbTables);
put("_aknn_nb_bits_per_table", nbBitsPerTable);
put("_aknn_nb_dimensions", nbDimensions);
put("_aknn_description", description);
put("_aknn_midpoints", midpoints.stream().map(realMatrix -> realMatrix.getData()).collect(Collectors.toList()));
put("_aknn_normals", normals.stream().map(normals -> normals.getData()).collect(Collectors.toList()));
}};
}
}
这个类,指定了,模型生成的生成规则。
5. maven 构建打包项目生成 my-elastisearch-aknn5-0.0.1-SNAPSHOT.jar 包,将文件拷贝到 elastisearch 集群中的每个节点的安装目录的plugins 目录下 ,然后重启elastisearch 节点即可,然后我们的 elastisearch Aknn 插件就构建成功了。