Files
iov_data_analysis_agent/web/main.py

858 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys
import os
import threading
import glob
import uuid
import json
from datetime import datetime
from typing import Optional, Dict, List
from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
# Add parent directory to path to import agent modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data_analysis_agent import DataAnalysisAgent
from config.llm_config import LLMConfig
from utils.create_session_dir import create_session_output_dir
from utils.logger import PrintCapture
app = FastAPI(title="IOV Data Analysis Agent")
def _to_web_path(fs_path: str) -> str:
"""将文件系统路径转为 URL 安全的正斜杠路径(修复 Windows 反斜杠问题)"""
return fs_path.replace("\\", "/")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Session Management ---
class SessionData:
def __init__(self, session_id: str):
self.session_id = session_id
self.is_running = False
self.output_dir: Optional[str] = None
self.generated_report: Optional[str] = None
self.log_file: Optional[str] = None
self.analysis_results: List[Dict] = [] # Store analysis results for gallery
self.agent: Optional[DataAnalysisAgent] = None # Store the agent instance for follow-up
# 新增:进度跟踪
self.current_round: int = 0
self.max_rounds: int = 20
self.progress_percentage: float = 0.0
self.status_message: str = "等待开始"
# 新增:历史记录
self.created_at: str = ""
self.last_updated: str = ""
self.user_requirement: str = ""
self.file_list: List[str] = []
self.reusable_script: Optional[str] = None # 新增:可复用脚本路径
class SessionManager:
def __init__(self):
self.sessions: Dict[str, SessionData] = {}
self.lock = threading.Lock()
def create_session(self) -> str:
with self.lock:
session_id = str(uuid.uuid4())
self.sessions[session_id] = SessionData(session_id)
return session_id
def get_session(self, session_id: str) -> Optional[SessionData]:
if session_id in self.sessions:
return self.sessions[session_id]
# Fallback: Try to reconstruct from disk for history sessions
output_dir = os.path.join("outputs", f"session_{session_id}")
if os.path.exists(output_dir) and os.path.isdir(output_dir):
return self._reconstruct_session(session_id, output_dir)
return None
def _reconstruct_session(self, session_id: str, output_dir: str) -> SessionData:
"""从磁盘目录重建会话对象"""
session = SessionData(session_id)
session.output_dir = output_dir
session.is_running = False
session.current_round = session.max_rounds
session.progress_percentage = 100.0
session.status_message = "已完成 (历史记录)"
# Recover Log
log_path = os.path.join(output_dir, "process.log")
if os.path.exists(log_path):
session.log_file = log_path
# Recover Report
# 宽容查找:扫描所有 .md 文件,优先取包含 "report" 或 "报告" 的文件
md_files = glob.glob(os.path.join(output_dir, "*.md"))
if md_files:
# 默认取第一个
chosen = md_files[0]
# 尝试找更好的匹配
for md in md_files:
fname = os.path.basename(md).lower()
if "report" in fname or "报告" in fname:
chosen = md
break
session.generated_report = chosen
# Recover Script (查找可能的脚本文件)
possible_scripts = ["data_analysis_script.py", "script.py", "analysis_script.py"]
for s in possible_scripts:
p = os.path.join(output_dir, s)
if os.path.exists(p):
session.reusable_script = p
break
# Recover Results (images etc)
results_json = os.path.join(output_dir, "results.json")
if os.path.exists(results_json):
try:
with open(results_json, "r") as f:
session.analysis_results = json.load(f)
except:
pass
# Recover Metadata
try:
stat = os.stat(output_dir)
dt = datetime.fromtimestamp(stat.st_ctime)
session.created_at = dt.strftime("%Y-%m-%d %H:%M:%S")
except:
pass
# Cache it
with self.lock:
self.sessions[session_id] = session
return session
def list_sessions(self):
return list(self.sessions.keys())
def delete_session(self, session_id: str) -> bool:
"""删除指定会话"""
with self.lock:
if session_id in self.sessions:
session = self.sessions[session_id]
if session.agent:
session.agent.reset()
del self.sessions[session_id]
return True
return False
def get_session_info(self, session_id: str) -> Optional[Dict]:
"""获取会话详细信息"""
session = self.get_session(session_id)
if session:
return {
"session_id": session.session_id,
"is_running": session.is_running,
"progress": session.progress_percentage,
"status": session.status_message,
"current_round": session.current_round,
"max_rounds": session.max_rounds,
"created_at": session.created_at,
"last_updated": session.last_updated,
"user_requirement": session.user_requirement[:100] + "..." if len(session.user_requirement) > 100 else session.user_requirement,
"script_path": session.reusable_script # 新增:返回脚本路径
}
return None
session_manager = SessionManager()
# Mount static files
os.makedirs("web/static", exist_ok=True)
os.makedirs("uploads", exist_ok=True)
os.makedirs("outputs", exist_ok=True)
app.mount("/static", StaticFiles(directory="web/static"), name="static")
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
# --- Helper Functions ---
def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False):
"""在后台线程中运行分析任务"""
session = session_manager.get_session(session_id)
if not session:
print(f"Error: Session {session_id} not found in background task.")
return
session.is_running = True
try:
base_output_dir = "outputs"
if not session.output_dir:
session.output_dir = create_session_output_dir(base_output_dir, user_requirement)
session_output_dir = session.output_dir
session.log_file = os.path.join(session_output_dir, "process.log")
# 使用 PrintCapture 替代全局 FileLogger退出 with 块后自动恢复 stdout
with PrintCapture(session.log_file):
if is_followup:
print(f"\n--- Follow-up Session {session_id} Continued ---")
else:
print(f"--- Session {session_id} Started ---")
try:
if not is_followup:
llm_config = LLMConfig()
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
session.agent = agent
result = agent.analyze(
user_input=user_requirement,
files=files,
session_output_dir=session_output_dir,
reset_session=True,
)
else:
agent = session.agent
if not agent:
print("Error: Agent not initialized for follow-up.")
return
result = agent.analyze(
user_input=user_requirement,
files=None,
session_output_dir=session_output_dir,
reset_session=False,
max_rounds=10,
)
session.generated_report = result.get("report_file_path", None)
session.analysis_results = result.get("analysis_results", [])
session.reusable_script = result.get("reusable_script_path", None)
# 持久化结果
with open(os.path.join(session_output_dir, "results.json"), "w") as f:
json.dump(session.analysis_results, f, default=str)
except Exception as e:
print(f"Error during analysis: {e}")
except Exception as e:
print(f"System Error: {e}")
finally:
session.is_running = False
# --- Pydantic Models ---
class StartRequest(BaseModel):
requirement: str
class ChatRequest(BaseModel):
session_id: str
message: str
# --- API Endpoints ---
@app.get("/")
async def read_root():
return FileResponse("web/static/index.html")
@app.post("/api/upload")
async def upload_files(files: list[UploadFile] = File(...)):
saved_files = []
for file in files:
file_location = f"uploads/{file.filename}"
with open(file_location, "wb+") as file_object:
file_object.write(file.file.read())
saved_files.append(file_location)
return {"info": f"Saved {len(saved_files)} files", "paths": saved_files}
@app.post("/api/start")
async def start_analysis(request: StartRequest, background_tasks: BackgroundTasks):
session_id = session_manager.create_session()
files = glob.glob("uploads/*.csv")
if not files:
if os.path.exists("cleaned_data.csv"):
files = ["cleaned_data.csv"]
else:
raise HTTPException(status_code=400, detail="No CSV files found")
files = [os.path.abspath(f) for f in files] # Only use absolute paths
background_tasks.add_task(run_analysis_task, session_id, files, request.requirement, is_followup=False)
return {"status": "started", "session_id": session_id}
@app.post("/api/chat")
async def chat_analysis(request: ChatRequest, background_tasks: BackgroundTasks):
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.is_running:
raise HTTPException(status_code=400, detail="Analysis already in progress")
background_tasks.add_task(run_analysis_task, request.session_id, [], request.message, is_followup=True)
return {"status": "started"}
@app.get("/api/status")
async def get_status(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
log_content = ""
if session.log_file and os.path.exists(session.log_file):
with open(session.log_file, "r", encoding="utf-8") as f:
log_content = f.read()
return {
"is_running": session.is_running,
"log": log_content,
"has_report": session.generated_report is not None,
"report_path": session.generated_report,
"script_path": session.reusable_script # 新增:返回脚本路径
}
@app.get("/api/export")
async def export_session(session_id: str = Query(..., description="Session ID")):
"""导出会话数据为ZIP"""
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.output_dir or not os.path.exists(session.output_dir):
raise HTTPException(status_code=404, detail="No data available for export")
import zipfile
from datetime import datetime as dt
timestamp = dt.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"report_{timestamp}.zip"
export_dir = "outputs"
os.makedirs(export_dir, exist_ok=True)
temp_zip_path = os.path.join(export_dir, zip_filename)
with zipfile.ZipFile(temp_zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for root, dirs, files in os.walk(session.output_dir):
for file in files:
if file.endswith(('.md', '.png', '.csv', '.log', '.json', '.yaml')):
abs_path = os.path.join(root, file)
rel_path = os.path.relpath(abs_path, session.output_dir)
zf.write(abs_path, arcname=rel_path)
return FileResponse(
path=temp_zip_path,
filename=zip_filename,
media_type='application/zip'
)
@app.get("/api/report")
async def get_report(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.generated_report or not os.path.exists(session.generated_report):
return {"content": "Report not ready.", "paragraphs": []}
with open(session.generated_report, "r", encoding="utf-8") as f:
content = f.read()
# Fix image paths
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
web_base_path = f"/{relative_session_path}"
# Robust image path replacement
content = content.replace("](./", f"]({web_base_path}/")
import re
def replace_link(match):
alt = match.group(1)
url = match.group(2)
if url.startswith("http") or url.startswith("/") or url.startswith("data:"):
return match.group(0)
clean_url = url.lstrip("./")
return f"![{alt}]({web_base_path}/{clean_url})"
content = re.sub(r'!\[(.*?)\]\((.*?)\)', replace_link, content)
# 将报告按段落拆分,为前端润色功能提供结构化数据
paragraphs = _split_report_to_paragraphs(content)
return {"content": content, "base_path": web_base_path, "paragraphs": paragraphs}
@app.get("/api/figures")
async def get_figures(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# We can try to get from memory first
results = session.analysis_results
# If empty in memory (maybe server restarted but files exist?), try load json
if not results and session.output_dir:
json_path = os.path.join(session.output_dir, "results.json")
if os.path.exists(json_path):
with open(json_path, 'r') as f:
results = json.load(f)
# Extract collected figures
figures = []
# We iterate over analysis results to find 'collect_figures' actions
if results:
for item in results:
if item.get("action") == "collect_figures":
collected = item.get("collected_figures", [])
for fig in collected:
# Enrich with web path
if session.output_dir:
# Assume filename is present
fname = fig.get("filename")
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
fig["web_url"] = f"/{relative_session_path}/{fname}"
figures.append(fig)
# Also check for 'generate_code' results that might have implicit figures if we parse them
# But the 'collect_figures' action is the reliable source as per agent design
# Auto-discovery fallback if list is empty but pngs exist?
if not figures and session.output_dir:
# Simple scan
pngs = glob.glob(os.path.join(session.output_dir, "*.png"))
for p in pngs:
fname = os.path.basename(p)
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
figures.append({
"filename": fname,
"description": "Auto-discovered image",
"analysis": "No analysis available",
"web_url": f"/{relative_session_path}/{fname}"
})
return {"figures": figures}
@app.get("/api/download_script")
async def download_script(session_id: str = Query(..., description="Session ID")):
"""下载生成的Python脚本"""
session = session_manager.get_session(session_id)
if not session or not session.reusable_script:
raise HTTPException(status_code=404, detail="Script not found")
if not os.path.exists(session.reusable_script):
raise HTTPException(status_code=404, detail="Script file missing on server")
return FileResponse(
path=session.reusable_script,
filename=os.path.basename(session.reusable_script),
media_type='text/x-python'
)
# --- Tools API ---
# --- 新增API端点 ---
@app.get("/api/sessions/progress")
async def get_session_progress(session_id: str = Query(..., description="Session ID")):
"""获取会话分析进度"""
session_info = session_manager.get_session_info(session_id)
if not session_info:
raise HTTPException(status_code=404, detail="Session not found")
return session_info
@app.get("/api/sessions/list")
async def list_all_sessions():
"""获取所有会话列表"""
session_ids = session_manager.list_sessions()
sessions_info = []
for sid in session_ids:
info = session_manager.get_session_info(sid)
if info:
sessions_info.append(info)
return {"sessions": sessions_info, "total": len(sessions_info)}
@app.delete("/api/sessions/{session_id}")
async def delete_specific_session(session_id: str):
"""删除指定会话"""
success = session_manager.delete_session(session_id)
if not success:
raise HTTPException(status_code=404, detail="Session not found")
return {"status": "deleted", "session_id": session_id}
# --- Report Polishing API ---
import re as _re
def _split_report_to_paragraphs(markdown_content: str) -> list:
"""
将 Markdown 报告按语义段落拆分。
每个段落包含 id、类型heading/text/table/image、原始内容。
前端可据此实现段落级选择与润色。
"""
lines = markdown_content.split("\n")
paragraphs = []
current_block = []
current_type = "text"
para_id = 0
def flush_block():
nonlocal para_id, current_block, current_type
text = "\n".join(current_block).strip()
if text:
paragraphs.append({
"id": f"p-{para_id}",
"type": current_type,
"content": text,
})
para_id += 1
current_block = []
current_type = "text"
in_table = False
in_code = False
for line in lines:
stripped = line.strip()
# 代码块边界
if stripped.startswith("```"):
if in_code:
current_block.append(line)
flush_block()
in_code = False
continue
else:
flush_block()
current_block.append(line)
current_type = "code"
in_code = True
continue
if in_code:
current_block.append(line)
continue
# 标题行 — 独立成段
if _re.match(r"^#{1,6}\s", stripped):
flush_block()
current_block.append(line)
current_type = "heading"
flush_block()
continue
# 图片行
if _re.match(r"^!\[.*\]\(.*\)", stripped):
flush_block()
current_block.append(line)
current_type = "image"
flush_block()
continue
# 表格行
if stripped.startswith("|"):
if not in_table:
flush_block()
in_table = True
current_type = "table"
current_block.append(line)
continue
else:
if in_table:
flush_block()
in_table = False
# 空行 — 段落分隔
if not stripped:
flush_block()
continue
# 普通文本
current_block.append(line)
flush_block()
return paragraphs
class PolishRequest(BaseModel):
session_id: str
paragraph_id: str
mode: str = "context" # "context" | "data" | "custom"
custom_instruction: str = ""
@app.post("/api/report/polish")
async def polish_paragraph(request: PolishRequest):
"""
对报告中指定段落进行 AI 润色。
mode:
- context: 根据上下文和图表信息润色,使表述更专业、更有洞察
- data: 结合原始分析数据重新生成该段落内容
- custom: 用户自定义润色指令
"""
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.generated_report or not os.path.exists(session.generated_report):
raise HTTPException(status_code=404, detail="Report not found")
# 读取报告并拆分段落
with open(session.generated_report, "r", encoding="utf-8") as f:
report_content = f.read()
paragraphs = _split_report_to_paragraphs(report_content)
# 找到目标段落
target = None
target_idx = -1
for i, p in enumerate(paragraphs):
if p["id"] == request.paragraph_id:
target = p
target_idx = i
break
if not target:
raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found")
# 构建上下文窗口前后各2个段落
context_window = []
for j in range(max(0, target_idx - 2), min(len(paragraphs), target_idx + 3)):
if j != target_idx:
context_window.append(paragraphs[j]["content"])
context_text = "\n\n".join(context_window)
# 收集图表信息
figures_info = ""
if session.analysis_results:
fig_parts = []
for item in session.analysis_results:
if item.get("action") == "collect_figures":
for fig in item.get("collected_figures", []):
fig_parts.append(f"- {fig.get('filename', '?')}: {fig.get('description', '')} / {fig.get('analysis', '')}")
if fig_parts:
figures_info = "\n".join(fig_parts)
# 构建润色 prompt
if request.mode == "data":
# 收集代码执行结果摘要
data_summary_parts = []
for item in session.analysis_results:
result = item.get("result", {})
if result.get("success") and result.get("output"):
output_text = result["output"][:2000]
data_summary_parts.append(output_text)
data_summary = "\n---\n".join(data_summary_parts[:5])
polish_prompt = f"""你是一位资深数据分析专家。请基于以下分析数据,重写下方段落,使其包含更精确的数据引用和更深入的业务洞察。
## 分析数据摘要
{data_summary}
## 图表信息
{figures_info}
## 需要润色的段落
{target['content']}
## 要求
- 保持原有的 Markdown 格式(标题级别、表格结构等)
- 用具体数据替换模糊描述
- 增加业务洞察和趋势判断
- 禁止使用第一人称
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
elif request.mode == "custom":
polish_prompt = f"""你是一位资深数据分析专家。请根据用户的指令润色以下段落。
## 用户指令
{request.custom_instruction}
## 上下文
{context_text}
## 图表信息
{figures_info}
## 需要润色的段落
{target['content']}
## 要求
- 保持原有的 Markdown 格式
- 严格遵循用户指令
- 禁止使用第一人称
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
else: # context mode
polish_prompt = f"""你是一位资深数据分析专家。请润色以下段落,使其表述更专业、更有洞察力。
## 上下文(前后段落)
{context_text}
## 图表信息
{figures_info}
## 需要润色的段落
{target['content']}
## 要求
- 保持原有的 Markdown 格式(标题级别、表格结构等)
- 提升专业性:使用同比、环比、占比等术语
- 增加洞察:不仅描述现象,还要分析原因和影响
- 禁止使用第一人称
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
# 调用 LLM 润色
try:
from utils.llm_helper import LLMHelper
llm = LLMHelper(LLMConfig())
polished_content = llm.call(
prompt=polish_prompt,
system_prompt="你是一位专业的数据分析报告润色专家。直接输出润色后的内容,不要添加任何解释或包裹。",
max_tokens=4096,
)
# 清理可能的代码块包裹
polished_content = polished_content.strip()
if polished_content.startswith("```markdown"):
polished_content = polished_content[len("```markdown"):].strip()
if polished_content.startswith("```"):
polished_content = polished_content[3:].strip()
if polished_content.endswith("```"):
polished_content = polished_content[:-3].strip()
return {
"paragraph_id": request.paragraph_id,
"original": target["content"],
"polished": polished_content,
"mode": request.mode,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Polish failed: {str(e)}")
class ApplyPolishRequest(BaseModel):
session_id: str
paragraph_id: str
new_content: str
@app.post("/api/report/apply")
async def apply_polish(request: ApplyPolishRequest):
"""
将润色后的内容应用到报告文件中,替换指定段落。
"""
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.generated_report or not os.path.exists(session.generated_report):
raise HTTPException(status_code=404, detail="Report not found")
with open(session.generated_report, "r", encoding="utf-8") as f:
report_content = f.read()
paragraphs = _split_report_to_paragraphs(report_content)
# 找到目标段落并替换
target = None
for p in paragraphs:
if p["id"] == request.paragraph_id:
target = p
break
if not target:
raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found")
# 在原文中替换
new_report = report_content.replace(target["content"], request.new_content, 1)
# 写回文件
with open(session.generated_report, "w", encoding="utf-8") as f:
f.write(new_report)
return {"status": "applied", "paragraph_id": request.paragraph_id}
# --- History API ---
@app.get("/api/history")
async def get_history():
"""
Get list of past analysis sessions from outputs directory
"""
history = []
output_base = "outputs"
if not os.path.exists(output_base):
return {"history": []}
try:
# Scan for session_* directories
for entry in os.scandir(output_base):
if entry.is_dir() and entry.name.startswith("session_"):
# Extract timestamp from folder name: session_20250101_120000
session_id = entry.name.replace("session_", "")
# Check creation time or extract from name
try:
# Try to parse timestamp from ID if it matches format
# Format: YYYYMMDD_HHMMSS
timestamp_str = session_id
dt = datetime.strptime(timestamp_str, "%Y%m%d_%H%M%S")
display_time = dt.strftime("%Y-%m-%d %H:%M:%S")
sort_key = dt.timestamp()
except ValueError:
# Fallback to file creation time
sort_key = entry.stat().st_ctime
display_time = datetime.fromtimestamp(sort_key).strftime("%Y-%m-%d %H:%M:%S")
history.append({
"id": session_id,
"timestamp": display_time,
"sort_key": sort_key,
"name": f"Session {display_time}"
})
# Sort by latest first
history.sort(key=lambda x: x["sort_key"], reverse=True)
# Cleanup internal sort key
for item in history:
del item["sort_key"]
return {"history": history}
except Exception as e:
print(f"Error scanning history: {e}")
return {"history": []}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)