Files
iov_data_analysis_agent/web/main.py
Jeason c7224153b1 YAML 反斜杠修复扩大范围 — 之前只匹配 "D:\..." 格式,现在匹配所有双引号内含反斜杠的字符串。"outputs\session_20260420..." 会被正确转成 "outputs/session_20260420...",不再导致 YAML 解析失败。这直接解决了第 10-19 轮的死循环。
_process_response 的 analysis_complete 检测已经在上一轮修好了,配合反斜杠修复,YAML 能正确解析出 action: "analysis_complete",不会再 fallback 到代码执行。

文件选择改为只用最近一次上传的文件 — app.state.last_uploaded_files 记录上传的文件列表,/api/start 优先使用它,不再 glob("uploads/*.csv") 把所有历史文件都拿来分析。
2026-04-20 13:09:54 +08:00

1200 lines
44 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys
import os
import threading
import glob
import uuid
import json
import re
from datetime import datetime
from typing import Optional, Dict, List
from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
import pandas as pd
# Add parent directory to path to import agent modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data_analysis_agent import DataAnalysisAgent
from config.llm_config import LLMConfig
from utils.create_session_dir import create_session_output_dir
from utils.logger import PrintCapture
app = FastAPI(title="IOV Data Analysis Agent")
def _to_web_path(fs_path: str) -> str:
"""将文件系统路径转为 URL 安全的正斜杠路径(修复 Windows 反斜杠问题)"""
return fs_path.replace("\\", "/")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Session Management ---
class SessionData:
def __init__(self, session_id: str):
self.session_id = session_id
self.is_running = False
self.output_dir: Optional[str] = None
self.generated_report: Optional[str] = None
self.log_file: Optional[str] = None
self.analysis_results: List[Dict] = [] # Store analysis results for gallery
self.rounds: List[Dict] = [] # Structured Round_Data objects
self.data_files: List[Dict] = [] # File metadata dicts
self.agent: Optional[DataAnalysisAgent] = None # Store the agent instance for follow-up
# 新增:进度跟踪
self.current_round: int = 0
self.max_rounds: int = 20
self.progress_percentage: float = 0.0
self.status_message: str = "等待开始"
# 新增:历史记录
self.created_at: str = ""
self.last_updated: str = ""
self.user_requirement: str = ""
self.file_list: List[str] = []
self.reusable_script: Optional[str] = None # 新增:可复用脚本路径
class SessionManager:
def __init__(self):
self.sessions: Dict[str, SessionData] = {}
self.lock = threading.Lock()
def create_session(self) -> str:
with self.lock:
session_id = str(uuid.uuid4())
self.sessions[session_id] = SessionData(session_id)
return session_id
def get_session(self, session_id: str) -> Optional[SessionData]:
if session_id in self.sessions:
return self.sessions[session_id]
# Fallback: Try to reconstruct from disk for history sessions
# First try the old convention: outputs/session_{uuid}
output_dir = os.path.join("outputs", f"session_{session_id}")
if os.path.exists(output_dir) and os.path.isdir(output_dir):
return self._reconstruct_session(session_id, output_dir)
# Scan all session directories for session_meta.json matching this session_id
# This handles the case where output_dir uses a timestamp name, not the UUID
outputs_root = "outputs"
if os.path.exists(outputs_root):
for dirname in os.listdir(outputs_root):
dir_path = os.path.join(outputs_root, dirname)
if not os.path.isdir(dir_path) or not dirname.startswith("session_"):
continue
meta_path = os.path.join(dir_path, "session_meta.json")
if os.path.exists(meta_path):
try:
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f)
if meta.get("session_id") == session_id:
return self._reconstruct_session(session_id, dir_path)
except Exception:
continue
return None
def _reconstruct_session(self, session_id: str, output_dir: str) -> SessionData:
"""从磁盘目录重建会话对象"""
session = SessionData(session_id)
session.output_dir = output_dir
session.is_running = False
session.current_round = session.max_rounds
session.progress_percentage = 100.0
session.status_message = "已完成 (历史记录)"
# Read session_meta.json if available
meta = {}
meta_path = os.path.join(output_dir, "session_meta.json")
if os.path.exists(meta_path):
try:
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f)
except Exception:
pass
# Recover Log
log_path = os.path.join(output_dir, "process.log")
if os.path.exists(log_path):
session.log_file = log_path
# Recover Report — prefer meta, then scan .md files
report_path = meta.get("report_path")
if report_path and os.path.exists(report_path):
session.generated_report = report_path
else:
md_files = glob.glob(os.path.join(output_dir, "*.md"))
if md_files:
chosen = md_files[0]
for md in md_files:
fname = os.path.basename(md).lower()
if "report" in fname or "报告" in fname:
chosen = md
break
session.generated_report = chosen
# Recover Script — prefer meta, then scan for 分析脚本_*.py or other patterns
script_path = meta.get("script_path")
if script_path and os.path.exists(script_path):
session.reusable_script = script_path
else:
# Try Chinese-named scripts first (generated by this system)
script_files = glob.glob(os.path.join(output_dir, "分析脚本_*.py"))
if not script_files:
for s in ["data_analysis_script.py", "script.py", "analysis_script.py"]:
p = os.path.join(output_dir, s)
if os.path.exists(p):
script_files = [p]
break
if script_files:
session.reusable_script = script_files[0]
# Recover Results (images etc)
results_json = os.path.join(output_dir, "results.json")
if os.path.exists(results_json):
try:
with open(results_json, "r") as f:
data = json.load(f)
# Support both old format (plain list) and new format (dict with rounds/data_files)
if isinstance(data, dict):
session.analysis_results = data.get("analysis_results", [])
session.rounds = data.get("rounds", [])
session.data_files = data.get("data_files", [])
else:
# Legacy format: data is the analysis_results list directly
session.analysis_results = data
except:
pass
# Recover Metadata
try:
stat = os.stat(output_dir)
dt = datetime.fromtimestamp(stat.st_ctime)
session.created_at = dt.strftime("%Y-%m-%d %H:%M:%S")
except:
pass
# Cache it
with self.lock:
self.sessions[session_id] = session
return session
def list_sessions(self):
return list(self.sessions.keys())
def delete_session(self, session_id: str) -> bool:
"""删除指定会话"""
with self.lock:
if session_id in self.sessions:
session = self.sessions[session_id]
if session.agent:
session.agent.reset()
del self.sessions[session_id]
return True
return False
def get_session_info(self, session_id: str) -> Optional[Dict]:
"""获取会话详细信息"""
session = self.get_session(session_id)
if session:
return {
"session_id": session.session_id,
"is_running": session.is_running,
"progress": session.progress_percentage,
"status": session.status_message,
"current_round": session.current_round,
"max_rounds": session.max_rounds,
"created_at": session.created_at,
"last_updated": session.last_updated,
"user_requirement": session.user_requirement[:100] + "..." if len(session.user_requirement) > 100 else session.user_requirement,
"script_path": session.reusable_script # 新增:返回脚本路径
}
return None
session_manager = SessionManager()
# Mount static files
os.makedirs("web/static", exist_ok=True)
os.makedirs("uploads", exist_ok=True)
os.makedirs("outputs", exist_ok=True)
app.mount("/static", StaticFiles(directory="web/static"), name="static")
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
# --- Helper Functions ---
def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False, template_name: str = None):
"""在后台线程中运行分析任务"""
session = session_manager.get_session(session_id)
if not session:
print(f"Error: Session {session_id} not found in background task.")
return
session.is_running = True
try:
base_output_dir = "outputs"
if not session.output_dir:
session.output_dir = create_session_output_dir(base_output_dir, user_requirement)
session_output_dir = session.output_dir
session.log_file = os.path.join(session_output_dir, "process.log")
# Persist session-to-directory mapping immediately so recovery works
# even if the server restarts mid-analysis
try:
with open(os.path.join(session_output_dir, "session_meta.json"), "w") as f:
json.dump({"session_id": session_id, "user_requirement": user_requirement}, f, default=str)
except Exception:
pass
# 使用 PrintCapture 替代全局 FileLogger退出 with 块后自动恢复 stdout
with PrintCapture(session.log_file):
if is_followup:
print(f"\n--- Follow-up Session {session_id} Continued ---")
else:
print(f"--- Session {session_id} Started ---")
try:
if not is_followup:
llm_config = LLMConfig()
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
session.agent = agent
# Wire progress callback to update session progress fields
def progress_cb(current, total, message):
session.current_round = current
session.max_rounds = total
session.progress_percentage = round((current / total) * 100, 1) if total > 0 else 0
session.status_message = message
agent.set_progress_callback(progress_cb)
agent.set_session_ref(session)
result = agent.analyze(
user_input=user_requirement,
files=files,
session_output_dir=session_output_dir,
reset_session=True,
template_name=template_name,
)
else:
agent = session.agent
if not agent:
print("Error: Agent not initialized for follow-up.")
return
# Wire progress callback for follow-up sessions
def progress_cb_followup(current, total, message):
session.current_round = current
session.max_rounds = total
session.progress_percentage = round((current / total) * 100, 1) if total > 0 else 0
session.status_message = message
agent.set_progress_callback(progress_cb_followup)
agent.set_session_ref(session)
result = agent.analyze(
user_input=user_requirement,
files=None,
session_output_dir=session_output_dir,
reset_session=False,
max_rounds=10,
)
session.generated_report = result.get("report_file_path", None)
session.analysis_results = result.get("analysis_results", [])
session.reusable_script = result.get("reusable_script_path", None)
# 持久化结果
with open(os.path.join(session_output_dir, "results.json"), "w") as f:
json.dump({
"analysis_results": session.analysis_results,
"rounds": session.rounds,
"data_files": session.data_files,
}, f, default=str)
# Persist session-to-directory mapping for recovery after server restart
try:
with open(os.path.join(session_output_dir, "session_meta.json"), "w") as f:
json.dump({
"session_id": session_id,
"user_requirement": user_requirement,
"report_path": session.generated_report,
"script_path": session.reusable_script,
}, f, default=str)
except Exception:
pass
except Exception as e:
print(f"Error during analysis: {e}")
except Exception as e:
print(f"System Error: {e}")
finally:
session.is_running = False
# --- Pydantic Models ---
class StartRequest(BaseModel):
requirement: str
template: Optional[str] = None
class ChatRequest(BaseModel):
session_id: str
message: str
# --- API Endpoints ---
@app.get("/")
async def read_root():
return FileResponse("web/static/index.html")
@app.post("/api/upload")
async def upload_files(files: list[UploadFile] = File(...)):
saved_files = []
for file in files:
file_location = f"uploads/{file.filename}"
with open(file_location, "wb+") as file_object:
file_object.write(file.file.read())
saved_files.append(file_location)
# Track the most recently uploaded files for the next analysis
app.state.last_uploaded_files = saved_files
return {"info": f"Saved {len(saved_files)} files", "paths": saved_files}
@app.post("/api/start")
async def start_analysis(request: StartRequest, background_tasks: BackgroundTasks):
session_id = session_manager.create_session()
# Use only the most recently uploaded files, not everything in uploads/
files = getattr(app.state, 'last_uploaded_files', None)
if not files:
# Fallback: scan uploads directory
files = glob.glob("uploads/*.csv") + glob.glob("uploads/*.xlsx")
if not files:
raise HTTPException(status_code=400, detail="No data files found. Please upload files first.")
files = [os.path.abspath(f) for f in files] # Only use absolute paths
background_tasks.add_task(run_analysis_task, session_id, files, request.requirement, is_followup=False, template_name=request.template)
return {"status": "started", "session_id": session_id}
@app.post("/api/chat")
async def chat_analysis(request: ChatRequest, background_tasks: BackgroundTasks):
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.is_running:
raise HTTPException(status_code=400, detail="Analysis already in progress")
background_tasks.add_task(run_analysis_task, request.session_id, [], request.message, is_followup=True)
return {"status": "started"}
import math as _math
def _sanitize_value(v):
"""Make any value JSON-serializable.
Handles: NaN/inf floats → None, pandas Timestamp/Timedelta → str,
numpy integers/floats → Python int/float, dicts and lists recursively.
"""
if v is None:
return None
if isinstance(v, float):
if _math.isnan(v) or _math.isinf(v):
return None
return v
if isinstance(v, (int, bool, str)):
return v
if isinstance(v, dict):
return {k: _sanitize_value(val) for k, val in v.items()}
if isinstance(v, list):
return [_sanitize_value(item) for item in v]
# pandas Timestamp, Timedelta, NaT
try:
if pd.isna(v):
return None
except (TypeError, ValueError):
pass
if hasattr(v, 'isoformat'): # datetime, Timestamp
return v.isoformat()
# numpy scalar types
if hasattr(v, 'item'):
return v.item()
# Fallback: convert to string
return str(v)
@app.get("/api/status")
async def get_status(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
log_content = ""
if session.log_file and os.path.exists(session.log_file):
with open(session.log_file, "r", encoding="utf-8") as f:
log_content = f.read()
response_data = {
"is_running": session.is_running,
"log": log_content,
"has_report": session.generated_report is not None,
"report_path": session.generated_report,
"script_path": session.reusable_script,
"current_round": session.current_round,
"max_rounds": session.max_rounds,
"progress_percentage": session.progress_percentage,
"status_message": session.status_message,
"rounds": _sanitize_value(session.rounds),
}
return JSONResponse(content=response_data)
@app.get("/api/export")
async def export_session(session_id: str = Query(..., description="Session ID")):
"""导出会话数据为ZIP"""
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.output_dir or not os.path.exists(session.output_dir):
raise HTTPException(status_code=404, detail="No data available for export")
import zipfile
from datetime import datetime as dt
timestamp = dt.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"report_{timestamp}.zip"
export_dir = "outputs"
os.makedirs(export_dir, exist_ok=True)
temp_zip_path = os.path.join(export_dir, zip_filename)
with zipfile.ZipFile(temp_zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for root, dirs, files in os.walk(session.output_dir):
for file in files:
if file.endswith(('.md', '.png', '.csv', '.log', '.json', '.yaml')):
abs_path = os.path.join(root, file)
rel_path = os.path.relpath(abs_path, session.output_dir)
zf.write(abs_path, arcname=rel_path)
return FileResponse(
path=temp_zip_path,
filename=zip_filename,
media_type='application/zip'
)
@app.get("/api/report")
async def get_report(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.generated_report or not os.path.exists(session.generated_report):
return {"content": "Report not ready.", "paragraphs": []}
with open(session.generated_report, "r", encoding="utf-8") as f:
content = f.read()
# Fix image paths
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
web_base_path = f"/{relative_session_path}"
# Robust image path replacement
content = content.replace("](./", f"]({web_base_path}/")
import re
def replace_link(match):
alt = match.group(1)
url = match.group(2)
if url.startswith("http") or url.startswith("/") or url.startswith("data:"):
return match.group(0)
clean_url = url.lstrip("./")
return f"![{alt}]({web_base_path}/{clean_url})"
content = re.sub(r'!\[(.*?)\]\((.*?)\)', replace_link, content)
# 将报告按段落拆分,为前端润色功能提供结构化数据
paragraphs = _split_report_to_paragraphs(content)
# Extract evidence annotations and build supporting_data mapping
supporting_data = _extract_evidence_annotations(paragraphs, session)
return {"content": content, "base_path": web_base_path, "paragraphs": paragraphs, "supporting_data": supporting_data}
@app.get("/api/figures")
async def get_figures(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# We can try to get from memory first
results = session.analysis_results
# If empty in memory (maybe server restarted but files exist?), try load json
if not results and session.output_dir:
json_path = os.path.join(session.output_dir, "results.json")
if os.path.exists(json_path):
with open(json_path, 'r') as f:
results = json.load(f)
# Extract collected figures
figures = []
# We iterate over analysis results to find 'collect_figures' actions
if results:
for item in results:
if item.get("action") == "collect_figures":
collected = item.get("collected_figures", [])
for fig in collected:
# Enrich with web path
if session.output_dir:
# Assume filename is present
fname = fig.get("filename")
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
fig["web_url"] = f"/{relative_session_path}/{fname}"
figures.append(fig)
# Also check for 'generate_code' results that might have implicit figures if we parse them
# But the 'collect_figures' action is the reliable source as per agent design
# Auto-discovery fallback if list is empty but pngs exist?
if not figures and session.output_dir:
# Simple scan
pngs = glob.glob(os.path.join(session.output_dir, "*.png"))
for p in pngs:
fname = os.path.basename(p)
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
figures.append({
"filename": fname,
"description": "Auto-discovered image",
"analysis": "No analysis available",
"web_url": f"/{relative_session_path}/{fname}"
})
return {"figures": figures}
@app.get("/api/download_script")
async def download_script(session_id: str = Query(..., description="Session ID")):
"""下载生成的Python脚本"""
session = session_manager.get_session(session_id)
if not session or not session.reusable_script:
raise HTTPException(status_code=404, detail="Script not found")
if not os.path.exists(session.reusable_script):
raise HTTPException(status_code=404, detail="Script file missing on server")
return FileResponse(
path=session.reusable_script,
filename=os.path.basename(session.reusable_script),
media_type='text/x-python'
)
# --- Tools API ---
@app.get("/api/templates")
async def list_available_templates():
from utils.analysis_templates import list_templates
return {"templates": list_templates()}
@app.get("/api/templates/{template_name}")
async def get_template_detail(template_name: str):
"""获取单个模板的完整内容(含步骤)"""
from utils.analysis_templates import get_template
try:
tpl = get_template(template_name)
return tpl.to_dict()
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@app.put("/api/templates/{template_name}")
async def update_template(template_name: str, body: dict):
"""创建或更新模板"""
from utils.analysis_templates import save_template
try:
filepath = save_template(template_name, body)
return {"status": "saved", "filepath": filepath}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/api/templates/{template_name}")
async def remove_template(template_name: str):
"""删除模板"""
from utils.analysis_templates import delete_template
if delete_template(template_name):
return {"status": "deleted"}
raise HTTPException(status_code=404, detail=f"Template not found: {template_name}")
# --- Data Files API ---
@app.get("/api/data-files")
async def list_data_files(session_id: str = Query(..., description="Session ID")):
"""Return session.data_files merged with fallback directory scan for CSV/XLSX files."""
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Start with known data_files from session
known_files = {f["filename"]: f for f in session.data_files}
# Fallback directory scan for CSV/XLSX in output_dir
if session.output_dir and os.path.exists(session.output_dir):
# Collect original uploaded file basenames to exclude them
uploaded_basenames = set()
if hasattr(session, "file_list"):
for fp in session.file_list:
uploaded_basenames.add(os.path.basename(fp))
for pattern in ("*.csv", "*.xlsx"):
for fpath in glob.glob(os.path.join(session.output_dir, pattern)):
fname = os.path.basename(fpath)
if fname in uploaded_basenames:
continue
if fname not in known_files:
try:
size_bytes = os.path.getsize(fpath)
except OSError:
size_bytes = 0
known_files[fname] = {
"filename": fname,
"description": "",
"rows": 0,
"cols": 0,
"size_bytes": size_bytes,
}
return {"files": list(known_files.values())}
@app.get("/api/data-files/preview")
async def preview_data_file(
session_id: str = Query(..., description="Session ID"),
filename: str = Query(..., description="File name"),
):
"""Read CSV/XLSX via pandas, return {columns, rows (first 5)}."""
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.output_dir:
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
file_path = os.path.join(session.output_dir, filename)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
try:
if filename.lower().endswith(".xlsx"):
df = pd.read_excel(file_path, nrows=5)
else:
# Try utf-8-sig first (common for Chinese CSV exports), fall back to utf-8
try:
df = pd.read_csv(file_path, nrows=5, encoding="utf-8-sig")
except UnicodeDecodeError:
try:
df = pd.read_csv(file_path, nrows=5, encoding="utf-8")
except UnicodeDecodeError:
df = pd.read_csv(file_path, nrows=5, encoding="gbk")
columns = list(df.columns)
rows = df.head(5).to_dict(orient="records")
# Sanitize NaN/inf for JSON serialization
rows = [
{k: (None if isinstance(v, float) and (_math.isnan(v) or _math.isinf(v)) else v)
for k, v in row.items()}
for row in rows
]
return {"columns": columns, "rows": rows}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to read file: {str(e)}")
@app.get("/api/data-files/download")
async def download_data_file(
session_id: str = Query(..., description="Session ID"),
filename: str = Query(..., description="File name"),
):
"""Return FileResponse with correct MIME type."""
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.output_dir:
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
file_path = os.path.join(session.output_dir, filename)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
if filename.lower().endswith(".xlsx"):
media_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
else:
media_type = "text/csv"
return FileResponse(path=file_path, filename=filename, media_type=media_type)
# --- 新增API端点 ---
@app.get("/api/sessions/progress")
async def get_session_progress(session_id: str = Query(..., description="Session ID")):
"""获取会话分析进度"""
session_info = session_manager.get_session_info(session_id)
if not session_info:
raise HTTPException(status_code=404, detail="Session not found")
return session_info
@app.get("/api/sessions/list")
async def list_all_sessions():
"""获取所有会话列表"""
session_ids = session_manager.list_sessions()
sessions_info = []
for sid in session_ids:
info = session_manager.get_session_info(sid)
if info:
sessions_info.append(info)
return {"sessions": sessions_info, "total": len(sessions_info)}
@app.delete("/api/sessions/{session_id}")
async def delete_specific_session(session_id: str):
"""删除指定会话"""
success = session_manager.delete_session(session_id)
if not success:
raise HTTPException(status_code=404, detail="Session not found")
return {"status": "deleted", "session_id": session_id}
# --- Report Polishing API ---
import re as _re
def _split_report_to_paragraphs(markdown_content: str) -> list:
"""
将 Markdown 报告按语义段落拆分。
每个段落包含 id、类型heading/text/table/image、原始内容。
前端可据此实现段落级选择与润色。
"""
lines = markdown_content.split("\n")
paragraphs = []
current_block = []
current_type = "text"
para_id = 0
def flush_block():
nonlocal para_id, current_block, current_type
text = "\n".join(current_block).strip()
if text:
paragraphs.append({
"id": f"p-{para_id}",
"type": current_type,
"content": text,
})
para_id += 1
current_block = []
current_type = "text"
in_table = False
in_code = False
for line in lines:
stripped = line.strip()
# 代码块边界
if stripped.startswith("```"):
if in_code:
current_block.append(line)
flush_block()
in_code = False
continue
else:
flush_block()
current_block.append(line)
current_type = "code"
in_code = True
continue
if in_code:
current_block.append(line)
continue
# 标题行 — 独立成段
if _re.match(r"^#{1,6}\s", stripped):
flush_block()
current_block.append(line)
current_type = "heading"
flush_block()
continue
# 图片行
if _re.match(r"^!\[.*\]\(.*\)", stripped):
flush_block()
current_block.append(line)
current_type = "image"
flush_block()
continue
# 表格行
if stripped.startswith("|"):
if not in_table:
flush_block()
in_table = True
current_type = "table"
current_block.append(line)
continue
else:
if in_table:
flush_block()
in_table = False
# 空行 — 段落分隔
if not stripped:
flush_block()
continue
# 普通文本
current_block.append(line)
flush_block()
return paragraphs
def _extract_evidence_annotations(paragraphs: list, session) -> dict:
"""Parse <!-- evidence:round_N --> annotations from paragraph content.
For each paragraph containing an evidence annotation, look up
session.rounds[N-1].evidence_rows and build a supporting_data mapping
keyed by paragraph ID.
"""
supporting_data = {}
evidence_pattern = re.compile(r"<!--\s*evidence:round_(\d+)\s*-->")
for para in paragraphs:
content = para.get("content", "")
match = evidence_pattern.search(content)
if match:
round_num = int(match.group(1))
# rounds are 1-indexed, list is 0-indexed
idx = round_num - 1
if 0 <= idx < len(session.rounds):
evidence_rows = session.rounds[idx].get("evidence_rows", [])
if evidence_rows:
supporting_data[para["id"]] = evidence_rows
return supporting_data
class PolishRequest(BaseModel):
session_id: str
paragraph_id: str
mode: str = "context" # "context" | "data" | "custom"
custom_instruction: str = ""
@app.post("/api/report/polish")
async def polish_paragraph(request: PolishRequest):
"""
对报告中指定段落进行 AI 润色。
mode:
- context: 根据上下文和图表信息润色,使表述更专业、更有洞察
- data: 结合原始分析数据重新生成该段落内容
- custom: 用户自定义润色指令
"""
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.generated_report or not os.path.exists(session.generated_report):
raise HTTPException(status_code=404, detail="Report not found")
# 读取报告并拆分段落
with open(session.generated_report, "r", encoding="utf-8") as f:
report_content = f.read()
paragraphs = _split_report_to_paragraphs(report_content)
# 找到目标段落
target = None
target_idx = -1
for i, p in enumerate(paragraphs):
if p["id"] == request.paragraph_id:
target = p
target_idx = i
break
if not target:
raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found")
# Build the actual content to polish: include adjacent table paragraphs
# so that when user clicks on text below a table, the table gets polished too
polish_para_ids = [target["id"]]
polish_content_parts = [target["content"]]
# Check if previous paragraph is a table — include it
if target_idx > 0 and paragraphs[target_idx - 1]["type"] == "table":
polish_para_ids.insert(0, paragraphs[target_idx - 1]["id"])
polish_content_parts.insert(0, paragraphs[target_idx - 1]["content"])
# Check if next paragraph is a table — include it
if target_idx + 1 < len(paragraphs) and paragraphs[target_idx + 1]["type"] == "table":
polish_para_ids.append(paragraphs[target_idx + 1]["id"])
polish_content_parts.append(paragraphs[target_idx + 1]["content"])
# If the target itself is a table, include adjacent text too
if target["type"] == "table":
if target_idx + 1 < len(paragraphs) and paragraphs[target_idx + 1]["type"] == "text":
polish_para_ids.append(paragraphs[target_idx + 1]["id"])
polish_content_parts.append(paragraphs[target_idx + 1]["content"])
if target_idx > 0 and paragraphs[target_idx - 1]["type"] == "text":
polish_para_ids.insert(0, paragraphs[target_idx - 1]["id"])
polish_content_parts.insert(0, paragraphs[target_idx - 1]["content"])
combined_content = "\n\n".join(polish_content_parts)
# 构建上下文窗口前后各2个段落排除已包含的
context_window = []
for j in range(max(0, target_idx - 2), min(len(paragraphs), target_idx + 3)):
if paragraphs[j]["id"] not in polish_para_ids:
context_window.append(paragraphs[j]["content"])
context_text = "\n\n".join(context_window)
# 收集图表信息
figures_info = ""
if session.analysis_results:
fig_parts = []
for item in session.analysis_results:
if item.get("action") == "collect_figures":
for fig in item.get("collected_figures", []):
fig_parts.append(f"- {fig.get('filename', '?')}: {fig.get('description', '')} / {fig.get('analysis', '')}")
if fig_parts:
figures_info = "\n".join(fig_parts)
# 构建润色 prompt
if request.mode == "data":
# 收集代码执行结果摘要
data_summary_parts = []
for item in session.analysis_results:
result = item.get("result", {})
if result.get("success") and result.get("output"):
output_text = result["output"][:2000]
data_summary_parts.append(output_text)
data_summary = "\n---\n".join(data_summary_parts[:5])
polish_prompt = f"""你是一位资深数据分析专家。请基于以下分析数据,重写下方段落,使其包含更精确的数据引用和更深入的业务洞察。
## 分析数据摘要
{data_summary}
## 图表信息
{figures_info}
## 需要润色的段落(可能包含表格和文字)
{combined_content}
## 要求
- 保持原有的 Markdown 格式(标题级别、表格结构等)
- 如果包含表格,必须同时润色表格内容(补充数据、修正数值)
- 用具体数据替换模糊描述
- 增加业务洞察和趋势判断
- 禁止使用第一人称
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
elif request.mode == "custom":
polish_prompt = f"""你是一位资深数据分析专家。请根据用户的指令润色以下段落。
## 用户指令
{request.custom_instruction}
## 上下文
{context_text}
## 图表信息
{figures_info}
## 需要润色的段落(可能包含表格和文字)
{combined_content}
## 要求
- 保持原有的 Markdown 格式
- 如果包含表格,必须同时润色表格内容
- 严格遵循用户指令
- 禁止使用第一人称
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
else: # context mode
polish_prompt = f"""你是一位资深数据分析专家。请润色以下段落,使其表述更专业、更有洞察力。
## 上下文(前后段落)
{context_text}
## 图表信息
{figures_info}
## 需要润色的段落(可能包含表格和文字)
{combined_content}
## 要求
- 保持原有的 Markdown 格式(标题级别、表格结构等)
- 如果包含表格,必须同时润色表格内容(补充数据、修正数值)
- 提升专业性:使用同比、环比、占比等术语
- 增加洞察:不仅描述现象,还要分析原因和影响
- 禁止使用第一人称
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
# 调用 LLM 润色
try:
from utils.llm_helper import LLMHelper
llm = LLMHelper(LLMConfig())
polished_content = llm.call(
prompt=polish_prompt,
system_prompt="你是一位专业的数据分析报告润色专家。直接输出润色后的内容,不要添加任何解释或包裹。",
max_tokens=4096,
)
# 清理可能的代码块包裹
polished_content = polished_content.strip()
if polished_content.startswith("```markdown"):
polished_content = polished_content[len("```markdown"):].strip()
if polished_content.startswith("```"):
polished_content = polished_content[3:].strip()
if polished_content.endswith("```"):
polished_content = polished_content[:-3].strip()
return {
"paragraph_id": request.paragraph_id,
"original": combined_content,
"polished": polished_content,
"mode": request.mode,
"affected_paragraph_ids": polish_para_ids,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Polish failed: {str(e)}")
class ApplyPolishRequest(BaseModel):
session_id: str
paragraph_id: str
new_content: str
@app.post("/api/report/apply")
async def apply_polish(request: ApplyPolishRequest):
"""
将润色后的内容应用到报告文件中,替换指定段落。
"""
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.generated_report or not os.path.exists(session.generated_report):
raise HTTPException(status_code=404, detail="Report not found")
with open(session.generated_report, "r", encoding="utf-8") as f:
report_content = f.read()
paragraphs = _split_report_to_paragraphs(report_content)
# 找到目标段落并替换
target = None
for p in paragraphs:
if p["id"] == request.paragraph_id:
target = p
break
if not target:
raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found")
# 在原文中替换
new_report = report_content.replace(target["content"], request.new_content, 1)
# 写回文件
with open(session.generated_report, "w", encoding="utf-8") as f:
f.write(new_report)
return {"status": "applied", "paragraph_id": request.paragraph_id}
# --- History API ---
@app.get("/api/history")
async def get_history():
"""
Get list of past analysis sessions from outputs directory
"""
history = []
output_base = "outputs"
if not os.path.exists(output_base):
return {"history": []}
try:
# Scan for session_* directories
for entry in os.scandir(output_base):
if entry.is_dir() and entry.name.startswith("session_"):
# Extract timestamp from folder name: session_20250101_120000
session_id = entry.name.replace("session_", "")
# Check creation time or extract from name
try:
# Try to parse timestamp from ID if it matches format
# Format: YYYYMMDD_HHMMSS
timestamp_str = session_id
dt = datetime.strptime(timestamp_str, "%Y%m%d_%H%M%S")
display_time = dt.strftime("%Y-%m-%d %H:%M:%S")
sort_key = dt.timestamp()
except ValueError:
# Fallback to file creation time
sort_key = entry.stat().st_ctime
display_time = datetime.fromtimestamp(sort_key).strftime("%Y-%m-%d %H:%M:%S")
history.append({
"id": session_id,
"timestamp": display_time,
"sort_key": sort_key,
"name": f"Session {display_time}"
})
# Sort by latest first
history.sort(key=lambda x: x["sort_key"], reverse=True)
# Cleanup internal sort key
for item in history:
del item["sort_key"]
return {"history": history}
except Exception as e:
print(f"Error scanning history: {e}")
return {"history": []}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)