大改,未验证
This commit is contained in:
152
src/core/embedding_client.py
Normal file
152
src/core/embedding_client.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Embedding 向量客户端(本地模型方案)
|
||||
使用 sentence-transformers 在本地运行轻量级中文 embedding 模型
|
||||
零 API 调用、零成本、低延迟
|
||||
"""
|
||||
|
||||
import logging
|
||||
import hashlib
|
||||
import threading
|
||||
from typing import List, Optional
|
||||
|
||||
from src.config.unified_config import get_config
|
||||
from src.core.cache_manager import cache_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingClient:
|
||||
"""本地 Embedding 向量客户端"""
|
||||
|
||||
def __init__(self):
|
||||
config = get_config()
|
||||
self.enabled = config.embedding.enabled
|
||||
self.model_name = config.embedding.model
|
||||
self.dimension = config.embedding.dimension
|
||||
self.cache_ttl = config.embedding.cache_ttl
|
||||
|
||||
self._model = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
if self.enabled:
|
||||
logger.info(f"Embedding 客户端初始化: model={self.model_name} (本地模式)")
|
||||
else:
|
||||
logger.info("Embedding 功能已禁用,将使用关键词匹配降级")
|
||||
|
||||
def _get_model(self):
|
||||
"""延迟加载模型(首次调用时下载并加载)"""
|
||||
if self._model is not None:
|
||||
return self._model
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None:
|
||||
return self._model
|
||||
|
||||
try:
|
||||
import os
|
||||
# 设置 HuggingFace 镜像,解决国内下载问题
|
||||
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
logger.info(f"正在加载 embedding 模型: {self.model_name} ...")
|
||||
self._model = SentenceTransformer(self.model_name)
|
||||
logger.info(f"Embedding 模型加载完成: {self.model_name}")
|
||||
return self._model
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"sentence-transformers 未安装,请运行: pip install sentence-transformers"
|
||||
)
|
||||
self.enabled = False
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"加载 embedding 模型失败: {e}")
|
||||
self.enabled = False
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公开接口
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def embed_text(self, text: str) -> Optional[List[float]]:
|
||||
"""对单条文本生成 embedding 向量,优先从缓存读取"""
|
||||
if not self.enabled or not text.strip():
|
||||
return None
|
||||
|
||||
cache_key = self._cache_key(text)
|
||||
cached = cache_manager.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
model = self._get_model()
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
vec = model.encode(text, normalize_embeddings=True).tolist()
|
||||
cache_manager.set(cache_key, vec, self.cache_ttl)
|
||||
return vec
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding 生成失败: {e}")
|
||||
return None
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[Optional[List[float]]]:
|
||||
"""批量生成 embedding"""
|
||||
if not self.enabled:
|
||||
return [None] * len(texts)
|
||||
|
||||
results: List[Optional[List[float]]] = [None] * len(texts)
|
||||
uncached_indices = []
|
||||
uncached_texts = []
|
||||
|
||||
# 1. 先查缓存
|
||||
for i, t in enumerate(texts):
|
||||
if not t.strip():
|
||||
continue
|
||||
cached = cache_manager.get(self._cache_key(t))
|
||||
if cached is not None:
|
||||
results[i] = cached
|
||||
else:
|
||||
uncached_indices.append(i)
|
||||
uncached_texts.append(t)
|
||||
|
||||
if not uncached_texts:
|
||||
return results
|
||||
|
||||
# 2. 批量推理
|
||||
model = self._get_model()
|
||||
if model is None:
|
||||
return results
|
||||
|
||||
try:
|
||||
vectors = model.encode(
|
||||
uncached_texts, normalize_embeddings=True, batch_size=32
|
||||
).tolist()
|
||||
|
||||
for j, vec in enumerate(vectors):
|
||||
idx = uncached_indices[j]
|
||||
results[idx] = vec
|
||||
cache_manager.set(self._cache_key(uncached_texts[j]), vec, self.cache_ttl)
|
||||
except Exception as e:
|
||||
logger.error(f"批量 embedding 生成失败: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""测试模型是否可用"""
|
||||
try:
|
||||
vec = self.embed_text("测试连接")
|
||||
return vec is not None and len(vec) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding 模型测试失败: {e}")
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 内部方法
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _cache_key(text: str) -> str:
|
||||
"""生成缓存键(基于文本哈希)"""
|
||||
h = hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||
return f"emb:{h}"
|
||||
@@ -58,11 +58,49 @@ class WorkOrder(Base):
|
||||
# 关联处理过程记录
|
||||
process_history = relationship("WorkOrderProcessHistory", back_populates="work_order", order_by="WorkOrderProcessHistory.process_time")
|
||||
|
||||
class ChatSession(Base):
|
||||
"""对话会话模型 — 将多轮对话组织为一个会话"""
|
||||
__tablename__ = "chat_sessions"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
session_id = Column(String(100), unique=True, nullable=False) # 唯一会话标识
|
||||
user_id = Column(String(100), nullable=True) # 用户标识
|
||||
work_order_id = Column(Integer, ForeignKey("work_orders.id"), nullable=True)
|
||||
title = Column(String(200), nullable=True) # 会话标题(取首条消息摘要)
|
||||
status = Column(String(20), default="active") # active, ended
|
||||
message_count = Column(Integer, default=0) # 消息轮数
|
||||
source = Column(String(50), nullable=True) # 来源:websocket, api, feishu
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
ended_at = Column(DateTime, nullable=True)
|
||||
|
||||
# 关联消息
|
||||
messages = relationship("Conversation", back_populates="chat_session", order_by="Conversation.timestamp")
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'session_id': self.session_id,
|
||||
'user_id': self.user_id,
|
||||
'work_order_id': self.work_order_id,
|
||||
'title': self.title,
|
||||
'status': self.status,
|
||||
'message_count': self.message_count,
|
||||
'source': self.source,
|
||||
'ip_address': self.ip_address,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
|
||||
'ended_at': self.ended_at.isoformat() if self.ended_at else None,
|
||||
}
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
"""对话记录模型"""
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
session_id = Column(String(100), ForeignKey("chat_sessions.session_id"), nullable=True) # 关联会话
|
||||
work_order_id = Column(Integer, ForeignKey("work_orders.id"))
|
||||
user_message = Column(Text, nullable=False)
|
||||
assistant_response = Column(Text, nullable=False)
|
||||
@@ -79,6 +117,7 @@ class Conversation(Base):
|
||||
cpu_usage = Column(Float) # CPU使用率
|
||||
|
||||
work_order = relationship("WorkOrder", back_populates="conversations")
|
||||
chat_session = relationship("ChatSession", back_populates="messages")
|
||||
|
||||
class KnowledgeEntry(Base):
|
||||
"""知识库条目模型"""
|
||||
|
||||
@@ -86,6 +86,7 @@ class QueryOptimizer:
|
||||
for conv in conversations:
|
||||
conversation_list.append({
|
||||
'id': conv.id,
|
||||
'session_id': conv.session_id,
|
||||
'user_message': conv.user_message,
|
||||
'assistant_response': conv.assistant_response,
|
||||
'timestamp': conv.timestamp.isoformat() if conv.timestamp else None,
|
||||
|
||||
164
src/core/vector_store.py
Normal file
164
src/core/vector_store.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
向量存储与检索
|
||||
使用 numpy 实现轻量级向量索引(无需额外依赖)
|
||||
支持从 DB 加载已有 embedding 构建索引,增量更新
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import threading
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
|
||||
from src.core.database import db_manager
|
||||
from src.core.models import KnowledgeEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""轻量级向量存储,基于 numpy 余弦相似度"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.RLock()
|
||||
# 索引数据: entry_id -> embedding vector
|
||||
self._ids: List[int] = []
|
||||
self._matrix: Optional[np.ndarray] = None # shape: (n, dim)
|
||||
self._loaded = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 索引管理
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load_from_db(self):
|
||||
"""从数据库加载所有已有 embedding 构建索引"""
|
||||
try:
|
||||
with db_manager.get_session() as session:
|
||||
entries = session.query(
|
||||
KnowledgeEntry.id, KnowledgeEntry.vector_embedding
|
||||
).filter(
|
||||
KnowledgeEntry.is_active == True,
|
||||
KnowledgeEntry.vector_embedding.isnot(None),
|
||||
KnowledgeEntry.vector_embedding != ''
|
||||
).all()
|
||||
|
||||
ids = []
|
||||
vectors = []
|
||||
for entry_id, vec_json in entries:
|
||||
try:
|
||||
vec = json.loads(vec_json)
|
||||
if isinstance(vec, list) and len(vec) > 0:
|
||||
ids.append(entry_id)
|
||||
vectors.append(vec)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
with self._lock:
|
||||
if vectors:
|
||||
self._ids = ids
|
||||
self._matrix = np.array(vectors, dtype=np.float32)
|
||||
# L2 归一化,方便后续用点积算余弦相似度
|
||||
norms = np.linalg.norm(self._matrix, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1.0
|
||||
self._matrix = self._matrix / norms
|
||||
else:
|
||||
self._ids = []
|
||||
self._matrix = None
|
||||
self._loaded = True
|
||||
|
||||
logger.info(f"向量索引加载完成: {len(ids)} 条记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载向量索引失败: {e}")
|
||||
self._loaded = True # 标记为已加载,避免重复尝试
|
||||
|
||||
def add(self, entry_id: int, vector: List[float]):
|
||||
"""增量添加一条向量"""
|
||||
with self._lock:
|
||||
vec = np.array(vector, dtype=np.float32).reshape(1, -1)
|
||||
norm = np.linalg.norm(vec)
|
||||
if norm > 0:
|
||||
vec = vec / norm
|
||||
|
||||
if self._matrix is not None:
|
||||
self._ids.append(entry_id)
|
||||
self._matrix = np.vstack([self._matrix, vec])
|
||||
else:
|
||||
self._ids = [entry_id]
|
||||
self._matrix = vec
|
||||
|
||||
def remove(self, entry_id: int):
|
||||
"""移除一条向量"""
|
||||
with self._lock:
|
||||
if entry_id in self._ids:
|
||||
idx = self._ids.index(entry_id)
|
||||
self._ids.pop(idx)
|
||||
if self._matrix is not None and len(self._ids) > 0:
|
||||
self._matrix = np.delete(self._matrix, idx, axis=0)
|
||||
else:
|
||||
self._matrix = None
|
||||
|
||||
def update(self, entry_id: int, vector: List[float]):
|
||||
"""更新一条向量"""
|
||||
self.remove(entry_id)
|
||||
self.add(entry_id, vector)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 检索
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_vector: List[float],
|
||||
top_k: int = 5,
|
||||
threshold: float = 0.0
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""
|
||||
向量相似度检索
|
||||
|
||||
Returns:
|
||||
[(entry_id, similarity_score), ...] 按相似度降序
|
||||
"""
|
||||
if not self._loaded:
|
||||
self.load_from_db()
|
||||
|
||||
with self._lock:
|
||||
if self._matrix is None or len(self._ids) == 0:
|
||||
return []
|
||||
|
||||
q = np.array(query_vector, dtype=np.float32).reshape(1, -1)
|
||||
norm = np.linalg.norm(q)
|
||||
if norm > 0:
|
||||
q = q / norm
|
||||
|
||||
# 余弦相似度 = 归一化向量的点积
|
||||
similarities = (self._matrix @ q.T).flatten()
|
||||
|
||||
# 筛选超过阈值的
|
||||
valid_mask = similarities >= threshold
|
||||
valid_indices = np.where(valid_mask)[0]
|
||||
|
||||
if len(valid_indices) == 0:
|
||||
return []
|
||||
|
||||
# 取 top_k
|
||||
if len(valid_indices) > top_k:
|
||||
top_indices = valid_indices[np.argsort(-similarities[valid_indices])[:top_k]]
|
||||
else:
|
||||
top_indices = valid_indices[np.argsort(-similarities[valid_indices])]
|
||||
|
||||
results = []
|
||||
for idx in top_indices:
|
||||
results.append((self._ids[idx], float(similarities[idx])))
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._ids)
|
||||
|
||||
|
||||
# 全局单例
|
||||
vector_store = VectorStore()
|
||||
Reference in New Issue
Block a user