import sys import os import threading import glob import uuid import json from datetime import datetime from typing import Optional, Dict, List from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse, JSONResponse from pydantic import BaseModel # Add parent directory to path to import agent modules sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data_analysis_agent import DataAnalysisAgent from config.llm_config import LLMConfig from utils.create_session_dir import create_session_output_dir from utils.logger import PrintCapture app = FastAPI(title="IOV Data Analysis Agent") def _to_web_path(fs_path: str) -> str: """将文件系统路径转为 URL 安全的正斜杠路径(修复 Windows 反斜杠问题)""" return fs_path.replace("\\", "/") # CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Session Management --- class SessionData: def __init__(self, session_id: str): self.session_id = session_id self.is_running = False self.output_dir: Optional[str] = None self.generated_report: Optional[str] = None self.log_file: Optional[str] = None self.analysis_results: List[Dict] = [] # Store analysis results for gallery self.agent: Optional[DataAnalysisAgent] = None # Store the agent instance for follow-up # 新增:进度跟踪 self.current_round: int = 0 self.max_rounds: int = 20 self.progress_percentage: float = 0.0 self.status_message: str = "等待开始" # 新增:历史记录 self.created_at: str = "" self.last_updated: str = "" self.user_requirement: str = "" self.file_list: List[str] = [] self.reusable_script: Optional[str] = None # 新增:可复用脚本路径 class SessionManager: def __init__(self): self.sessions: Dict[str, SessionData] = {} self.lock = threading.Lock() def create_session(self) -> str: with self.lock: session_id = str(uuid.uuid4()) self.sessions[session_id] = SessionData(session_id) return session_id def get_session(self, session_id: str) -> Optional[SessionData]: if session_id in self.sessions: return self.sessions[session_id] # Fallback: Try to reconstruct from disk for history sessions output_dir = os.path.join("outputs", f"session_{session_id}") if os.path.exists(output_dir) and os.path.isdir(output_dir): return self._reconstruct_session(session_id, output_dir) return None def _reconstruct_session(self, session_id: str, output_dir: str) -> SessionData: """从磁盘目录重建会话对象""" session = SessionData(session_id) session.output_dir = output_dir session.is_running = False session.current_round = session.max_rounds session.progress_percentage = 100.0 session.status_message = "已完成 (历史记录)" # Recover Log log_path = os.path.join(output_dir, "process.log") if os.path.exists(log_path): session.log_file = log_path # Recover Report # 宽容查找:扫描所有 .md 文件,优先取包含 "report" 或 "报告" 的文件 md_files = glob.glob(os.path.join(output_dir, "*.md")) if md_files: # 默认取第一个 chosen = md_files[0] # 尝试找更好的匹配 for md in md_files: fname = os.path.basename(md).lower() if "report" in fname or "报告" in fname: chosen = md break session.generated_report = chosen # Recover Script (查找可能的脚本文件) possible_scripts = ["data_analysis_script.py", "script.py", "analysis_script.py"] for s in possible_scripts: p = os.path.join(output_dir, s) if os.path.exists(p): session.reusable_script = p break # Recover Results (images etc) results_json = os.path.join(output_dir, "results.json") if os.path.exists(results_json): try: with open(results_json, "r") as f: session.analysis_results = json.load(f) except: pass # Recover Metadata try: stat = os.stat(output_dir) dt = datetime.fromtimestamp(stat.st_ctime) session.created_at = dt.strftime("%Y-%m-%d %H:%M:%S") except: pass # Cache it with self.lock: self.sessions[session_id] = session return session def list_sessions(self): return list(self.sessions.keys()) def delete_session(self, session_id: str) -> bool: """删除指定会话""" with self.lock: if session_id in self.sessions: session = self.sessions[session_id] if session.agent: session.agent.reset() del self.sessions[session_id] return True return False def get_session_info(self, session_id: str) -> Optional[Dict]: """获取会话详细信息""" session = self.get_session(session_id) if session: return { "session_id": session.session_id, "is_running": session.is_running, "progress": session.progress_percentage, "status": session.status_message, "current_round": session.current_round, "max_rounds": session.max_rounds, "created_at": session.created_at, "last_updated": session.last_updated, "user_requirement": session.user_requirement[:100] + "..." if len(session.user_requirement) > 100 else session.user_requirement, "script_path": session.reusable_script # 新增:返回脚本路径 } return None session_manager = SessionManager() # Mount static files os.makedirs("web/static", exist_ok=True) os.makedirs("uploads", exist_ok=True) os.makedirs("outputs", exist_ok=True) app.mount("/static", StaticFiles(directory="web/static"), name="static") app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") # --- Helper Functions --- def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False): """在后台线程中运行分析任务""" session = session_manager.get_session(session_id) if not session: print(f"Error: Session {session_id} not found in background task.") return session.is_running = True try: base_output_dir = "outputs" if not session.output_dir: session.output_dir = create_session_output_dir(base_output_dir, user_requirement) session_output_dir = session.output_dir session.log_file = os.path.join(session_output_dir, "process.log") # 使用 PrintCapture 替代全局 FileLogger,退出 with 块后自动恢复 stdout with PrintCapture(session.log_file): if is_followup: print(f"\n--- Follow-up Session {session_id} Continued ---") else: print(f"--- Session {session_id} Started ---") try: if not is_followup: llm_config = LLMConfig() agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir) session.agent = agent result = agent.analyze( user_input=user_requirement, files=files, session_output_dir=session_output_dir, reset_session=True, ) else: agent = session.agent if not agent: print("Error: Agent not initialized for follow-up.") return result = agent.analyze( user_input=user_requirement, files=None, session_output_dir=session_output_dir, reset_session=False, max_rounds=10, ) session.generated_report = result.get("report_file_path", None) session.analysis_results = result.get("analysis_results", []) session.reusable_script = result.get("reusable_script_path", None) # 持久化结果 with open(os.path.join(session_output_dir, "results.json"), "w") as f: json.dump(session.analysis_results, f, default=str) except Exception as e: print(f"Error during analysis: {e}") except Exception as e: print(f"System Error: {e}") finally: session.is_running = False # --- Pydantic Models --- class StartRequest(BaseModel): requirement: str class ChatRequest(BaseModel): session_id: str message: str # --- API Endpoints --- @app.get("/") async def read_root(): return FileResponse("web/static/index.html") @app.post("/api/upload") async def upload_files(files: list[UploadFile] = File(...)): saved_files = [] for file in files: file_location = f"uploads/{file.filename}" with open(file_location, "wb+") as file_object: file_object.write(file.file.read()) saved_files.append(file_location) return {"info": f"Saved {len(saved_files)} files", "paths": saved_files} @app.post("/api/start") async def start_analysis(request: StartRequest, background_tasks: BackgroundTasks): session_id = session_manager.create_session() files = glob.glob("uploads/*.csv") if not files: if os.path.exists("cleaned_data.csv"): files = ["cleaned_data.csv"] else: raise HTTPException(status_code=400, detail="No CSV files found") files = [os.path.abspath(f) for f in files] # Only use absolute paths background_tasks.add_task(run_analysis_task, session_id, files, request.requirement, is_followup=False) return {"status": "started", "session_id": session_id} @app.post("/api/chat") async def chat_analysis(request: ChatRequest, background_tasks: BackgroundTasks): session = session_manager.get_session(request.session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if session.is_running: raise HTTPException(status_code=400, detail="Analysis already in progress") background_tasks.add_task(run_analysis_task, request.session_id, [], request.message, is_followup=True) return {"status": "started"} @app.get("/api/status") async def get_status(session_id: str = Query(..., description="Session ID")): session = session_manager.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") log_content = "" if session.log_file and os.path.exists(session.log_file): with open(session.log_file, "r", encoding="utf-8") as f: log_content = f.read() return { "is_running": session.is_running, "log": log_content, "has_report": session.generated_report is not None, "report_path": session.generated_report, "script_path": session.reusable_script # 新增:返回脚本路径 } @app.get("/api/export") async def export_session(session_id: str = Query(..., description="Session ID")): """导出会话数据为ZIP""" session = session_manager.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if not session.output_dir or not os.path.exists(session.output_dir): raise HTTPException(status_code=404, detail="No data available for export") import zipfile from datetime import datetime as dt timestamp = dt.now().strftime("%Y%m%d_%H%M%S") zip_filename = f"report_{timestamp}.zip" export_dir = "outputs" os.makedirs(export_dir, exist_ok=True) temp_zip_path = os.path.join(export_dir, zip_filename) with zipfile.ZipFile(temp_zip_path, "w", zipfile.ZIP_DEFLATED) as zf: for root, dirs, files in os.walk(session.output_dir): for file in files: if file.endswith(('.md', '.png', '.csv', '.log', '.json', '.yaml')): abs_path = os.path.join(root, file) rel_path = os.path.relpath(abs_path, session.output_dir) zf.write(abs_path, arcname=rel_path) return FileResponse( path=temp_zip_path, filename=zip_filename, media_type='application/zip' ) @app.get("/api/report") async def get_report(session_id: str = Query(..., description="Session ID")): session = session_manager.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if not session.generated_report or not os.path.exists(session.generated_report): return {"content": "Report not ready.", "paragraphs": []} with open(session.generated_report, "r", encoding="utf-8") as f: content = f.read() # Fix image paths relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd())) web_base_path = f"/{relative_session_path}" # Robust image path replacement content = content.replace("](./", f"]({web_base_path}/") import re def replace_link(match): alt = match.group(1) url = match.group(2) if url.startswith("http") or url.startswith("/") or url.startswith("data:"): return match.group(0) clean_url = url.lstrip("./") return f"![{alt}]({web_base_path}/{clean_url})" content = re.sub(r'!\[(.*?)\]\((.*?)\)', replace_link, content) # 将报告按段落拆分,为前端润色功能提供结构化数据 paragraphs = _split_report_to_paragraphs(content) return {"content": content, "base_path": web_base_path, "paragraphs": paragraphs} @app.get("/api/figures") async def get_figures(session_id: str = Query(..., description="Session ID")): session = session_manager.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") # We can try to get from memory first results = session.analysis_results # If empty in memory (maybe server restarted but files exist?), try load json if not results and session.output_dir: json_path = os.path.join(session.output_dir, "results.json") if os.path.exists(json_path): with open(json_path, 'r') as f: results = json.load(f) # Extract collected figures figures = [] # We iterate over analysis results to find 'collect_figures' actions if results: for item in results: if item.get("action") == "collect_figures": collected = item.get("collected_figures", []) for fig in collected: # Enrich with web path if session.output_dir: # Assume filename is present fname = fig.get("filename") relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd())) fig["web_url"] = f"/{relative_session_path}/{fname}" figures.append(fig) # Also check for 'generate_code' results that might have implicit figures if we parse them # But the 'collect_figures' action is the reliable source as per agent design # Auto-discovery fallback if list is empty but pngs exist? if not figures and session.output_dir: # Simple scan pngs = glob.glob(os.path.join(session.output_dir, "*.png")) for p in pngs: fname = os.path.basename(p) relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd())) figures.append({ "filename": fname, "description": "Auto-discovered image", "analysis": "No analysis available", "web_url": f"/{relative_session_path}/{fname}" }) return {"figures": figures} @app.get("/api/download_script") async def download_script(session_id: str = Query(..., description="Session ID")): """下载生成的Python脚本""" session = session_manager.get_session(session_id) if not session or not session.reusable_script: raise HTTPException(status_code=404, detail="Script not found") if not os.path.exists(session.reusable_script): raise HTTPException(status_code=404, detail="Script file missing on server") return FileResponse( path=session.reusable_script, filename=os.path.basename(session.reusable_script), media_type='text/x-python' ) # --- Tools API --- # --- 新增API端点 --- @app.get("/api/sessions/progress") async def get_session_progress(session_id: str = Query(..., description="Session ID")): """获取会话分析进度""" session_info = session_manager.get_session_info(session_id) if not session_info: raise HTTPException(status_code=404, detail="Session not found") return session_info @app.get("/api/sessions/list") async def list_all_sessions(): """获取所有会话列表""" session_ids = session_manager.list_sessions() sessions_info = [] for sid in session_ids: info = session_manager.get_session_info(sid) if info: sessions_info.append(info) return {"sessions": sessions_info, "total": len(sessions_info)} @app.delete("/api/sessions/{session_id}") async def delete_specific_session(session_id: str): """删除指定会话""" success = session_manager.delete_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found") return {"status": "deleted", "session_id": session_id} # --- Report Polishing API --- import re as _re def _split_report_to_paragraphs(markdown_content: str) -> list: """ 将 Markdown 报告按语义段落拆分。 每个段落包含 id、类型(heading/text/table/image)、原始内容。 前端可据此实现段落级选择与润色。 """ lines = markdown_content.split("\n") paragraphs = [] current_block = [] current_type = "text" para_id = 0 def flush_block(): nonlocal para_id, current_block, current_type text = "\n".join(current_block).strip() if text: paragraphs.append({ "id": f"p-{para_id}", "type": current_type, "content": text, }) para_id += 1 current_block = [] current_type = "text" in_table = False in_code = False for line in lines: stripped = line.strip() # 代码块边界 if stripped.startswith("```"): if in_code: current_block.append(line) flush_block() in_code = False continue else: flush_block() current_block.append(line) current_type = "code" in_code = True continue if in_code: current_block.append(line) continue # 标题行 — 独立成段 if _re.match(r"^#{1,6}\s", stripped): flush_block() current_block.append(line) current_type = "heading" flush_block() continue # 图片行 if _re.match(r"^!\[.*\]\(.*\)", stripped): flush_block() current_block.append(line) current_type = "image" flush_block() continue # 表格行 if stripped.startswith("|"): if not in_table: flush_block() in_table = True current_type = "table" current_block.append(line) continue else: if in_table: flush_block() in_table = False # 空行 — 段落分隔 if not stripped: flush_block() continue # 普通文本 current_block.append(line) flush_block() return paragraphs class PolishRequest(BaseModel): session_id: str paragraph_id: str mode: str = "context" # "context" | "data" | "custom" custom_instruction: str = "" @app.post("/api/report/polish") async def polish_paragraph(request: PolishRequest): """ 对报告中指定段落进行 AI 润色。 mode: - context: 根据上下文和图表信息润色,使表述更专业、更有洞察 - data: 结合原始分析数据重新生成该段落内容 - custom: 用户自定义润色指令 """ session = session_manager.get_session(request.session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if not session.generated_report or not os.path.exists(session.generated_report): raise HTTPException(status_code=404, detail="Report not found") # 读取报告并拆分段落 with open(session.generated_report, "r", encoding="utf-8") as f: report_content = f.read() paragraphs = _split_report_to_paragraphs(report_content) # 找到目标段落 target = None target_idx = -1 for i, p in enumerate(paragraphs): if p["id"] == request.paragraph_id: target = p target_idx = i break if not target: raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found") # 构建上下文窗口(前后各2个段落) context_window = [] for j in range(max(0, target_idx - 2), min(len(paragraphs), target_idx + 3)): if j != target_idx: context_window.append(paragraphs[j]["content"]) context_text = "\n\n".join(context_window) # 收集图表信息 figures_info = "" if session.analysis_results: fig_parts = [] for item in session.analysis_results: if item.get("action") == "collect_figures": for fig in item.get("collected_figures", []): fig_parts.append(f"- {fig.get('filename', '?')}: {fig.get('description', '')} / {fig.get('analysis', '')}") if fig_parts: figures_info = "\n".join(fig_parts) # 构建润色 prompt if request.mode == "data": # 收集代码执行结果摘要 data_summary_parts = [] for item in session.analysis_results: result = item.get("result", {}) if result.get("success") and result.get("output"): output_text = result["output"][:2000] data_summary_parts.append(output_text) data_summary = "\n---\n".join(data_summary_parts[:5]) polish_prompt = f"""你是一位资深数据分析专家。请基于以下分析数据,重写下方段落,使其包含更精确的数据引用和更深入的业务洞察。 ## 分析数据摘要 {data_summary} ## 图表信息 {figures_info} ## 需要润色的段落 {target['content']} ## 要求 - 保持原有的 Markdown 格式(标题级别、表格结构等) - 用具体数据替换模糊描述 - 增加业务洞察和趋势判断 - 禁止使用第一人称 - 直接输出润色后的 Markdown 内容,不要包裹在代码块中""" elif request.mode == "custom": polish_prompt = f"""你是一位资深数据分析专家。请根据用户的指令润色以下段落。 ## 用户指令 {request.custom_instruction} ## 上下文 {context_text} ## 图表信息 {figures_info} ## 需要润色的段落 {target['content']} ## 要求 - 保持原有的 Markdown 格式 - 严格遵循用户指令 - 禁止使用第一人称 - 直接输出润色后的 Markdown 内容,不要包裹在代码块中""" else: # context mode polish_prompt = f"""你是一位资深数据分析专家。请润色以下段落,使其表述更专业、更有洞察力。 ## 上下文(前后段落) {context_text} ## 图表信息 {figures_info} ## 需要润色的段落 {target['content']} ## 要求 - 保持原有的 Markdown 格式(标题级别、表格结构等) - 提升专业性:使用同比、环比、占比等术语 - 增加洞察:不仅描述现象,还要分析原因和影响 - 禁止使用第一人称 - 直接输出润色后的 Markdown 内容,不要包裹在代码块中""" # 调用 LLM 润色 try: from utils.llm_helper import LLMHelper llm = LLMHelper(LLMConfig()) polished_content = llm.call( prompt=polish_prompt, system_prompt="你是一位专业的数据分析报告润色专家。直接输出润色后的内容,不要添加任何解释或包裹。", max_tokens=4096, ) # 清理可能的代码块包裹 polished_content = polished_content.strip() if polished_content.startswith("```markdown"): polished_content = polished_content[len("```markdown"):].strip() if polished_content.startswith("```"): polished_content = polished_content[3:].strip() if polished_content.endswith("```"): polished_content = polished_content[:-3].strip() return { "paragraph_id": request.paragraph_id, "original": target["content"], "polished": polished_content, "mode": request.mode, } except Exception as e: raise HTTPException(status_code=500, detail=f"Polish failed: {str(e)}") class ApplyPolishRequest(BaseModel): session_id: str paragraph_id: str new_content: str @app.post("/api/report/apply") async def apply_polish(request: ApplyPolishRequest): """ 将润色后的内容应用到报告文件中,替换指定段落。 """ session = session_manager.get_session(request.session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if not session.generated_report or not os.path.exists(session.generated_report): raise HTTPException(status_code=404, detail="Report not found") with open(session.generated_report, "r", encoding="utf-8") as f: report_content = f.read() paragraphs = _split_report_to_paragraphs(report_content) # 找到目标段落并替换 target = None for p in paragraphs: if p["id"] == request.paragraph_id: target = p break if not target: raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found") # 在原文中替换 new_report = report_content.replace(target["content"], request.new_content, 1) # 写回文件 with open(session.generated_report, "w", encoding="utf-8") as f: f.write(new_report) return {"status": "applied", "paragraph_id": request.paragraph_id} # --- History API --- @app.get("/api/history") async def get_history(): """ Get list of past analysis sessions from outputs directory """ history = [] output_base = "outputs" if not os.path.exists(output_base): return {"history": []} try: # Scan for session_* directories for entry in os.scandir(output_base): if entry.is_dir() and entry.name.startswith("session_"): # Extract timestamp from folder name: session_20250101_120000 session_id = entry.name.replace("session_", "") # Check creation time or extract from name try: # Try to parse timestamp from ID if it matches format # Format: YYYYMMDD_HHMMSS timestamp_str = session_id dt = datetime.strptime(timestamp_str, "%Y%m%d_%H%M%S") display_time = dt.strftime("%Y-%m-%d %H:%M:%S") sort_key = dt.timestamp() except ValueError: # Fallback to file creation time sort_key = entry.stat().st_ctime display_time = datetime.fromtimestamp(sort_key).strftime("%Y-%m-%d %H:%M:%S") history.append({ "id": session_id, "timestamp": display_time, "sort_key": sort_key, "name": f"Session {display_time}" }) # Sort by latest first history.sort(key=lambda x: x["sort_key"], reverse=True) # Cleanup internal sort key for item in history: del item["sort_key"] return {"history": history} except Exception as e: print(f"Error scanning history: {e}") return {"history": []} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)