大改,未验证

This commit is contained in:
2026-03-20 16:50:26 +08:00
parent c7ee292c4f
commit e14e3ee7a5
36 changed files with 1419 additions and 4805 deletions

View File

@@ -0,0 +1,152 @@
# -*- coding: utf-8 -*-
"""
Embedding 向量客户端(本地模型方案)
使用 sentence-transformers 在本地运行轻量级中文 embedding 模型
零 API 调用、零成本、低延迟
"""
import logging
import hashlib
import threading
from typing import List, Optional
from src.config.unified_config import get_config
from src.core.cache_manager import cache_manager
logger = logging.getLogger(__name__)
class EmbeddingClient:
"""本地 Embedding 向量客户端"""
def __init__(self):
config = get_config()
self.enabled = config.embedding.enabled
self.model_name = config.embedding.model
self.dimension = config.embedding.dimension
self.cache_ttl = config.embedding.cache_ttl
self._model = None
self._lock = threading.Lock()
if self.enabled:
logger.info(f"Embedding 客户端初始化: model={self.model_name} (本地模式)")
else:
logger.info("Embedding 功能已禁用,将使用关键词匹配降级")
def _get_model(self):
"""延迟加载模型(首次调用时下载并加载)"""
if self._model is not None:
return self._model
with self._lock:
if self._model is not None:
return self._model
try:
import os
# 设置 HuggingFace 镜像,解决国内下载问题
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
from sentence_transformers import SentenceTransformer
logger.info(f"正在加载 embedding 模型: {self.model_name} ...")
self._model = SentenceTransformer(self.model_name)
logger.info(f"Embedding 模型加载完成: {self.model_name}")
return self._model
except ImportError:
logger.error(
"sentence-transformers 未安装,请运行: pip install sentence-transformers"
)
self.enabled = False
return None
except Exception as e:
logger.error(f"加载 embedding 模型失败: {e}")
self.enabled = False
return None
# ------------------------------------------------------------------
# 公开接口
# ------------------------------------------------------------------
def embed_text(self, text: str) -> Optional[List[float]]:
"""对单条文本生成 embedding 向量,优先从缓存读取"""
if not self.enabled or not text.strip():
return None
cache_key = self._cache_key(text)
cached = cache_manager.get(cache_key)
if cached is not None:
return cached
model = self._get_model()
if model is None:
return None
try:
vec = model.encode(text, normalize_embeddings=True).tolist()
cache_manager.set(cache_key, vec, self.cache_ttl)
return vec
except Exception as e:
logger.error(f"Embedding 生成失败: {e}")
return None
def embed_batch(self, texts: List[str]) -> List[Optional[List[float]]]:
"""批量生成 embedding"""
if not self.enabled:
return [None] * len(texts)
results: List[Optional[List[float]]] = [None] * len(texts)
uncached_indices = []
uncached_texts = []
# 1. 先查缓存
for i, t in enumerate(texts):
if not t.strip():
continue
cached = cache_manager.get(self._cache_key(t))
if cached is not None:
results[i] = cached
else:
uncached_indices.append(i)
uncached_texts.append(t)
if not uncached_texts:
return results
# 2. 批量推理
model = self._get_model()
if model is None:
return results
try:
vectors = model.encode(
uncached_texts, normalize_embeddings=True, batch_size=32
).tolist()
for j, vec in enumerate(vectors):
idx = uncached_indices[j]
results[idx] = vec
cache_manager.set(self._cache_key(uncached_texts[j]), vec, self.cache_ttl)
except Exception as e:
logger.error(f"批量 embedding 生成失败: {e}")
return results
def test_connection(self) -> bool:
"""测试模型是否可用"""
try:
vec = self.embed_text("测试连接")
return vec is not None and len(vec) > 0
except Exception as e:
logger.error(f"Embedding 模型测试失败: {e}")
return False
# ------------------------------------------------------------------
# 内部方法
# ------------------------------------------------------------------
@staticmethod
def _cache_key(text: str) -> str:
"""生成缓存键(基于文本哈希)"""
h = hashlib.md5(text.encode("utf-8")).hexdigest()
return f"emb:{h}"

View File

@@ -58,11 +58,49 @@ class WorkOrder(Base):
# 关联处理过程记录
process_history = relationship("WorkOrderProcessHistory", back_populates="work_order", order_by="WorkOrderProcessHistory.process_time")
class ChatSession(Base):
"""对话会话模型 — 将多轮对话组织为一个会话"""
__tablename__ = "chat_sessions"
id = Column(Integer, primary_key=True)
session_id = Column(String(100), unique=True, nullable=False) # 唯一会话标识
user_id = Column(String(100), nullable=True) # 用户标识
work_order_id = Column(Integer, ForeignKey("work_orders.id"), nullable=True)
title = Column(String(200), nullable=True) # 会话标题(取首条消息摘要)
status = Column(String(20), default="active") # active, ended
message_count = Column(Integer, default=0) # 消息轮数
source = Column(String(50), nullable=True) # 来源websocket, api, feishu
ip_address = Column(String(45), nullable=True)
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
ended_at = Column(DateTime, nullable=True)
# 关联消息
messages = relationship("Conversation", back_populates="chat_session", order_by="Conversation.timestamp")
def to_dict(self):
return {
'id': self.id,
'session_id': self.session_id,
'user_id': self.user_id,
'work_order_id': self.work_order_id,
'title': self.title,
'status': self.status,
'message_count': self.message_count,
'source': self.source,
'ip_address': self.ip_address,
'created_at': self.created_at.isoformat() if self.created_at else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
'ended_at': self.ended_at.isoformat() if self.ended_at else None,
}
class Conversation(Base):
"""对话记录模型"""
__tablename__ = "conversations"
id = Column(Integer, primary_key=True)
session_id = Column(String(100), ForeignKey("chat_sessions.session_id"), nullable=True) # 关联会话
work_order_id = Column(Integer, ForeignKey("work_orders.id"))
user_message = Column(Text, nullable=False)
assistant_response = Column(Text, nullable=False)
@@ -79,6 +117,7 @@ class Conversation(Base):
cpu_usage = Column(Float) # CPU使用率
work_order = relationship("WorkOrder", back_populates="conversations")
chat_session = relationship("ChatSession", back_populates="messages")
class KnowledgeEntry(Base):
"""知识库条目模型"""

