支持数据库的embedding模型

嵌入数据库,再也不用担心词的数量太多

嵌入模型是神经网络提取数据特征向量的重要工具,承担整个模型的记忆功能,但随着数据量及数据类型的增加,需要保存在内存上的张量也越大,机器成本越来越高,大量数据的向量在训练过程中只有局部生效,为什么不放在数据库中按需提取?所以,封装了nn.Embedding模型用来支持数据库存储。效果还行

dict_embedding.py

import torch
import torch.nn as nn
import sqlite3

from configs.task_config import config


class DictEmbedding(nn.Module):
    def __init__(self, cache_size, hidden_dim, nn_pool, table_name='embedding'):
        super(DictEmbedding, self).__init__()
        self.cache_size = cache_size
        self.hidden_dim = hidden_dim

        # 如果是词多是连续就用nn.Parameter, 随机选取nn.Embedding性能会更好
        self.emb_table = nn.Embedding(cache_size, hidden_dim)

        self.word2emb = {}
        self.emb2word = {}

        self.emb_buffer = 0

        self.table_name = table_name
        self.sqlite_pool = nn_pool

    def check_buffer(self, word_ids):
        word_id = self.emb2word.get(self.emb_buffer)
        while word_id is not None and word_id in word_ids:
            self.emb_buffer = (self.emb_buffer + 1) % self.cache_size
            word_id = self.emb2word.get(self.emb_buffer)

        self.word2emb.pop(word_id, None)

        return self.emb_buffer, word_id

    def get_index(self, word_ids):
        emb_ids = [self.word2emb.get(key) for key in word_ids]
        if None in emb_ids:
            new_add = {}
            old_word_ids = []

            for k, v in enumerate(emb_ids):
                if v is None:
                    word_id = word_ids[k]
                    emb_id = new_add.get(word_id)
                    if emb_id is not None:
                        emb_ids[k] = emb_id
                        continue
                    new_emb_id, old_word_id = self.check_buffer(word_ids)

                    self.emb2word[new_emb_id] = word_id
                    self.word2emb[word_id] = new_emb_id

                    emb_ids[k] = new_emb_id

                    self.emb_buffer = (self.emb_buffer + 1) % self.cache_size

                    new_add[word_id] = new_emb_id
                    old_word_ids.append(old_word_id)

            self.update_emb(new_add, old_word_ids)
        return emb_ids

    def forward(self, word_ids):
        emb_ids = self.get_index(word_ids)
        emb_vs = self.emb_table(torch.IntTensor(emb_ids).to(config.device))
        return emb_vs

    def save_emb(self, word_id, emb_vector):
        self.sqlite_pool.insert(f"REPLACE INTO {self.table_name}(id, emb) VALUES (?, ?)",
                                (word_id, sqlite3.Binary(emb_vector.cpu().detach().numpy())))

    def save_emb_bath_(self, word_ids, emb_vectors):
        embeddings = [sqlite3.Binary(e.cpu().detach().numpy()) for e in emb_vectors]

        data = [(i, embs) for i, embs in zip(word_ids, embeddings)]

        def _executemany(cursor):
            cursor.executemany(f'REPLACE INTO {self.table_name}(id, emb) VALUES (?, ?)', data)

        self.sqlite_pool.insert_bath(_executemany)

    def save_emb_bath(self, word_ids=None, emb_vectors=None, update_cache=False):
        if word_ids == None:
            word_ids = self.word2emb.keys()

        if emb_vectors is None:
            emb_vectors = self(word_ids)

        self.save_emb_bath_(word_ids, emb_vectors)

        if update_cache and emb_vectors is not None:
            emb_ids = self.get_index(word_ids)
            for k, v in enumerate(word_ids):
                self.emb_table.weight.data[emb_ids[k]] = emb_vectors[k]

    def update_emb(self, new_add, old_word_ids):
        word_ids = list(new_add.keys())
        emb_ids = list(new_add.values())

        placeholders = ', '.join(['(?)'] * len(word_ids))
        query = f"SELECT t.column1 AS eid, e.emb FROM (VALUES {placeholders}) AS t LEFT JOIN {self.table_name} e ON e.id = t.column1"
        out = self.sqlite_pool.query_where(query, word_ids)

        for k, (row, word, emb) in enumerate(zip(out, old_word_ids, emb_ids)):
            r0 = row[0]
            r1 = row[1]
            if r1 is None:  # 未查到的值 生成数据加入数据库同时更新emb_table缓存
                emb_vector = torch.randn(self.hidden_dim)
                self.save_emb(r0, emb_vector)
                self.emb_table.weight.data[emb] = emb_vector
            else:
                if word:
                    self.save_emb(word, self.emb_table.weight.data[emb])  # 保存embdding历史
                self.emb_table.weight.data[emb] = torch.frombuffer(bytearray(r1), dtype=torch.float32)

sqlite_pool.py 调用数据库,这里用sqlite是因为sqlite比较方便,改成其它数据库效果会更好,一些词向量库还支持向量搜索算法

import sqlite3
import queue
import logging

logging.getLogger().setLevel(logging.INFO)
connection_pools = {}

