在Elasticsearch 7.0中,ES引入了高维向量的字段类型:
dense_vector存储稠密向量,value是单一的float数值,可以是0、负数或正数,dense_vector数组的最大长度不能超过1024,每个文档的数组长度可以不同。
sparse_vector存储稀疏向量,value是单一的float数值,可以是0、负数或正数,sparse_vector存储的是个非嵌套类型的json对象,key是向量的位置,即integer类型的字符串,范围[0,65535]。
ElasticSearch版本:elasticsearch-7.3.0
环境准备:
curl -H "Content-Type: application/json" -XPUT 'http://192.168.0.1:9200/article_v1/' -d '
{
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"dynamic": "strict",
"properties": {
"id": {
"type": "keyword"
},
"title": {
"analyzer": "ik_smart",
"type": "text"
},
"title_dv": {
"type": "dense_vector",
"dims": 200
},
"title_sv": {
"type": "sparse_vector"
}
}
}
}
'
测试验证代码:
# -*- coding:utf-8 -*-
import os
import sys
import jieba
import logging
import pymongo
from elasticsearch import Elasticsearch
from elasticsearch.serializer import TextSerializer, JSONSerializer
from gensim.models.doc2vec import TaggedDocument, Doc2Vec
default_encoding = 'utf-8'
if sys.getdefaultencoding() != default_encoding:
reload(sys)
sys.setdefaultencoding(default_encoding)
logging.basicConfig(format='%(asctime)s:%(levelname)s:%(message)s', level=logging.INFO)
# 网上随便爬取一些新闻存入数据库
client = pymongo.MongoClient(host='192.168.0.1', port=27017)
db = client['news']
es = Elasticsearch([{'host': '192.168.0.1', 'port': 9200}], timeout=3600)
chinese_stop_words_file = os.path.abspath(os.getcwd() + os.sep + '..' + os.sep + 'static' + os.sep + 'dic' + os.sep + 'chinese_stop_words.txt')
chinese_stop_words = [line.strip() for line in open(chinese_stop_words_file, 'r').readlines()]
total_cut_word_count = 0
# 句子分割
def sentence_segment(sentence):
global total_cut_word_count
result = []
cut_words = jieba.cut(sentence)
for cut_word in cut_words:
if cut_word not in chinese_stop_words:
result.append(cut_word)
total_cut_word_count += 1
return result
# 准备语料库
def prepare_doc_corpus():
datas = db['netease_ent_news_detail'].find({"create_time": {"$ne": None}}).sort('create_time', pymongo.ASCENDING)
print datas.count()
for i, data in enumerate(datas):
if data['title'] is not None and data['content'] is not None:
title = str(data['title']).strip()
yield TaggedDocument(sentence_segment(title), [data['_id']])
# 训练模型
def train_doc_model():
corpus = prepare_doc_corpus()
doc2vec = Doc2Vec(vector_size=200, min_count=2, window=5, workers=4, epochs=20)
doc2vec.build_vocab(corpus)
doc2vec.train(corpus, total_examples=doc2vec.corpus_count, epochs=doc2vec.epochs)
doc2vec.save('doc2vec.model')
def insert_data_to_es():
datas = db['netease_ent_news_detail'].find({"create_time": {"$ne": None}}).sort('create_time', pymongo.ASCENDING)
print datas.count()
doc2vec = Doc2Vec.load('doc2vec.model')
for data in datas:
if data['title'] is not None and data['content'] is not None:
sentence = str(data['title']).strip()
title_dv = doc2vec.infer_vector(sentence_segment(sentence)).tolist()
body = {"id": data['_id'], "title": data['title'], "title_dv": title_dv}
es_result = es.create(index="article_v1", doc_type="_doc",
id=data['_id'], body=body, ignore=[400, 409])
print es_result
# cosineSimilarity函数计算给定文档与索引库里文档的dense_vector相似度
def search_es_dense_vertor_1(sentence):
doc2vec = Doc2Vec.load('doc2vec.model')
query_vector = doc2vec.infer_vector(sentence_segment(sentence)).tolist()
body = {
"query": {
"script_score": {
"query": {
"match_all": {}
},
"script": {
"source": "cosineSimilarity(params.queryVector, doc['title_dv']) + 1",
"params": {
"queryVector": query_vector
}
}
}
},
"from": 0,
"size": 5
}
result = es.search(index="article_v1", body=body)
hits = result['hits']['hits']
for hit in hits:
source = hit['_source']
for key, value in source.items():
print '%s %s' % (key, value)
print '----------'
# dotProduct函数计算给定文档与索引库文档点积的距离
def search_es_dense_vertor_2(sentence):
doc2vec = Doc2Vec.load('doc2vec.model')
query_vector = doc2vec.infer_vector(sentence_segment(sentence)).tolist()
body = {
"query": {
"script_score": {
"query": {
"match_all": {}
},
"script": {
"source": "dotProduct(params.queryVector, doc['title_dv']) + 1",
"params": {
"queryVector": query_vector
}
}
}
},
"from": 0,
"size": 5
}
result = es.search(index="article_v1", body=body)
hits = result['hits']['hits']
for hit in hits:
source = hit['_source']
for key, value in source.items():
print '%s %s' % (key, value)
print '----------'