This commit is contained in:
2026-04-19 16:29:59 +08:00
22 changed files with 2060 additions and 916 deletions

View File

@@ -19,11 +19,16 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data_analysis_agent import DataAnalysisAgent
from config.llm_config import LLMConfig
from utils.create_session_dir import create_session_output_dir
from config.llm_config import LLMConfig
from utils.create_session_dir import create_session_output_dir
from utils.logger import PrintCapture
app = FastAPI(title="IOV Data Analysis Agent")
def _to_web_path(fs_path: str) -> str:
"""将文件系统路径转为 URL 安全的正斜杠路径(修复 Windows 反斜杠问题)"""
return fs_path.replace("\\", "/")
# CORS
app.add_middleware(
CORSMiddleware,
@@ -168,8 +173,6 @@ class SessionManager:
"max_rounds": session.max_rounds,
"created_at": session.created_at,
"last_updated": session.last_updated,
"created_at": session.created_at,
"last_updated": session.last_updated,
"user_requirement": session.user_requirement[:100] + "..." if len(session.user_requirement) > 100 else session.user_requirement,
"script_path": session.reusable_script # 新增:返回脚本路径
}
@@ -188,9 +191,7 @@ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
# --- Helper Functions ---
def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False):
"""
Runs the analysis agent in a background thread for a specific session.
"""
"""在后台线程中运行分析任务"""
session = session_manager.get_session(session_id)
if not session:
print(f"Error: Session {session_id} not found in background task.")
@@ -198,119 +199,58 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str, is_fo
session.is_running = True
try:
# Create session directory if not exists (for follow-up it should accept existing)
base_output_dir = "outputs"
if not session.output_dir:
session.output_dir = create_session_output_dir(base_output_dir, user_requirement)
session.output_dir = create_session_output_dir(base_output_dir, user_requirement)
session_output_dir = session.output_dir
# Initialize Log capturing
session.log_file = os.path.join(session_output_dir, "process.log")
# Thread-safe logging requires a bit of care.
# Since we are running in a thread, redirecting sys.stdout globally is BAD for multi-session.
# However, for this MVP, if we run multiple sessions concurrently, their logs will mix in stdout.
# BUT we are writing to specific log files.
# We need a logger that writes to the session's log file.
# And the Agent needs to use that logger.
# Currently the Agent uses print().
# To support true concurrent logging without mixing, we'd need to refactor Agent to use a logger instance.
# LIMITATION: For now, we accept that stdout redirection intercepts EVERYTHING.
# So multiple concurrent sessions is risky with global stdout redirection.
# A safer approach for now: We won't redirect stdout globally for multi-session support
# unless we lock execution to one at a time.
# OR: We just rely on the fact that we might only run one analysis at a time mostly.
# Let's try to just write to the log file explicitly if we could, but we can't change Agent easily right now.
# Compromise: We will continue to use global redirection but acknowledge it's not thread-safe for output.
# A better way: Modify Agent to accept a 'log_callback'.
# For this refactor, let's stick to the existing pattern but bind it to the thread if possible? No.
# We will wrap the execution with a simple File Logger that appends to the distinct file.
# But sys.stdout is global.
# We will assume single concurrent analysis for safety, or accept mixed terminal output but separate file logs?
# Actually, if we swap sys.stdout, it affects all threads.
# So we MUST NOT swap sys.stdout if we want concurrency.
# If we don't swap stdout, we don't capture logs to file unless Agent does it.
# The Agent code has `print`.
# Correct fix: Refactor Agent to use `logging` module or pass a printer.
# Given the scope, let's just hold the lock (serialize execution) OR allow mixing in terminal
# but try to capture to file?
# Let's just write to the file.
# Let's just write to the file.
with open(session.log_file, "a" if is_followup else "w", encoding="utf-8") as f:
# 使用 PrintCapture 替代全局 FileLogger退出 with 块后自动恢复 stdout
with PrintCapture(session.log_file):
if is_followup:
f.write(f"\n--- Follow-up Session {session_id} Continued ---\n")
print(f"\n--- Follow-up Session {session_id} Continued ---")
else:
f.write(f"--- Session {session_id} Started ---\n")
print(f"--- Session {session_id} Started ---")
# We will create a custom print function that writes to the file
# And monkeypatch builtins.print? No, that's too hacky.
# Let's just use the stdout redirector, but acknowledge only one active session at a time is safe.
# We can implement a crude lock for now.
class FileLogger:
def __init__(self, filename):
self.terminal = sys.__stdout__
self.log = open(filename, "a", encoding="utf-8", buffering=1)
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def close(self):
self.log.close()
try:
if not is_followup:
llm_config = LLMConfig()
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
session.agent = agent
logger = FileLogger(session.log_file)
sys.stdout = logger # Global hijack!
try:
if not is_followup:
llm_config = LLMConfig()
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
session.agent = agent
result = agent.analyze(
user_input=user_requirement,
files=files,
session_output_dir=session_output_dir,
reset_session=True
)
else:
agent = session.agent
if not agent:
print("Error: Agent not initialized for follow-up.")
return
result = agent.analyze(
user_input=user_requirement,
files=files,
session_output_dir=session_output_dir,
reset_session=True,
)
else:
agent = session.agent
if not agent:
print("Error: Agent not initialized for follow-up.")
return
result = agent.analyze(
user_input=user_requirement,
files=None,
session_output_dir=session_output_dir,
reset_session=False,
max_rounds=10,
)
session.generated_report = result.get("report_file_path", None)
session.analysis_results = result.get("analysis_results", [])
session.reusable_script = result.get("reusable_script_path", None)
# 持久化结果
with open(os.path.join(session_output_dir, "results.json"), "w") as f:
json.dump(session.analysis_results, f, default=str)
except Exception as e:
print(f"Error during analysis: {e}")
result = agent.analyze(
user_input=user_requirement,
files=None,
session_output_dir=session_output_dir,
reset_session=False,
max_rounds=10
)
session.generated_report = result.get("report_file_path", None)
session.analysis_results = result.get("analysis_results", [])
session.reusable_script = result.get("reusable_script_path", None) # 新增:保存脚本路径
# Save results to json for persistence
with open(os.path.join(session_output_dir, "results.json"), "w") as f:
json.dump(session.analysis_results, f, default=str)
except Exception as e:
print(f"Error during analysis: {e}")
finally:
sys.stdout = logger.terminal
logger.close()
except Exception as e:
print(f"System Error: {e}")
finally:
@@ -390,24 +330,37 @@ async def get_status(session_id: str = Query(..., description="Session ID")):
@app.get("/api/export")
async def export_session(session_id: str = Query(..., description="Session ID")):
"""导出会话数据为ZIP"""
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.output_dir or not os.path.exists(session.output_dir):
raise HTTPException(status_code=404, detail="No data available for export")
# Create a zip file
import shutil
# We want to zip the contents of session_output_dir
# Zip path should be outside to avoid recursive zipping if inside
zip_base_name = os.path.join("outputs", f"export_{session_id}")
import zipfile
from datetime import datetime as dt
# shutil.make_archive expects base_name (without extension) and root_dir
archive_path = shutil.make_archive(zip_base_name, 'zip', session.output_dir)
timestamp = dt.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"report_{timestamp}.zip"
return FileResponse(archive_path, media_type='application/zip', filename=f"analysis_export_{session_id}.zip")
export_dir = "outputs"
os.makedirs(export_dir, exist_ok=True)
temp_zip_path = os.path.join(export_dir, zip_filename)
with zipfile.ZipFile(temp_zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for root, dirs, files in os.walk(session.output_dir):
for file in files:
if file.endswith(('.md', '.png', '.csv', '.log', '.json', '.yaml')):
abs_path = os.path.join(root, file)
rel_path = os.path.relpath(abs_path, session.output_dir)
zf.write(abs_path, arcname=rel_path)
return FileResponse(
path=temp_zip_path,
filename=zip_filename,
media_type='application/zip'
)
@app.get("/api/report")
async def get_report(session_id: str = Query(..., description="Session ID")):
@@ -416,33 +369,33 @@ async def get_report(session_id: str = Query(..., description="Session ID")):
raise HTTPException(status_code=404, detail="Session not found")
if not session.generated_report or not os.path.exists(session.generated_report):
return {"content": "Report not ready."}
return {"content": "Report not ready.", "paragraphs": []}
with open(session.generated_report, "r", encoding="utf-8") as f:
content = f.read()
# Fix image paths
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
web_base_path = f"/{relative_session_path}"
# Robust image path replacement
# 1. Replace explicit relative paths ./image.png
content = content.replace("](./", f"]({web_base_path}/")
# 2. Replace naked paths that might be generated like ](image.png) but NOT ](http...) or ](/...)
import re
def replace_link(match):
alt = match.group(1)
url = match.group(2)
if url.startswith("http") or url.startswith("/") or url.startswith("data:"):
return match.group(0)
# Remove ./ if exists again just in case
clean_url = url.lstrip("./")
return f"![{alt}]({web_base_path}/{clean_url})"
content = re.sub(r'!\[(.*?)\]\((.*?)\)', replace_link, content)
# 将报告按段落拆分,为前端润色功能提供结构化数据
paragraphs = _split_report_to_paragraphs(content)
return {"content": content, "base_path": web_base_path}
return {"content": content, "base_path": web_base_path, "paragraphs": paragraphs}
@app.get("/api/figures")
async def get_figures(session_id: str = Query(..., description="Session ID")):
@@ -473,7 +426,7 @@ async def get_figures(session_id: str = Query(..., description="Session ID")):
if session.output_dir:
# Assume filename is present
fname = fig.get("filename")
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
fig["web_url"] = f"/{relative_session_path}/{fname}"
figures.append(fig)
@@ -486,7 +439,7 @@ async def get_figures(session_id: str = Query(..., description="Session ID")):
pngs = glob.glob(os.path.join(session.output_dir, "*.png"))
for p in pngs:
fname = os.path.basename(p)
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
figures.append({
"filename": fname,
"description": "Auto-discovered image",
@@ -496,37 +449,6 @@ async def get_figures(session_id: str = Query(..., description="Session ID")):
return {"figures": figures}
@app.get("/api/export")
async def export_report(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
if not session or not session.output_dir:
raise HTTPException(status_code=404, detail="Session not found")
import zipfile
import tempfile
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"report_{timestamp}.zip"
export_dir = "outputs"
os.makedirs(export_dir, exist_ok=True)
temp_zip_path = os.path.join(export_dir, zip_filename)
with zipfile.ZipFile(temp_zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for root, dirs, files in os.walk(session.output_dir):
for file in files:
if file.endswith(('.md', '.png', '.csv', '.log', '.json', '.yaml')):
abs_path = os.path.join(root, file)
rel_path = os.path.relpath(abs_path, session.output_dir)
zf.write(abs_path, arcname=rel_path)
return FileResponse(
path=temp_zip_path,
filename=zip_filename,
media_type='application/zip'
)
@app.get("/api/download_script")
async def download_script(session_id: str = Query(..., description="Session ID")):
"""下载生成的Python脚本"""
@@ -580,7 +502,301 @@ async def delete_specific_session(session_id: str):
raise HTTPException(status_code=404, detail="Session not found")
return {"status": "deleted", "session_id": session_id}
return {"status": "deleted", "session_id": session_id}
# --- Report Polishing API ---
import re as _re
def _split_report_to_paragraphs(markdown_content: str) -> list:
"""
将 Markdown 报告按语义段落拆分。
每个段落包含 id、类型heading/text/table/image、原始内容。
前端可据此实现段落级选择与润色。
"""
lines = markdown_content.split("\n")
paragraphs = []
current_block = []
current_type = "text"
para_id = 0
def flush_block():
nonlocal para_id, current_block, current_type
text = "\n".join(current_block).strip()
if text:
paragraphs.append({
"id": f"p-{para_id}",
"type": current_type,
"content": text,
})
para_id += 1
current_block = []
current_type = "text"
in_table = False
in_code = False
for line in lines:
stripped = line.strip()
# 代码块边界
if stripped.startswith("```"):
if in_code:
current_block.append(line)
flush_block()
in_code = False
continue
else:
flush_block()
current_block.append(line)
current_type = "code"
in_code = True
continue
if in_code:
current_block.append(line)
continue
# 标题行 — 独立成段
if _re.match(r"^#{1,6}\s", stripped):
flush_block()
current_block.append(line)
current_type = "heading"
flush_block()
continue
# 图片行
if _re.match(r"^!\[.*\]\(.*\)", stripped):
flush_block()
current_block.append(line)
current_type = "image"
flush_block()
continue
# 表格行
if stripped.startswith("|"):
if not in_table:
flush_block()
in_table = True
current_type = "table"
current_block.append(line)
continue
else:
if in_table:
flush_block()
in_table = False
# 空行 — 段落分隔
if not stripped:
flush_block()
continue
# 普通文本
current_block.append(line)
flush_block()
return paragraphs
class PolishRequest(BaseModel):
session_id: str
paragraph_id: str
mode: str = "context" # "context" | "data" | "custom"
custom_instruction: str = ""
@app.post("/api/report/polish")
async def polish_paragraph(request: PolishRequest):
"""
对报告中指定段落进行 AI 润色。
mode:
- context: 根据上下文和图表信息润色,使表述更专业、更有洞察
- data: 结合原始分析数据重新生成该段落内容
- custom: 用户自定义润色指令
"""
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.generated_report or not os.path.exists(session.generated_report):
raise HTTPException(status_code=404, detail="Report not found")
# 读取报告并拆分段落
with open(session.generated_report, "r", encoding="utf-8") as f:
report_content = f.read()
paragraphs = _split_report_to_paragraphs(report_content)
# 找到目标段落
target = None
target_idx = -1
for i, p in enumerate(paragraphs):
if p["id"] == request.paragraph_id:
target = p
target_idx = i
break
if not target:
raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found")
# 构建上下文窗口前后各2个段落
context_window = []
for j in range(max(0, target_idx - 2), min(len(paragraphs), target_idx + 3)):
if j != target_idx:
context_window.append(paragraphs[j]["content"])
context_text = "\n\n".join(context_window)
# 收集图表信息
figures_info = ""
if session.analysis_results:
fig_parts = []
for item in session.analysis_results:
if item.get("action") == "collect_figures":
for fig in item.get("collected_figures", []):
fig_parts.append(f"- {fig.get('filename', '?')}: {fig.get('description', '')} / {fig.get('analysis', '')}")
if fig_parts:
figures_info = "\n".join(fig_parts)
# 构建润色 prompt
if request.mode == "data":
# 收集代码执行结果摘要
data_summary_parts = []
for item in session.analysis_results:
result = item.get("result", {})
if result.get("success") and result.get("output"):
output_text = result["output"][:2000]
data_summary_parts.append(output_text)
data_summary = "\n---\n".join(data_summary_parts[:5])
polish_prompt = f"""你是一位资深数据分析专家。请基于以下分析数据,重写下方段落,使其包含更精确的数据引用和更深入的业务洞察。
## 分析数据摘要
{data_summary}
## 图表信息
{figures_info}
## 需要润色的段落
{target['content']}
## 要求
- 保持原有的 Markdown 格式(标题级别、表格结构等)
- 用具体数据替换模糊描述
- 增加业务洞察和趋势判断
- 禁止使用第一人称
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
elif request.mode == "custom":
polish_prompt = f"""你是一位资深数据分析专家。请根据用户的指令润色以下段落。
## 用户指令
{request.custom_instruction}
## 上下文
{context_text}
## 图表信息
{figures_info}
## 需要润色的段落
{target['content']}
## 要求
- 保持原有的 Markdown 格式
- 严格遵循用户指令
- 禁止使用第一人称
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
else: # context mode
polish_prompt = f"""你是一位资深数据分析专家。请润色以下段落,使其表述更专业、更有洞察力。
## 上下文(前后段落)
{context_text}
## 图表信息
{figures_info}
## 需要润色的段落
{target['content']}
## 要求
- 保持原有的 Markdown 格式(标题级别、表格结构等)
- 提升专业性:使用同比、环比、占比等术语
- 增加洞察:不仅描述现象,还要分析原因和影响
- 禁止使用第一人称
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
# 调用 LLM 润色
try:
from utils.llm_helper import LLMHelper
llm = LLMHelper(LLMConfig())
polished_content = llm.call(
prompt=polish_prompt,
system_prompt="你是一位专业的数据分析报告润色专家。直接输出润色后的内容,不要添加任何解释或包裹。",
max_tokens=4096,
)
# 清理可能的代码块包裹
polished_content = polished_content.strip()
if polished_content.startswith("```markdown"):
polished_content = polished_content[len("```markdown"):].strip()
if polished_content.startswith("```"):
polished_content = polished_content[3:].strip()
if polished_content.endswith("```"):
polished_content = polished_content[:-3].strip()
return {
"paragraph_id": request.paragraph_id,
"original": target["content"],
"polished": polished_content,
"mode": request.mode,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Polish failed: {str(e)}")
class ApplyPolishRequest(BaseModel):
session_id: str
paragraph_id: str
new_content: str
@app.post("/api/report/apply")
async def apply_polish(request: ApplyPolishRequest):
"""
将润色后的内容应用到报告文件中,替换指定段落。
"""
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.generated_report or not os.path.exists(session.generated_report):
raise HTTPException(status_code=404, detail="Report not found")
with open(session.generated_report, "r", encoding="utf-8") as f:
report_content = f.read()
paragraphs = _split_report_to_paragraphs(report_content)
# 找到目标段落并替换
target = None
for p in paragraphs:
if p["id"] == request.paragraph_id:
target = p
break
if not target:
raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found")
# 在原文中替换
new_report = report_content.replace(target["content"], request.new_content, 1)
# 写回文件
with open(session.generated_report, "w", encoding="utf-8") as f:
f.write(new_report)
return {"status": "applied", "paragraph_id": request.paragraph_id}
# --- History API ---