class ConnectionPool:
    def __init__(self, max_connections=2, timeout=1, dbpath="nn.db"):
        self.max_connections = max_connections
        self.timeout = timeout
        self.connections = queue.Queue(max_connections)
        self.dbpath = dbpath

        for _ in range(max_connections):
            conn = self.create_connection(self.dbpath)
            if conn is not None:
                self.connections.put(conn)

    @staticmethod
    def create_connection(dbpath):
        try:
            conn = sqlite3.connect(dbpath, isolation_level=None, check_same_thread=False)  # 关闭事务
            return conn
        except sqlite3.Error as e:
            logging.error("Error creating connection:", str(e))
            return None

    def get_connection(self):
        try:
            conn = self.connections.get(timeout=self.timeout)
            return conn
        except queue.Empty:
            logging.error("Connection timeout. No available connections.}")
            return None

    def release_connection(self, conn):
        if conn is not None:
            if self.connections.full():
                logging.warning("Connection pool is full. Cannot release connection.")
            else:
                self.connections.put(conn)

    def execute_query(self, sql, query_data=None, insert_data=None, update_data=None):
        """

        Args:
            query_data:
            sql:
            insert_data:
            update_data:

        Returns:

        """
        conn = self.get_connection()
        if conn is None:
            return None

        try:
            row = None
            cursor = conn.cursor()

            if query_data:
                cursor.execute(sql, query_data)
                row = cursor.fetchone()
            elif insert_data:
                cursor.execute(sql, insert_data)
            elif update_data:
                cursor.execute(sql, update_data)
            else:
                cursor.execute(sql)
                row = cursor.fetchone()
            return row

        except sqlite3.Error as e:
            conn.rollback()
            logging.error("Error executing query:", str(e))
        finally:
            cursor.close()
            self.release_connection(conn)

    def query_where(self, sql, query_data):
        conn = self.get_connection()
        if conn is None:
            return None

        try:
            cursor = conn.cursor()
            cursor.execute(sql, query_data)
            row = cursor.fetchall()
            return row

        except sqlite3.Error as e:
            logging.error("Error executing query:", str(e))
        finally:
            cursor.close()
            self.release_connection(conn)

    def query(self, sql, query_data):
        return self.execute_query(sql, query_data=query_data)

    def insert(self, sql, insert_data):
        return self.execute_query(sql, insert_data=insert_data)

    def update(self, sql, update_data):
        return self.execute_query(sql, update_data=update_data)

    def insert_bath(self, fn):
        conn = self.get_connection()
        if conn is None:
            return None

        try:
            cursor = conn.cursor()
            # WAL mode for multi-processing
            cursor.execute('PRAGMA journal_mode=wal')  # https://www.coder.work/article/2441365
            cursor.execute('PRAGMA synchronous=OFF')  #

            fn(cursor)
        except sqlite3.Error as e:
            conn.rollback()
            logging.error("Error executing query:", str(e))
        finally:
            cursor.close()
            self.release_connection(conn)


def get_pool(dbpath, pool_name, max_connections=2, timeout=3):
    pool = connection_pools.get(pool_name)
    if pool is None:
        pool = ConnectionPool(max_connections=max_connections, timeout=timeout, dbpath=dbpath)
        connection_pools[pool_name] = pool
    return pool

new table

CREATE TABLE IF NOT EXISTS "embedding" (
  "id" INTEGER NOT NULL,
  "code" real,
  "emb" blob NOT NULL,
  "score" real NOT NULL DEFAULT 0,
  "date" TIMESTAMP default (datetime('now', 'localtime')),
  PRIMARY KEY ("id")
);

使用

    from data.sqlite_pool import get_pool
    import tiktoken

    # cache_size = 50
    # hidden_dim = 16

    nn_pool = get_pool(config.db.nn_db, config.db.nn_db_name)

    emb = DictEmbedding(1000, 256, nn_pool).to(config.device)
    emb.eval()

    enc = tiktoken.get_encoding("cl100k_base")

    with torch.no_grad():
        out = emb(enc.encode("你好"))
        ids = enc.encode("你好gsdf145")
        out2 = emb(ids)
        print(out.size(), out2.size())

    emb.save_emb_bath(ids, out2)
Azure Open AI Embedding是一种嵌入模型,通过使用Azure平台上的开放AI服务,可以将文本转换为高维向量表示。这种嵌入模型可以用于自然语言处理、推荐系统等领域,帮助我们更好地理解和处理文本数据。在使用Azure Open AI Embedding时,我们可以利用Azure的强大计算资源和高效的API接口,快速获取到所需的文本嵌入表示。 Milvus本地向量数据库是一种针对大规模向量数据的高性能数据库。它提供了快速的向量相似度搜索和存储功能,可以高效地应用于图像识别、人脸识别、文本检索等领域。在搭建Milvus本地向量数据库的单例安装和使用时,我们可以通过简单的配置和管理,快速部署本地向量检索系统,并且能够自由定制化自己的向量索引。 对于私有模型的应用,可以将Azure Open AI Embedding模型和Milvus本地向量数据库结合起来。首先,可以使用Azure Open AI Embedding模型将文本数据转换为向量表示,然后将这些向量存储到Milvus本地向量数据库中进行索引和检索。这样可以实现自己的定制化文本嵌入表示和快速的向量相似度搜索。同时,我们也可以通过对Milvus本地向量数据库进行单例安装和私有化部署,更好地保护自己的数据和模型隐私。这样的集成和应用可以帮助我们更好地理解和处理大规模文本数据,并且能够高效地进行相似度搜索和检索。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值