大改,未验证
This commit is contained in:
152
src/core/embedding_client.py
Normal file
152
src/core/embedding_client.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Embedding 向量客户端(本地模型方案)
|
||||
使用 sentence-transformers 在本地运行轻量级中文 embedding 模型
|
||||
零 API 调用、零成本、低延迟
|
||||
"""
|
||||
|
||||
import logging
|
||||
import hashlib
|
||||
import threading
|
||||
from typing import List, Optional
|
||||
|
||||
from src.config.unified_config import get_config
|
||||
from src.core.cache_manager import cache_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingClient:
|
||||
"""本地 Embedding 向量客户端"""
|
||||
|
||||
def __init__(self):
|
||||
config = get_config()
|
||||
self.enabled = config.embedding.enabled
|
||||
self.model_name = config.embedding.model
|
||||
self.dimension = config.embedding.dimension
|
||||
self.cache_ttl = config.embedding.cache_ttl
|
||||
|
||||
self._model = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
if self.enabled:
|
||||
logger.info(f"Embedding 客户端初始化: model={self.model_name} (本地模式)")
|
||||
else:
|
||||
logger.info("Embedding 功能已禁用,将使用关键词匹配降级")
|
||||
|
||||
def _get_model(self):
|
||||
"""延迟加载模型(首次调用时下载并加载)"""
|
||||
if self._model is not None:
|
||||
return self._model
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None:
|
||||
return self._model
|
||||
|
||||
try:
|
||||
import os
|
||||
# 设置 HuggingFace 镜像,解决国内下载问题
|
||||
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
logger.info(f"正在加载 embedding 模型: {self.model_name} ...")
|
||||
self._model = SentenceTransformer(self.model_name)
|
||||
logger.info(f"Embedding 模型加载完成: {self.model_name}")
|
||||
return self._model
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"sentence-transformers 未安装,请运行: pip install sentence-transformers"
|
||||
)
|
||||
self.enabled = False
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"加载 embedding 模型失败: {e}")
|
||||
self.enabled = False
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公开接口
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def embed_text(self, text: str) -> Optional[List[float]]:
|
||||
"""对单条文本生成 embedding 向量,优先从缓存读取"""
|
||||
if not self.enabled or not text.strip():
|
||||
return None
|
||||
|
||||
cache_key = self._cache_key(text)
|
||||
cached = cache_manager.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
model = self._get_model()
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
vec = model.encode(text, normalize_embeddings=True).tolist()
|
||||
cache_manager.set(cache_key, vec, self.cache_ttl)
|
||||
return vec
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding 生成失败: {e}")
|
||||
return None
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[Optional[List[float]]]:
|
||||
"""批量生成 embedding"""
|
||||
if not self.enabled:
|
||||
return [None] * len(texts)
|
||||
|
||||
results: List[Optional[List[float]]] = [None] * len(texts)
|
||||
uncached_indices = []
|
||||
uncached_texts = []
|
||||
|
||||
# 1. 先查缓存
|
||||
for i, t in enumerate(texts):
|
||||
if not t.strip():
|
||||
continue
|
||||
cached = cache_manager.get(self._cache_key(t))
|
||||
if cached is not None:
|
||||
results[i] = cached
|
||||
else:
|
||||
uncached_indices.append(i)
|
||||
uncached_texts.append(t)
|
||||
|
||||
if not uncached_texts:
|
||||
return results
|
||||
|
||||
# 2. 批量推理
|
||||
model = self._get_model()
|
||||
if model is None:
|
||||
return results
|
||||
|
||||
try:
|
||||
vectors = model.encode(
|
||||
uncached_texts, normalize_embeddings=True, batch_size=32
|
||||
).tolist()
|
||||
|
||||
for j, vec in enumerate(vectors):
|
||||
idx = uncached_indices[j]
|
||||
results[idx] = vec
|
||||
cache_manager.set(self._cache_key(uncached_texts[j]), vec, self.cache_ttl)
|
||||
except Exception as e:
|
||||
logger.error(f"批量 embedding 生成失败: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""测试模型是否可用"""
|
||||
try:
|
||||
vec = self.embed_text("测试连接")
|
||||
return vec is not None and len(vec) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding 模型测试失败: {e}")
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 内部方法
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _cache_key(text: str) -> str:
|
||||
"""生成缓存键(基于文本哈希)"""
|
||||
h = hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||
return f"emb:{h}"
|
||||
Reference in New Issue
Block a user