嵌入数据库,再也不用担心词的数量太多
嵌入模型是神经网络提取数据特征向量的重要工具,承担整个模型的记忆功能,但随着数据量及数据类型的增加,需要保存在内存上的张量也越大,机器成本越来越高,大量数据的向量在训练过程中只有局部生效,为什么不放在数据库中按需提取?所以,封装了nn.Embedding模型用来支持数据库存储。效果还行
dict_embedding.py
import torch
import torch.nn as nn
import sqlite3
from configs.task_config import config
class DictEmbedding(nn.Module):
def __init__(self, cache_size, hidden_dim, nn_pool, table_name='embedding'):
super(DictEmbedding, self).__init__()
self.cache_size = cache_size
self.hidden_dim = hidden_dim
# 如果是词多是连续就用nn.Parameter, 随机选取nn.Embedding性能会更好
self.emb_table = nn.Embedding(cache_size, hidden_dim)
self.word2emb = {}
self.emb2word = {}
self.emb_buffer = 0
self.table_name = table_name
self.sqlite_pool = nn_pool
def check_buffer(self, word_ids):
word_id = self.emb2word.get(self.emb_buffer)
while word_id is not None and word_id in word_ids:
self.emb_buffer = (self.emb_buffer + 1) % self.cache_size
word_id = self.emb2word.get(self.emb_buffer)
self.word2emb.pop(word_id, None)
return self.emb_buffer, word_id
def get_index(self, word_ids):
emb_ids = [self.word2emb.get(key) for key in word_ids]
if None in emb_ids:
new_add = {}
old_word_ids = []
for k, v in enumerate(emb_ids):
if v is None:
word_id = word_ids[k]
emb_id = new_add.get(word_id)
if emb_id is not None:
emb_ids[k] = emb_id
continue
new_emb_id, old_word_id = self.check_buffer(word_ids)
self.emb2word[new_emb_id] = word_id
self.word2emb[word_id] = new_emb_id
emb_ids[k] = new_emb_id
self.emb_buffer = (self.emb_buffer + 1) % self.cache_size
new_add[word_id] = new_emb_id
old_word_ids.append(old_word_id)
self.update_emb(new_add, old_word_ids)
return emb_ids
def forward(self, word_ids):
emb_ids = self.get_index(word_ids)
emb_vs = self.emb_table(torch.IntTensor(emb_ids).to(config.device))
return emb_vs
def save_emb(self, word_id, emb_vector):
self.sqlite_pool.insert(f"REPLACE INTO {self.table_name}(id, emb) VALUES (?, ?)",
(word_id, sqlite3.Binary(emb_vector.cpu().detach().numpy())))
def save_emb_bath_(self, word_ids, emb_vectors):
embeddings = [sqlite3.Binary(e.cpu().detach().numpy()) for e in emb_vectors]
data = [(i, embs) for i, embs in zip(word_ids, embeddings)]
def _executemany(cursor):
cursor.executemany(f'REPLACE INTO {self.table_name}(id, emb) VALUES (?, ?)', data)
self.sqlite_pool.insert_bath(_executemany)
def save_emb_bath(self, word_ids=None, emb_vectors=None, update_cache=False):
if word_ids == None:
word_ids = self.word2emb.keys()
if emb_vectors is None:
emb_vectors = self(word_ids)
self.save_emb_bath_(word_ids, emb_vectors)
if update_cache and emb_vectors is not None:
emb_ids = self.get_index(word_ids)
for k, v in enumerate(word_ids):
self.emb_table.weight.data[emb_ids[k]] = emb_vectors[k]
def update_emb(self, new_add, old_word_ids):
word_ids = list(new_add.keys())
emb_ids = list(new_add.values())
placeholders = ', '.join(['(?)'] * len(word_ids))
query = f"SELECT t.column1 AS eid, e.emb FROM (VALUES {placeholders}) AS t LEFT JOIN {self.table_name} e ON e.id = t.column1"
out = self.sqlite_pool.query_where(query, word_ids)
for k, (row, word, emb) in enumerate(zip(out, old_word_ids, emb_ids)):
r0 = row[0]
r1 = row[1]
if r1 is None: # 未查到的值 生成数据加入数据库同时更新emb_table缓存
emb_vector = torch.randn(self.hidden_dim)
self.save_emb(r0, emb_vector)
self.emb_table.weight.data[emb] = emb_vector
else:
if word:
self.save_emb(word, self.emb_table.weight.data[emb]) # 保存embdding历史
self.emb_table.weight.data[emb] = torch.frombuffer(bytearray(r1), dtype=torch.float32)
sqlite_pool.py 调用数据库,这里用sqlite是因为sqlite比较方便,改成其它数据库效果会更好,一些词向量库还支持向量搜索算法
import sqlite3
import queue
import logging
logging.getLogger().setLevel(logging.INFO)
connection_pools = {}
class ConnectionPool:
def __init__(self, max_connections=2, timeout=1, dbpath="nn.db"):
self.max_connections = max_connections
self.timeout = timeout
self.connections = queue.Queue(max_connections)
self.dbpath = dbpath
for _ in range(max_connections):
conn = self.create_connection(self.dbpath)
if conn is not None:
self.connections.put(conn)
@staticmethod
def create_connection(dbpath):
try:
conn = sqlite3.connect(dbpath, isolation_level=None, check_same_thread=False) # 关闭事务
return conn
except sqlite3.Error as e:
logging.error("Error creating connection:", str(e))
return None
def get_connection(self):
try:
conn = self.connections.get(timeout=self.timeout)
return conn
except queue.Empty:
logging.error("Connection timeout. No available connections.}")
return None
def release_connection(self, conn):
if conn is not None:
if self.connections.full():
logging.warning("Connection pool is full. Cannot release connection.")
else:
self.connections.put(conn)
def execute_query(self, sql, query_data=None, insert_data=None, update_data=None):
"""
Args:
query_data:
sql:
insert_data:
update_data:
Returns:
"""
conn = self.get_connection()
if conn is None:
return None
try:
row = None
cursor = conn.cursor()
if query_data:
cursor.execute(sql, query_data)
row = cursor.fetchone()
elif insert_data:
cursor.execute(sql, insert_data)
elif update_data:
cursor.execute(sql, update_data)
else:
cursor.execute(sql)
row = cursor.fetchone()
return row
except sqlite3.Error as e:
conn.rollback()
logging.error("Error executing query:", str(e))
finally:
cursor.close()
self.release_connection(conn)
def query_where(self, sql, query_data):
conn = self.get_connection()
if conn is None:
return None
try:
cursor = conn.cursor()
cursor.execute(sql, query_data)
row = cursor.fetchall()
return row
except sqlite3.Error as e:
logging.error("Error executing query:", str(e))
finally:
cursor.close()
self.release_connection(conn)
def query(self, sql, query_data):
return self.execute_query(sql, query_data=query_data)
def insert(self, sql, insert_data):
return self.execute_query(sql, insert_data=insert_data)
def update(self, sql, update_data):
return self.execute_query(sql, update_data=update_data)
def insert_bath(self, fn):
conn = self.get_connection()
if conn is None:
return None
try:
cursor = conn.cursor()
# WAL mode for multi-processing
cursor.execute('PRAGMA journal_mode=wal') # https://www.coder.work/article/2441365
cursor.execute('PRAGMA synchronous=OFF') #
fn(cursor)
except sqlite3.Error as e:
conn.rollback()
logging.error("Error executing query:", str(e))
finally:
cursor.close()
self.release_connection(conn)
def get_pool(dbpath, pool_name, max_connections=2, timeout=3):
pool = connection_pools.get(pool_name)
if pool is None:
pool = ConnectionPool(max_connections=max_connections, timeout=timeout, dbpath=dbpath)
connection_pools[pool_name] = pool
return pool
new table
CREATE TABLE IF NOT EXISTS "embedding" (
"id" INTEGER NOT NULL,
"code" real,
"emb" blob NOT NULL,
"score" real NOT NULL DEFAULT 0,
"date" TIMESTAMP default (datetime('now', 'localtime')),
PRIMARY KEY ("id")
);
使用
from data.sqlite_pool import get_pool
import tiktoken
# cache_size = 50
# hidden_dim = 16
nn_pool = get_pool(config.db.nn_db, config.db.nn_db_name)
emb = DictEmbedding(1000, 256, nn_pool).to(config.device)
emb.eval()
enc = tiktoken.get_encoding("cl100k_base")
with torch.no_grad():
out = emb(enc.encode("你好"))
ids = enc.encode("你好gsdf145")
out2 = emb(ids)
print(out.size(), out2.size())
emb.save_emb_bath(ids, out2)