154 lines
4.7 KiB
Python
154 lines
4.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
分析模板系统 - 从 config/templates/*.yaml 加载模板
|
|
|
|
模板文件格式:
|
|
name: 模板显示名称
|
|
description: 模板描述
|
|
steps:
|
|
- name: 步骤名称
|
|
description: 步骤描述
|
|
prompt: 给LLM的指令
|
|
"""
|
|
|
|
import os
|
|
import glob
|
|
import yaml
|
|
from typing import List, Dict, Any
|
|
from dataclasses import dataclass
|
|
|
|
TEMPLATES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config", "templates")
|
|
|
|
|
|
@dataclass
|
|
class AnalysisStep:
|
|
"""分析步骤"""
|
|
name: str
|
|
description: str
|
|
prompt: str
|
|
|
|
|
|
class AnalysisTemplate:
|
|
"""从 YAML 文件加载的分析模板"""
|
|
|
|
def __init__(self, name: str, display_name: str, description: str, steps: List[AnalysisStep], filepath: str = ""):
|
|
self.name = name
|
|
self.display_name = display_name
|
|
self.description = description
|
|
self.steps = steps
|
|
self.filepath = filepath
|
|
|
|
def get_full_prompt(self) -> str:
|
|
prompt = f"# {self.display_name}\n\n{self.description}\n\n"
|
|
prompt += "## 分析步骤:\n\n"
|
|
for i, step in enumerate(self.steps, 1):
|
|
prompt += f"### {i}. {step.name}\n"
|
|
prompt += f"{step.description}\n\n"
|
|
prompt += f"```\n{step.prompt}\n```\n\n"
|
|
return prompt
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"name": self.name,
|
|
"display_name": self.display_name,
|
|
"description": self.description,
|
|
"steps": [{"name": s.name, "description": s.description, "prompt": s.prompt} for s in self.steps],
|
|
}
|
|
|
|
|
|
def _load_template_from_file(filepath: str) -> AnalysisTemplate:
|
|
"""从单个 YAML 文件加载模板"""
|
|
with open(filepath, "r", encoding="utf-8") as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
template_name = os.path.splitext(os.path.basename(filepath))[0]
|
|
steps = []
|
|
for s in data.get("steps", []):
|
|
steps.append(AnalysisStep(
|
|
name=s.get("name", ""),
|
|
description=s.get("description", ""),
|
|
prompt=s.get("prompt", ""),
|
|
))
|
|
|
|
return AnalysisTemplate(
|
|
name=template_name,
|
|
display_name=data.get("name", template_name),
|
|
description=data.get("description", ""),
|
|
steps=steps,
|
|
filepath=filepath,
|
|
)
|
|
|
|
|
|
def _scan_templates() -> Dict[str, AnalysisTemplate]:
|
|
"""扫描 config/templates/ 目录加载所有模板"""
|
|
registry = {}
|
|
if not os.path.exists(TEMPLATES_DIR):
|
|
os.makedirs(TEMPLATES_DIR, exist_ok=True)
|
|
return registry
|
|
|
|
for fpath in sorted(glob.glob(os.path.join(TEMPLATES_DIR, "*.yaml"))):
|
|
try:
|
|
tpl = _load_template_from_file(fpath)
|
|
registry[tpl.name] = tpl
|
|
except Exception as e:
|
|
print(f"[WARN] 加载模板失败 {fpath}: {e}")
|
|
return registry
|
|
|
|
|
|
# Module-level registry, refreshed on each call to support hot-editing
|
|
def _get_registry() -> Dict[str, AnalysisTemplate]:
|
|
return _scan_templates()
|
|
|
|
|
|
# Keep TEMPLATE_REGISTRY as a lazy property for backward compatibility with tests
|
|
TEMPLATE_REGISTRY = _scan_templates()
|
|
|
|
|
|
def get_template(template_name: str) -> AnalysisTemplate:
|
|
"""获取分析模板(每次从磁盘重新加载以支持热编辑)"""
|
|
registry = _get_registry()
|
|
if template_name in registry:
|
|
return registry[template_name]
|
|
raise ValueError(f"未找到模板: {template_name}。可用模板: {list(registry.keys())}")
|
|
|
|
|
|
def list_templates() -> List[Dict[str, str]]:
|
|
"""列出所有可用模板"""
|
|
registry = _get_registry()
|
|
return [
|
|
{"name": tpl.name, "display_name": tpl.display_name, "description": tpl.description}
|
|
for tpl in registry.values()
|
|
]
|
|
|
|
|
|
def save_template(template_name: str, data: Dict[str, Any]) -> str:
|
|
"""保存或更新模板到 YAML 文件,返回文件路径"""
|
|
os.makedirs(TEMPLATES_DIR, exist_ok=True)
|
|
filepath = os.path.join(TEMPLATES_DIR, f"{template_name}.yaml")
|
|
|
|
yaml_data = {
|
|
"name": data.get("display_name", data.get("name", template_name)),
|
|
"description": data.get("description", ""),
|
|
"steps": data.get("steps", []),
|
|
}
|
|
|
|
with open(filepath, "w", encoding="utf-8") as f:
|
|
yaml.dump(yaml_data, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
|
|
|
|
# Refresh global registry
|
|
global TEMPLATE_REGISTRY
|
|
TEMPLATE_REGISTRY = _scan_templates()
|
|
|
|
return filepath
|
|
|
|
|
|
def delete_template(template_name: str) -> bool:
|
|
"""删除模板文件"""
|
|
filepath = os.path.join(TEMPLATES_DIR, f"{template_name}.yaml")
|
|
if os.path.exists(filepath):
|
|
os.remove(filepath)
|
|
global TEMPLATE_REGISTRY
|
|
TEMPLATE_REGISTRY = _scan_templates()
|
|
return True
|
|
return False
|