import sys import os import threading import glob import uuid import json import re from datetime import datetime from typing import Optional, Dict, List from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse, JSONResponse from pydantic import BaseModel import pandas as pd # Add parent directory to path to import agent modules sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data_analysis_agent import DataAnalysisAgent from config.llm_config import LLMConfig from utils.create_session_dir import create_session_output_dir from utils.logger import PrintCapture app = FastAPI(title="IOV Data Analysis Agent") def _to_web_path(fs_path: str) -> str: """将文件系统路径转为 URL 安全的正斜杠路径(修复 Windows 反斜杠问题)""" return fs_path.replace("\\", "/") # CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Session Management --- class SessionData: def __init__(self, session_id: str): self.session_id = session_id self.is_running = False self.output_dir: Optional[str] = None self.generated_report: Optional[str] = None self.log_file: Optional[str] = None self.analysis_results: List[Dict] = [] # Store analysis results for gallery self.rounds: List[Dict] = [] # Structured Round_Data objects self.data_files: List[Dict] = [] # File metadata dicts self.agent: Optional[DataAnalysisAgent] = None # Store the agent instance for follow-up # 新增:进度跟踪 self.current_round: int = 0 self.max_rounds: int = 20 self.progress_percentage: float = 0.0 self.status_message: str = "等待开始" # 新增:历史记录 self.created_at: str = "" self.last_updated: str = "" self.user_requirement: str = "" self.file_list: List[str] = [] self.reusable_script: Optional[str] = None # 新增:可复用脚本路径 class SessionManager: def __init__(self): self.sessions: Dict[str, SessionData] = {} self.lock = threading.Lock() def create_session(self) -> str: with self.lock: session_id = str(uuid.uuid4()) self.sessions[session_id] = SessionData(session_id) return session_id def get_session(self, session_id: str) -> Optional[SessionData]: if session_id in self.sessions: return self.sessions[session_id] # Fallback: Try to reconstruct from disk for history sessions # First try the old convention: outputs/session_{uuid} output_dir = os.path.join("outputs", f"session_{session_id}") if os.path.exists(output_dir) and os.path.isdir(output_dir): return self._reconstruct_session(session_id, output_dir) # Scan all session directories for session_meta.json matching this session_id # This handles the case where output_dir uses a timestamp name, not the UUID outputs_root = "outputs" if os.path.exists(outputs_root): for dirname in os.listdir(outputs_root): dir_path = os.path.join(outputs_root, dirname) if not os.path.isdir(dir_path) or not dirname.startswith("session_"): continue meta_path = os.path.join(dir_path, "session_meta.json") if os.path.exists(meta_path): try: with open(meta_path, "r", encoding="utf-8") as f: meta = json.load(f) if meta.get("session_id") == session_id: return self._reconstruct_session(session_id, dir_path) except Exception: continue return None def _reconstruct_session(self, session_id: str, output_dir: str) -> SessionData: """从磁盘目录重建会话对象""" session = SessionData(session_id) session.output_dir = output_dir session.is_running = False session.current_round = session.max_rounds session.progress_percentage = 100.0 session.status_message = "已完成 (历史记录)" # Read session_meta.json if available meta = {} meta_path = os.path.join(output_dir, "session_meta.json") if os.path.exists(meta_path): try: with open(meta_path, "r", encoding="utf-8") as f: meta = json.load(f) except Exception: pass # Recover Log log_path = os.path.join(output_dir, "process.log") if os.path.exists(log_path): session.log_file = log_path # Recover Report — prefer meta, then scan .md files report_path = meta.get("report_path") if report_path and os.path.exists(report_path): session.generated_report = report_path else: md_files = glob.glob(os.path.join(output_dir, "*.md")) if md_files: chosen = md_files[0] for md in md_files: fname = os.path.basename(md).lower() if "report" in fname or "报告" in fname: chosen = md break session.generated_report = chosen # Recover Script — prefer meta, then scan for 分析脚本_*.py or other patterns script_path = meta.get("script_path") if script_path and os.path.exists(script_path): session.reusable_script = script_path else: # Try Chinese-named scripts first (generated by this system) script_files = glob.glob(os.path.join(output_dir, "分析脚本_*.py")) if not script_files: for s in ["data_analysis_script.py", "script.py", "analysis_script.py"]: p = os.path.join(output_dir, s) if os.path.exists(p): script_files = [p] break if script_files: session.reusable_script = script_files[0] # Recover Results (images etc) results_json = os.path.join(output_dir, "results.json") if os.path.exists(results_json): try: with open(results_json, "r") as f: data = json.load(f) # Support both old format (plain list) and new format (dict with rounds/data_files) if isinstance(data, dict): session.analysis_results = data.get("analysis_results", []) session.rounds = data.get("rounds", []) session.data_files = data.get("data_files", []) else: # Legacy format: data is the analysis_results list directly session.analysis_results = data except: pass # Recover Metadata session.file_list = meta.get("file_list", []) session.user_requirement = meta.get("user_requirement", "") try: stat = os.stat(output_dir) dt = datetime.fromtimestamp(stat.st_ctime) session.created_at = dt.strftime("%Y-%m-%d %H:%M:%S") except: pass # Cache it with self.lock: self.sessions[session_id] = session return session def list_sessions(self): return list(self.sessions.keys()) def delete_session(self, session_id: str) -> bool: """删除指定会话""" with self.lock: if session_id in self.sessions: session = self.sessions[session_id] if session.agent: session.agent.reset() del self.sessions[session_id] return True return False def get_session_info(self, session_id: str) -> Optional[Dict]: """获取会话详细信息""" session = self.get_session(session_id) if session: return { "session_id": session.session_id, "is_running": session.is_running, "progress": session.progress_percentage, "status": session.status_message, "current_round": session.current_round, "max_rounds": session.max_rounds, "created_at": session.created_at, "last_updated": session.last_updated, "user_requirement": session.user_requirement[:100] + "..." if len(session.user_requirement) > 100 else session.user_requirement, "script_path": session.reusable_script # 新增:返回脚本路径 } return None session_manager = SessionManager() # Mount static files os.makedirs("web/static", exist_ok=True) os.makedirs("uploads", exist_ok=True) os.makedirs("outputs", exist_ok=True) app.mount("/static", StaticFiles(directory="web/static"), name="static") app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") # --- Helper Functions --- def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False, template_name: str = None): """在后台线程中运行分析任务""" session = session_manager.get_session(session_id) if not session: print(f"Error: Session {session_id} not found in background task.") return session.is_running = True session.file_list = files or [] try: base_output_dir = "outputs" if not session.output_dir: session.output_dir = create_session_output_dir(base_output_dir, user_requirement) session_output_dir = session.output_dir session.log_file = os.path.join(session_output_dir, "process.log") # Persist session-to-directory mapping immediately so recovery works # even if the server restarts mid-analysis try: with open(os.path.join(session_output_dir, "session_meta.json"), "w") as f: json.dump({"session_id": session_id, "user_requirement": user_requirement, "file_list": files or []}, f, default=str) except Exception: pass # 使用 PrintCapture 替代全局 FileLogger,退出 with 块后自动恢复 stdout with PrintCapture(session.log_file): if is_followup: print(f"\n--- Follow-up Session {session_id} Continued ---") else: print(f"--- Session {session_id} Started ---") try: if not is_followup: llm_config = LLMConfig() agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir) session.agent = agent # Wire progress callback to update session progress fields def progress_cb(current, total, message): session.current_round = current session.max_rounds = total session.progress_percentage = round((current / total) * 100, 1) if total > 0 else 0 session.status_message = message agent.set_progress_callback(progress_cb) agent.set_session_ref(session) result = agent.analyze( user_input=user_requirement, files=files, session_output_dir=session_output_dir, reset_session=True, template_name=template_name, ) else: agent = session.agent if not agent: # Agent lost (server restart). Recreate and run as new session # with the follow-up requirement, reusing the same output dir. print("[WARN] Agent not initialized, recreating for follow-up.") llm_config = LLMConfig() agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir) session.agent = agent # Wire progress callback for follow-up sessions def progress_cb_followup(current, total, message): session.current_round = current session.max_rounds = total session.progress_percentage = round((current / total) * 100, 1) if total > 0 else 0 session.status_message = message agent.set_progress_callback(progress_cb_followup) agent.set_session_ref(session) # If agent was just recreated, load data files so it has context data_files = None if not agent.data_files and session.file_list: data_files = session.file_list result = agent.analyze( user_input=user_requirement, files=data_files, session_output_dir=session_output_dir, reset_session=not bool(agent.conversation_history), max_rounds=10, ) session.generated_report = result.get("report_file_path", None) session.analysis_results = result.get("analysis_results", []) session.reusable_script = result.get("reusable_script_path", None) # 持久化结果 with open(os.path.join(session_output_dir, "results.json"), "w") as f: json.dump({ "analysis_results": session.analysis_results, "rounds": session.rounds, "data_files": session.data_files, }, f, default=str) # Persist session-to-directory mapping for recovery after server restart try: with open(os.path.join(session_output_dir, "session_meta.json"), "w") as f: json.dump({ "session_id": session_id, "user_requirement": user_requirement, "report_path": session.generated_report, "script_path": session.reusable_script, }, f, default=str) except Exception: pass except Exception as e: print(f"Error during analysis: {e}") except Exception as e: print(f"System Error: {e}") finally: session.is_running = False # --- Pydantic Models --- class StartRequest(BaseModel): requirement: str template: Optional[str] = None files: Optional[List[str]] = None class ChatRequest(BaseModel): session_id: str message: str # --- API Endpoints --- @app.get("/") async def read_root(): return FileResponse("web/static/index.html") @app.post("/api/upload") async def upload_files(files: list[UploadFile] = File(...)): saved_files = [] for file in files: file_location = f"uploads/{file.filename}" with open(file_location, "wb+") as file_object: file_object.write(file.file.read()) saved_files.append(file_location) # Track the most recently uploaded files for the next analysis app.state.last_uploaded_files = saved_files return {"info": f"Saved {len(saved_files)} files", "paths": saved_files} @app.post("/api/start") async def start_analysis(request: StartRequest, background_tasks: BackgroundTasks): session_id = session_manager.create_session() # Priority: request.files (from frontend) > last_uploaded > scan uploads/ files = None if request.files: files = [f for f in request.files if os.path.exists(f)] if not files: files = getattr(app.state, 'last_uploaded_files', None) if files: files = [f for f in files if os.path.exists(f)] if not files: files = glob.glob("uploads/*.csv") + glob.glob("uploads/*.xlsx") if not files: raise HTTPException(status_code=400, detail="No data files found. Please upload files first.") files = [os.path.abspath(f) for f in files] # Only use absolute paths background_tasks.add_task(run_analysis_task, session_id, files, request.requirement, is_followup=False, template_name=request.template) return {"status": "started", "session_id": session_id} @app.post("/api/chat") async def chat_analysis(request: ChatRequest, background_tasks: BackgroundTasks): session = session_manager.get_session(request.session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if session.is_running: raise HTTPException(status_code=400, detail="Analysis already in progress") background_tasks.add_task(run_analysis_task, request.session_id, [], request.message, is_followup=True) return {"status": "started"} import math as _math def _sanitize_value(v): """Make any value JSON-serializable. Handles: NaN/inf floats → None, pandas Timestamp/Timedelta → str, numpy integers/floats → Python int/float, dicts and lists recursively. """ if v is None: return None if isinstance(v, float): if _math.isnan(v) or _math.isinf(v): return None return v if isinstance(v, (int, bool, str)): return v if isinstance(v, dict): return {k: _sanitize_value(val) for k, val in v.items()} if isinstance(v, list): return [_sanitize_value(item) for item in v] # pandas Timestamp, Timedelta, NaT try: if pd.isna(v): return None except (TypeError, ValueError): pass if hasattr(v, 'isoformat'): # datetime, Timestamp return v.isoformat() # numpy scalar types if hasattr(v, 'item'): return v.item() # Fallback: convert to string return str(v) @app.get("/api/status") async def get_status(session_id: str = Query(..., description="Session ID")): session = session_manager.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") log_content = "" if session.log_file and os.path.exists(session.log_file): with open(session.log_file, "r", encoding="utf-8") as f: log_content = f.read() response_data = { "is_running": session.is_running, "log": log_content, "has_report": session.generated_report is not None, "report_path": session.generated_report, "script_path": session.reusable_script, "current_round": session.current_round, "max_rounds": session.max_rounds, "progress_percentage": session.progress_percentage, "status_message": session.status_message, "rounds": _sanitize_value(session.rounds), } return JSONResponse(content=response_data) @app.get("/api/export") async def export_session(session_id: str = Query(..., description="Session ID")): """导出会话数据为ZIP""" session = session_manager.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if not session.output_dir or not os.path.exists(session.output_dir): raise HTTPException(status_code=404, detail="No data available for export") import zipfile from datetime import datetime as dt timestamp = dt.now().strftime("%Y%m%d_%H%M%S") zip_filename = f"report_{timestamp}.zip" export_dir = "outputs" os.makedirs(export_dir, exist_ok=True) temp_zip_path = os.path.join(export_dir, zip_filename) with zipfile.ZipFile(temp_zip_path, "w", zipfile.ZIP_DEFLATED) as zf: for root, dirs, files in os.walk(session.output_dir): for file in files: if file.endswith(('.md', '.png', '.csv', '.log', '.json', '.yaml')): abs_path = os.path.join(root, file) rel_path = os.path.relpath(abs_path, session.output_dir) zf.write(abs_path, arcname=rel_path) return FileResponse( path=temp_zip_path, filename=zip_filename, media_type='application/zip' ) @app.get("/api/report") async def get_report(session_id: str = Query(..., description="Session ID")): session = session_manager.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if not session.generated_report or not os.path.exists(session.generated_report): return {"content": "Report not ready.", "paragraphs": []} with open(session.generated_report, "r", encoding="utf-8") as f: content = f.read() # Fix image paths relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd())) web_base_path = f"/{relative_session_path}" # Robust image path replacement content = content.replace("](./", f"]({web_base_path}/") import re def replace_link(match): alt = match.group(1) url = match.group(2) if url.startswith("http") or url.startswith("/") or url.startswith("data:"): return match.group(0) clean_url = url.lstrip("./") return f"![{alt}]({web_base_path}/{clean_url})" content = re.sub(r'!\[(.*?)\]\((.*?)\)', replace_link, content) # 将报告按段落拆分,为前端润色功能提供结构化数据 paragraphs = _split_report_to_paragraphs(content) # Extract evidence annotations and build supporting_data mapping supporting_data = _extract_evidence_annotations(paragraphs, session) return {"content": content, "base_path": web_base_path, "paragraphs": paragraphs, "supporting_data": supporting_data} @app.get("/api/figures") async def get_figures(session_id: str = Query(..., description="Session ID")): session = session_manager.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") # We can try to get from memory first results = session.analysis_results # If empty in memory (maybe server restarted but files exist?), try load json if not results and session.output_dir: json_path = os.path.join(session.output_dir, "results.json") if os.path.exists(json_path): with open(json_path, 'r') as f: results = json.load(f) # Extract collected figures figures = [] # We iterate over analysis results to find 'collect_figures' actions if results: for item in results: if item.get("action") == "collect_figures": collected = item.get("collected_figures", []) for fig in collected: # Enrich with web path if session.output_dir: # Assume filename is present fname = fig.get("filename") relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd())) fig["web_url"] = f"/{relative_session_path}/{fname}" figures.append(fig) # Also check for 'generate_code' results that might have implicit figures if we parse them # But the 'collect_figures' action is the reliable source as per agent design # Auto-discovery fallback if list is empty but pngs exist? if not figures and session.output_dir: # Simple scan pngs = glob.glob(os.path.join(session.output_dir, "*.png")) for p in pngs: fname = os.path.basename(p) relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd())) figures.append({ "filename": fname, "description": "Auto-discovered image", "analysis": "No analysis available", "web_url": f"/{relative_session_path}/{fname}" }) return {"figures": figures} @app.get("/api/download_script") async def download_script(session_id: str = Query(..., description="Session ID")): """下载生成的Python脚本""" session = session_manager.get_session(session_id) if not session or not session.reusable_script: raise HTTPException(status_code=404, detail="Script not found") if not os.path.exists(session.reusable_script): raise HTTPException(status_code=404, detail="Script file missing on server") return FileResponse( path=session.reusable_script, filename=os.path.basename(session.reusable_script), media_type='text/x-python' ) # --- Tools API --- @app.get("/api/templates") async def list_available_templates(): from utils.analysis_templates import list_templates return {"templates": list_templates()} @app.get("/api/templates/{template_name}") async def get_template_detail(template_name: str): """获取单个模板的完整内容(含步骤)""" from utils.analysis_templates import get_template try: tpl = get_template(template_name) return tpl.to_dict() except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) @app.put("/api/templates/{template_name}") async def update_template(template_name: str, body: dict): """创建或更新模板""" from utils.analysis_templates import save_template try: filepath = save_template(template_name, body) return {"status": "saved", "filepath": filepath} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.delete("/api/templates/{template_name}") async def remove_template(template_name: str): """删除模板""" from utils.analysis_templates import delete_template if delete_template(template_name): return {"status": "deleted"} raise HTTPException(status_code=404, detail=f"Template not found: {template_name}") # --- Data Files API --- @app.get("/api/data-files") async def list_data_files(session_id: str = Query(..., description="Session ID")): """Return session.data_files merged with fallback directory scan for CSV/XLSX files.""" session = session_manager.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") # Start with known data_files from session known_files = {f["filename"]: f for f in session.data_files} # Fallback directory scan for CSV/XLSX in output_dir if session.output_dir and os.path.exists(session.output_dir): # Collect original uploaded file basenames to exclude them uploaded_basenames = set() if hasattr(session, "file_list"): for fp in session.file_list: uploaded_basenames.add(os.path.basename(fp)) for pattern in ("*.csv", "*.xlsx"): for fpath in glob.glob(os.path.join(session.output_dir, pattern)): fname = os.path.basename(fpath) if fname in uploaded_basenames: continue if fname not in known_files: try: size_bytes = os.path.getsize(fpath) except OSError: size_bytes = 0 known_files[fname] = { "filename": fname, "description": "", "rows": 0, "cols": 0, "size_bytes": size_bytes, } return {"files": list(known_files.values())} @app.get("/api/data-files/preview") async def preview_data_file( session_id: str = Query(..., description="Session ID"), filename: str = Query(..., description="File name"), ): """Read CSV/XLSX via pandas, return {columns, rows (first 5)}.""" session = session_manager.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if not session.output_dir: raise HTTPException(status_code=404, detail=f"File not found: {filename}") file_path = os.path.join(session.output_dir, filename) if not os.path.exists(file_path): raise HTTPException(status_code=404, detail=f"File not found: {filename}") try: if filename.lower().endswith(".xlsx"): df = pd.read_excel(file_path, nrows=5) else: # Try utf-8-sig first (common for Chinese CSV exports), fall back to utf-8 try: df = pd.read_csv(file_path, nrows=5, encoding="utf-8-sig") except UnicodeDecodeError: try: df = pd.read_csv(file_path, nrows=5, encoding="utf-8") except UnicodeDecodeError: df = pd.read_csv(file_path, nrows=5, encoding="gbk") columns = list(df.columns) rows = df.head(5).to_dict(orient="records") # Sanitize NaN/inf for JSON serialization rows = [ {k: (None if isinstance(v, float) and (_math.isnan(v) or _math.isinf(v)) else v) for k, v in row.items()} for row in rows ] return {"columns": columns, "rows": rows} except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to read file: {str(e)}") @app.get("/api/data-files/download") async def download_data_file( session_id: str = Query(..., description="Session ID"), filename: str = Query(..., description="File name"), ): """Return FileResponse with correct MIME type.""" session = session_manager.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if not session.output_dir: raise HTTPException(status_code=404, detail=f"File not found: {filename}") file_path = os.path.join(session.output_dir, filename) if not os.path.exists(file_path): raise HTTPException(status_code=404, detail=f"File not found: {filename}") if filename.lower().endswith(".xlsx"): media_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" else: media_type = "text/csv" return FileResponse(path=file_path, filename=filename, media_type=media_type) # --- 新增API端点 --- @app.get("/api/sessions/progress") async def get_session_progress(session_id: str = Query(..., description="Session ID")): """获取会话分析进度""" session_info = session_manager.get_session_info(session_id) if not session_info: raise HTTPException(status_code=404, detail="Session not found") return session_info @app.get("/api/sessions/list") async def list_all_sessions(): """获取所有会话列表""" session_ids = session_manager.list_sessions() sessions_info = [] for sid in session_ids: info = session_manager.get_session_info(sid) if info: sessions_info.append(info) return {"sessions": sessions_info, "total": len(sessions_info)} @app.delete("/api/sessions/{session_id}") async def delete_specific_session(session_id: str): """删除指定会话""" success = session_manager.delete_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found") return {"status": "deleted", "session_id": session_id} # --- Report Polishing API --- import re as _re def _split_report_to_paragraphs(markdown_content: str) -> list: """ 将 Markdown 报告按语义段落拆分。 每个段落包含 id、类型(heading/text/table/image)、原始内容。 前端可据此实现段落级选择与润色。 """ lines = markdown_content.split("\n") paragraphs = [] current_block = [] current_type = "text" para_id = 0 def flush_block(): nonlocal para_id, current_block, current_type text = "\n".join(current_block).strip() if text: paragraphs.append({ "id": f"p-{para_id}", "type": current_type, "content": text, }) para_id += 1 current_block = [] current_type = "text" in_table = False in_code = False for line in lines: stripped = line.strip() # 代码块边界 if stripped.startswith("```"): if in_code: current_block.append(line) flush_block() in_code = False continue else: flush_block() current_block.append(line) current_type = "code" in_code = True continue if in_code: current_block.append(line) continue # 标题行 — 独立成段 if _re.match(r"^#{1,6}\s", stripped): flush_block() current_block.append(line) current_type = "heading" flush_block() continue # 图片行 if _re.match(r"^!\[.*\]\(.*\)", stripped): flush_block() current_block.append(line) current_type = "image" flush_block() continue # 表格行 if stripped.startswith("|"): if not in_table: flush_block() in_table = True current_type = "table" current_block.append(line) continue else: if in_table: flush_block() in_table = False # 空行 — 段落分隔 if not stripped: flush_block() continue # 普通文本 current_block.append(line) flush_block() return paragraphs def _extract_evidence_annotations(paragraphs: list, session) -> dict: """Parse annotations from paragraph content. For each paragraph containing an evidence annotation, look up session.rounds[N-1].evidence_rows and build a supporting_data mapping keyed by paragraph ID. """ supporting_data = {} evidence_pattern = re.compile(r"") for para in paragraphs: content = para.get("content", "") match = evidence_pattern.search(content) if match: round_num = int(match.group(1)) # rounds are 1-indexed, list is 0-indexed idx = round_num - 1 if 0 <= idx < len(session.rounds): evidence_rows = session.rounds[idx].get("evidence_rows", []) if evidence_rows: supporting_data[para["id"]] = evidence_rows return supporting_data class PolishRequest(BaseModel): session_id: str paragraph_id: str mode: str = "context" # "context" | "data" | "custom" custom_instruction: str = "" @app.post("/api/report/polish") async def polish_paragraph(request: PolishRequest): """ 对报告中指定段落进行 AI 润色。 mode: - context: 根据上下文和图表信息润色,使表述更专业、更有洞察 - data: 结合原始分析数据重新生成该段落内容 - custom: 用户自定义润色指令 """ session = session_manager.get_session(request.session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if not session.generated_report or not os.path.exists(session.generated_report): raise HTTPException(status_code=404, detail="Report not found") # 读取报告并拆分段落 with open(session.generated_report, "r", encoding="utf-8") as f: report_content = f.read() paragraphs = _split_report_to_paragraphs(report_content) # 找到目标段落 target = None target_idx = -1 for i, p in enumerate(paragraphs): if p["id"] == request.paragraph_id: target = p target_idx = i break if not target: raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found") # Build the actual content to polish: include adjacent table paragraphs # so that when user clicks on text below a table, the table gets polished too polish_para_ids = [target["id"]] polish_content_parts = [target["content"]] # Check if previous paragraph is a table — include it if target_idx > 0 and paragraphs[target_idx - 1]["type"] == "table": polish_para_ids.insert(0, paragraphs[target_idx - 1]["id"]) polish_content_parts.insert(0, paragraphs[target_idx - 1]["content"]) # Check if next paragraph is a table — include it if target_idx + 1 < len(paragraphs) and paragraphs[target_idx + 1]["type"] == "table": polish_para_ids.append(paragraphs[target_idx + 1]["id"]) polish_content_parts.append(paragraphs[target_idx + 1]["content"]) # If the target itself is a table, include adjacent text too if target["type"] == "table": if target_idx + 1 < len(paragraphs) and paragraphs[target_idx + 1]["type"] == "text": polish_para_ids.append(paragraphs[target_idx + 1]["id"]) polish_content_parts.append(paragraphs[target_idx + 1]["content"]) if target_idx > 0 and paragraphs[target_idx - 1]["type"] == "text": polish_para_ids.insert(0, paragraphs[target_idx - 1]["id"]) polish_content_parts.insert(0, paragraphs[target_idx - 1]["content"]) combined_content = "\n\n".join(polish_content_parts) # 构建上下文窗口(前后各2个段落,排除已包含的) context_window = [] for j in range(max(0, target_idx - 2), min(len(paragraphs), target_idx + 3)): if paragraphs[j]["id"] not in polish_para_ids: context_window.append(paragraphs[j]["content"]) context_text = "\n\n".join(context_window) # 收集图表信息 figures_info = "" if session.analysis_results: fig_parts = [] for item in session.analysis_results: if item.get("action") == "collect_figures": for fig in item.get("collected_figures", []): fig_parts.append(f"- {fig.get('filename', '?')}: {fig.get('description', '')} / {fig.get('analysis', '')}") if fig_parts: figures_info = "\n".join(fig_parts) # 构建润色 prompt if request.mode == "data": # 收集代码执行结果摘要 data_summary_parts = [] for item in session.analysis_results: result = item.get("result", {}) if result.get("success") and result.get("output"): output_text = result["output"][:2000] data_summary_parts.append(output_text) data_summary = "\n---\n".join(data_summary_parts[:5]) polish_prompt = f"""你是一位资深数据分析专家。请基于以下分析数据,重写下方段落,使其包含更精确的数据引用和更深入的业务洞察。 ## 分析数据摘要 {data_summary} ## 图表信息 {figures_info} ## 需要润色的段落(可能包含表格和文字) {combined_content} ## 要求 - 保持原有的 Markdown 格式(标题级别、表格结构等) - 如果包含表格,必须同时润色表格内容(补充数据、修正数值) - 用具体数据替换模糊描述 - 增加业务洞察和趋势判断 - 禁止使用第一人称 - 直接输出润色后的 Markdown 内容,不要包裹在代码块中""" elif request.mode == "custom": polish_prompt = f"""你是一位资深数据分析专家。请根据用户的指令润色以下段落。 ## 用户指令 {request.custom_instruction} ## 上下文 {context_text} ## 图表信息 {figures_info} ## 需要润色的段落(可能包含表格和文字) {combined_content} ## 要求 - 保持原有的 Markdown 格式 - 如果包含表格,必须同时润色表格内容 - 严格遵循用户指令 - 禁止使用第一人称 - 直接输出润色后的 Markdown 内容,不要包裹在代码块中""" else: # context mode polish_prompt = f"""你是一位资深数据分析专家。请润色以下段落,使其表述更专业、更有洞察力。 ## 上下文(前后段落) {context_text} ## 图表信息 {figures_info} ## 需要润色的段落(可能包含表格和文字) {combined_content} ## 要求 - 保持原有的 Markdown 格式(标题级别、表格结构等) - 如果包含表格,必须同时润色表格内容(补充数据、修正数值) - 提升专业性:使用同比、环比、占比等术语 - 增加洞察:不仅描述现象,还要分析原因和影响 - 禁止使用第一人称 - 直接输出润色后的 Markdown 内容,不要包裹在代码块中""" # 调用 LLM 润色 try: from utils.llm_helper import LLMHelper llm = LLMHelper(LLMConfig()) polished_content = llm.call( prompt=polish_prompt, system_prompt="你是一位专业的数据分析报告润色专家。直接输出润色后的内容,不要添加任何解释或包裹。", max_tokens=4096, ) # 清理可能的代码块包裹 polished_content = polished_content.strip() if polished_content.startswith("```markdown"): polished_content = polished_content[len("```markdown"):].strip() if polished_content.startswith("```"): polished_content = polished_content[3:].strip() if polished_content.endswith("```"): polished_content = polished_content[:-3].strip() return { "paragraph_id": request.paragraph_id, "original": combined_content, "polished": polished_content, "mode": request.mode, "affected_paragraph_ids": polish_para_ids, } except Exception as e: raise HTTPException(status_code=500, detail=f"Polish failed: {str(e)}") class ApplyPolishRequest(BaseModel): session_id: str paragraph_id: str new_content: str @app.post("/api/report/apply") async def apply_polish(request: ApplyPolishRequest): """ 将润色后的内容应用到报告文件中,替换指定段落。 """ session = session_manager.get_session(request.session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if not session.generated_report or not os.path.exists(session.generated_report): raise HTTPException(status_code=404, detail="Report not found") with open(session.generated_report, "r", encoding="utf-8") as f: report_content = f.read() paragraphs = _split_report_to_paragraphs(report_content) # 找到目标段落并替换 target = None for p in paragraphs: if p["id"] == request.paragraph_id: target = p break if not target: raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found") # 在原文中替换 new_report = report_content.replace(target["content"], request.new_content, 1) # 写回文件 with open(session.generated_report, "w", encoding="utf-8") as f: f.write(new_report) return {"status": "applied", "paragraph_id": request.paragraph_id} # --- History API --- @app.get("/api/history") async def get_history(): """ Get list of past analysis sessions from outputs directory """ history = [] output_base = "outputs" if not os.path.exists(output_base): return {"history": []} try: # Scan for session_* directories for entry in os.scandir(output_base): if entry.is_dir() and entry.name.startswith("session_"): # Extract timestamp from folder name: session_20250101_120000 session_id = entry.name.replace("session_", "") # Check creation time or extract from name try: # Try to parse timestamp from ID if it matches format # Format: YYYYMMDD_HHMMSS timestamp_str = session_id dt = datetime.strptime(timestamp_str, "%Y%m%d_%H%M%S") display_time = dt.strftime("%Y-%m-%d %H:%M:%S") sort_key = dt.timestamp() except ValueError: # Fallback to file creation time sort_key = entry.stat().st_ctime display_time = datetime.fromtimestamp(sort_key).strftime("%Y-%m-%d %H:%M:%S") history.append({ "id": session_id, "timestamp": display_time, "sort_key": sort_key, "name": f"Session {display_time}" }) # Sort by latest first history.sort(key=lambda x: x["sort_key"], reverse=True) # Cleanup internal sort key for item in history: del item["sort_key"] return {"history": history} except Exception as e: print(f"Error scanning history: {e}") return {"history": []} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)