94 lines
2.7 KiB
Python
94 lines
2.7 KiB
Python
|
|
"""
|
||
|
|
公共工具 —— JSON 提取、LLM 客户端单例
|
||
|
|
"""
|
||
|
|
import json
|
||
|
|
import re
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
import openai
|
||
|
|
|
||
|
|
# ── LLM 客户端单例 ──────────────────────────────────
|
||
|
|
|
||
|
|
_llm_client: openai.OpenAI | None = None
|
||
|
|
_llm_model: str = ""
|
||
|
|
|
||
|
|
|
||
|
|
def get_llm_client(config: dict) -> tuple[openai.OpenAI, str]:
|
||
|
|
"""获取 LLM 客户端(单例),避免每个组件各建一个"""
|
||
|
|
global _llm_client, _llm_model
|
||
|
|
if _llm_client is None:
|
||
|
|
_llm_client = openai.OpenAI(
|
||
|
|
api_key=config["api_key"],
|
||
|
|
base_url=config["base_url"],
|
||
|
|
)
|
||
|
|
_llm_model = config["model"]
|
||
|
|
return _llm_client, _llm_model
|
||
|
|
|
||
|
|
|
||
|
|
# ── JSON 提取 ────────────────────────────────────────
|
||
|
|
|
||
|
|
def extract_json_object(text: str) -> dict:
|
||
|
|
"""从 LLM 输出提取 JSON 对象"""
|
||
|
|
text = _clean_json_text(text)
|
||
|
|
|
||
|
|
try:
|
||
|
|
return json.loads(text)
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
|
||
|
|
match = re.search(pattern, text, re.DOTALL)
|
||
|
|
if match:
|
||
|
|
try:
|
||
|
|
return json.loads(_clean_json_text(match.group(1)))
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
continue
|
||
|
|
|
||
|
|
match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
|
||
|
|
if match:
|
||
|
|
try:
|
||
|
|
return json.loads(_clean_json_text(match.group()))
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
return {}
|
||
|
|
|
||
|
|
|
||
|
|
def extract_json_array(text: str) -> list[dict]:
|
||
|
|
"""从 LLM 输出提取 JSON 数组(处理尾逗号、注释等)"""
|
||
|
|
text = _clean_json_text(text)
|
||
|
|
|
||
|
|
try:
|
||
|
|
result = json.loads(text)
|
||
|
|
if isinstance(result, list):
|
||
|
|
return result
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
|
||
|
|
match = re.search(pattern, text, re.DOTALL)
|
||
|
|
if match:
|
||
|
|
try:
|
||
|
|
result = json.loads(_clean_json_text(match.group(1)))
|
||
|
|
if isinstance(result, list):
|
||
|
|
return result
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
continue
|
||
|
|
|
||
|
|
match = re.search(r'\[.*\]', text, re.DOTALL)
|
||
|
|
if match:
|
||
|
|
try:
|
||
|
|
return json.loads(_clean_json_text(match.group()))
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
return []
|
||
|
|
|
||
|
|
|
||
|
|
def _clean_json_text(s: str) -> str:
|
||
|
|
"""清理 LLM 常见的非标准 JSON"""
|
||
|
|
s = re.sub(r'//.*?\n', '\n', s)
|
||
|
|
s = re.sub(r'/\*.*?\*/', '', s, flags=re.DOTALL)
|
||
|
|
s = re.sub(r',\s*([}\]])', r'\1', s)
|
||
|
|
return s.strip()
|