View File

@@ -86,6 +86,7 @@ class QueryOptimizer:
for conv in conversations:
conversation_list.append({
'id': conv.id,
'session_id': conv.session_id,
'user_message': conv.user_message,
'assistant_response': conv.assistant_response,
'timestamp': conv.timestamp.isoformat() if conv.timestamp else None,

164
src/core/vector_store.py Normal file
View File

@@ -0,0 +1,164 @@
# -*- coding: utf-8 -*-
"""
向量存储与检索
使用 numpy 实现轻量级向量索引(无需额外依赖)
支持从 DB 加载已有 embedding 构建索引,增量更新
"""
import logging
import json
import threading
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from src.core.database import db_manager
from src.core.models import KnowledgeEntry
logger = logging.getLogger(__name__)
class VectorStore:
"""轻量级向量存储,基于 numpy 余弦相似度"""
def __init__(self):
self._lock = threading.RLock()
# 索引数据: entry_id -> embedding vector
self._ids: List[int] = []
self._matrix: Optional[np.ndarray] = None # shape: (n, dim)
self._loaded = False
# ------------------------------------------------------------------
# 索引管理
# ------------------------------------------------------------------
def load_from_db(self):
"""从数据库加载所有已有 embedding 构建索引"""
try:
with db_manager.get_session() as session:
entries = session.query(
KnowledgeEntry.id, KnowledgeEntry.vector_embedding
).filter(
KnowledgeEntry.is_active == True,
KnowledgeEntry.vector_embedding.isnot(None),
KnowledgeEntry.vector_embedding != ''
).all()
ids = []
vectors = []
for entry_id, vec_json in entries:
try:
vec = json.loads(vec_json)
if isinstance(vec, list) and len(vec) > 0:
ids.append(entry_id)
vectors.append(vec)
except (json.JSONDecodeError, TypeError):
continue
with self._lock:
if vectors:
self._ids = ids
self._matrix = np.array(vectors, dtype=np.float32)
# L2 归一化,方便后续用点积算余弦相似度
norms = np.linalg.norm(self._matrix, axis=1, keepdims=True)
norms[norms == 0] = 1.0
self._matrix = self._matrix / norms
else:
self._ids = []
self._matrix = None
self._loaded = True
logger.info(f"向量索引加载完成: {len(ids)} 条记录")
except Exception as e:
logger.error(f"加载向量索引失败: {e}")
self._loaded = True # 标记为已加载,避免重复尝试
def add(self, entry_id: int, vector: List[float]):
"""增量添加一条向量"""
with self._lock:
vec = np.array(vector, dtype=np.float32).reshape(1, -1)
norm = np.linalg.norm(vec)
if norm > 0:
vec = vec / norm
if self._matrix is not None:
self._ids.append(entry_id)
self._matrix = np.vstack([self._matrix, vec])
else:
self._ids = [entry_id]
self._matrix = vec
def remove(self, entry_id: int):
"""移除一条向量"""
with self._lock:
if entry_id in self._ids:
idx = self._ids.index(entry_id)
self._ids.pop(idx)
if self._matrix is not None and len(self._ids) > 0:
self._matrix = np.delete(self._matrix, idx, axis=0)
else:
self._matrix = None
def update(self, entry_id: int, vector: List[float]):
"""更新一条向量"""
self.remove(entry_id)
self.add(entry_id, vector)
# ------------------------------------------------------------------
# 检索
# ------------------------------------------------------------------
def search(
self,
query_vector: List[float],
top_k: int = 5,
threshold: float = 0.0
) -> List[Tuple[int, float]]:
"""
向量相似度检索
Returns:
[(entry_id, similarity_score), ...] 按相似度降序
"""
if not self._loaded:
self.load_from_db()
with self._lock:
if self._matrix is None or len(self._ids) == 0:
return []
q = np.array(query_vector, dtype=np.float32).reshape(1, -1)
norm = np.linalg.norm(q)
if norm > 0:
q = q / norm
# 余弦相似度 = 归一化向量的点积
similarities = (self._matrix @ q.T).flatten()
# 筛选超过阈值的
valid_mask = similarities >= threshold
valid_indices = np.where(valid_mask)[0]
if len(valid_indices) == 0:
return []
# 取 top_k
if len(valid_indices) > top_k:
top_indices = valid_indices[np.argsort(-similarities[valid_indices])[:top_k]]
else:
top_indices = valid_indices[np.argsort(-similarities[valid_indices])]
results = []
for idx in top_indices:
results.append((self._ids[idx], float(similarities[idx])))
return results
@property
def size(self) -> int:
with self._lock:
return len(self._ids)
# 全局单例
vector_store = VectorStore()