# -*- coding: utf-8 -*- """ Embedding 向量客户端(本地模型方案) 使用 sentence-transformers 在本地运行轻量级中文 embedding 模型 零 API 调用、零成本、低延迟 """ import logging import hashlib import threading from typing import List, Optional from src.config.unified_config import get_config from src.core.cache_manager import cache_manager logger = logging.getLogger(__name__) class EmbeddingClient: """本地 Embedding 向量客户端""" def __init__(self): config = get_config() self.enabled = config.embedding.enabled self.model_name = config.embedding.model self.dimension = config.embedding.dimension self.cache_ttl = config.embedding.cache_ttl self._model = None self._lock = threading.Lock() if self.enabled: logger.info(f"Embedding 客户端初始化: model={self.model_name} (本地模式)") else: logger.info("Embedding 功能已禁用,将使用关键词匹配降级") def _get_model(self): """延迟加载模型(首次调用时下载并加载)""" if self._model is not None: return self._model with self._lock: if self._model is not None: return self._model try: import os # 设置 HuggingFace 镜像,解决国内下载问题 os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com") from sentence_transformers import SentenceTransformer logger.info(f"正在加载 embedding 模型: {self.model_name} ...") self._model = SentenceTransformer(self.model_name) logger.info(f"Embedding 模型加载完成: {self.model_name}") return self._model except ImportError: logger.error( "sentence-transformers 未安装,请运行: pip install sentence-transformers" ) self.enabled = False return None except Exception as e: logger.error(f"加载 embedding 模型失败: {e}") self.enabled = False return None # ------------------------------------------------------------------ # 公开接口 # ------------------------------------------------------------------ def embed_text(self, text: str) -> Optional[List[float]]: """对单条文本生成 embedding 向量,优先从缓存读取""" if not self.enabled or not text.strip(): return None cache_key = self._cache_key(text) cached = cache_manager.get(cache_key) if cached is not None: return cached model = self._get_model() if model is None: return None try: vec = model.encode(text, normalize_embeddings=True).tolist() cache_manager.set(cache_key, vec, self.cache_ttl) return vec except Exception as e: logger.error(f"Embedding 生成失败: {e}") return None def embed_batch(self, texts: List[str]) -> List[Optional[List[float]]]: """批量生成 embedding""" if not self.enabled: return [None] * len(texts) results: List[Optional[List[float]]] = [None] * len(texts) uncached_indices = [] uncached_texts = [] # 1. 先查缓存 for i, t in enumerate(texts): if not t.strip(): continue cached = cache_manager.get(self._cache_key(t)) if cached is not None: results[i] = cached else: uncached_indices.append(i) uncached_texts.append(t) if not uncached_texts: return results # 2. 批量推理 model = self._get_model() if model is None: return results try: vectors = model.encode( uncached_texts, normalize_embeddings=True, batch_size=32 ).tolist() for j, vec in enumerate(vectors): idx = uncached_indices[j] results[idx] = vec cache_manager.set(self._cache_key(uncached_texts[j]), vec, self.cache_ttl) except Exception as e: logger.error(f"批量 embedding 生成失败: {e}") return results def test_connection(self) -> bool: """测试模型是否可用""" try: vec = self.embed_text("测试连接") return vec is not None and len(vec) > 0 except Exception as e: logger.error(f"Embedding 模型测试失败: {e}") return False # ------------------------------------------------------------------ # 内部方法 # ------------------------------------------------------------------ @staticmethod def _cache_key(text: str) -> str: """生成缓存键(基于文本哈希)""" h = hashlib.md5(text.encode("utf-8")).hexdigest() return f"emb:{h}"