858 lines
29 KiB
Python
858 lines
29 KiB
Python
|
||
import sys
|
||
import os
|
||
import threading
|
||
import glob
|
||
import uuid
|
||
import json
|
||
from datetime import datetime
|
||
from typing import Optional, Dict, List
|
||
from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException, Query
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.responses import FileResponse, JSONResponse
|
||
from pydantic import BaseModel
|
||
|
||
# Add parent directory to path to import agent modules
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from data_analysis_agent import DataAnalysisAgent
|
||
from config.llm_config import LLMConfig
|
||
from utils.create_session_dir import create_session_output_dir
|
||
from utils.logger import PrintCapture
|
||
|
||
app = FastAPI(title="IOV Data Analysis Agent")
|
||
|
||
|
||
def _to_web_path(fs_path: str) -> str:
|
||
"""将文件系统路径转为 URL 安全的正斜杠路径(修复 Windows 反斜杠问题)"""
|
||
return fs_path.replace("\\", "/")
|
||
|
||
|
||
# CORS
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# --- Session Management ---
|
||
|
||
class SessionData:
|
||
def __init__(self, session_id: str):
|
||
self.session_id = session_id
|
||
self.is_running = False
|
||
self.output_dir: Optional[str] = None
|
||
self.generated_report: Optional[str] = None
|
||
self.log_file: Optional[str] = None
|
||
self.analysis_results: List[Dict] = [] # Store analysis results for gallery
|
||
self.agent: Optional[DataAnalysisAgent] = None # Store the agent instance for follow-up
|
||
|
||
# 新增:进度跟踪
|
||
self.current_round: int = 0
|
||
self.max_rounds: int = 20
|
||
self.progress_percentage: float = 0.0
|
||
self.status_message: str = "等待开始"
|
||
|
||
# 新增:历史记录
|
||
self.created_at: str = ""
|
||
self.last_updated: str = ""
|
||
self.user_requirement: str = ""
|
||
self.file_list: List[str] = []
|
||
self.reusable_script: Optional[str] = None # 新增:可复用脚本路径
|
||
|
||
|
||
class SessionManager:
|
||
def __init__(self):
|
||
self.sessions: Dict[str, SessionData] = {}
|
||
self.lock = threading.Lock()
|
||
|
||
def create_session(self) -> str:
|
||
with self.lock:
|
||
session_id = str(uuid.uuid4())
|
||
self.sessions[session_id] = SessionData(session_id)
|
||
return session_id
|
||
|
||
|
||
def get_session(self, session_id: str) -> Optional[SessionData]:
|
||
if session_id in self.sessions:
|
||
return self.sessions[session_id]
|
||
|
||
# Fallback: Try to reconstruct from disk for history sessions
|
||
output_dir = os.path.join("outputs", f"session_{session_id}")
|
||
if os.path.exists(output_dir) and os.path.isdir(output_dir):
|
||
return self._reconstruct_session(session_id, output_dir)
|
||
|
||
return None
|
||
|
||
def _reconstruct_session(self, session_id: str, output_dir: str) -> SessionData:
|
||
"""从磁盘目录重建会话对象"""
|
||
session = SessionData(session_id)
|
||
session.output_dir = output_dir
|
||
session.is_running = False
|
||
session.current_round = session.max_rounds
|
||
session.progress_percentage = 100.0
|
||
session.status_message = "已完成 (历史记录)"
|
||
|
||
# Recover Log
|
||
log_path = os.path.join(output_dir, "process.log")
|
||
if os.path.exists(log_path):
|
||
session.log_file = log_path
|
||
|
||
# Recover Report
|
||
# 宽容查找:扫描所有 .md 文件,优先取包含 "report" 或 "报告" 的文件
|
||
md_files = glob.glob(os.path.join(output_dir, "*.md"))
|
||
if md_files:
|
||
# 默认取第一个
|
||
chosen = md_files[0]
|
||
# 尝试找更好的匹配
|
||
for md in md_files:
|
||
fname = os.path.basename(md).lower()
|
||
if "report" in fname or "报告" in fname:
|
||
chosen = md
|
||
break
|
||
session.generated_report = chosen
|
||
|
||
# Recover Script (查找可能的脚本文件)
|
||
possible_scripts = ["data_analysis_script.py", "script.py", "analysis_script.py"]
|
||
for s in possible_scripts:
|
||
p = os.path.join(output_dir, s)
|
||
if os.path.exists(p):
|
||
session.reusable_script = p
|
||
break
|
||
|
||
# Recover Results (images etc)
|
||
results_json = os.path.join(output_dir, "results.json")
|
||
if os.path.exists(results_json):
|
||
try:
|
||
with open(results_json, "r") as f:
|
||
session.analysis_results = json.load(f)
|
||
except:
|
||
pass
|
||
|
||
# Recover Metadata
|
||
try:
|
||
stat = os.stat(output_dir)
|
||
dt = datetime.fromtimestamp(stat.st_ctime)
|
||
session.created_at = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||
except:
|
||
pass
|
||
|
||
# Cache it
|
||
with self.lock:
|
||
self.sessions[session_id] = session
|
||
|
||
return session
|
||
|
||
def list_sessions(self):
|
||
return list(self.sessions.keys())
|
||
|
||
def delete_session(self, session_id: str) -> bool:
|
||
"""删除指定会话"""
|
||
with self.lock:
|
||
if session_id in self.sessions:
|
||
session = self.sessions[session_id]
|
||
if session.agent:
|
||
session.agent.reset()
|
||
del self.sessions[session_id]
|
||
return True
|
||
return False
|
||
|
||
def get_session_info(self, session_id: str) -> Optional[Dict]:
|
||
"""获取会话详细信息"""
|
||
session = self.get_session(session_id)
|
||
if session:
|
||
return {
|
||
"session_id": session.session_id,
|
||
"is_running": session.is_running,
|
||
"progress": session.progress_percentage,
|
||
"status": session.status_message,
|
||
"current_round": session.current_round,
|
||
"max_rounds": session.max_rounds,
|
||
"created_at": session.created_at,
|
||
"last_updated": session.last_updated,
|
||
"user_requirement": session.user_requirement[:100] + "..." if len(session.user_requirement) > 100 else session.user_requirement,
|
||
"script_path": session.reusable_script # 新增:返回脚本路径
|
||
}
|
||
return None
|
||
|
||
session_manager = SessionManager()
|
||
|
||
# Mount static files
|
||
os.makedirs("web/static", exist_ok=True)
|
||
os.makedirs("uploads", exist_ok=True)
|
||
os.makedirs("outputs", exist_ok=True)
|
||
|
||
app.mount("/static", StaticFiles(directory="web/static"), name="static")
|
||
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
|
||
|
||
# --- Helper Functions ---
|
||
|
||
def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False):
|
||
"""在后台线程中运行分析任务"""
|
||
session = session_manager.get_session(session_id)
|
||
if not session:
|
||
print(f"Error: Session {session_id} not found in background task.")
|
||
return
|
||
|
||
session.is_running = True
|
||
try:
|
||
base_output_dir = "outputs"
|
||
|
||
if not session.output_dir:
|
||
session.output_dir = create_session_output_dir(base_output_dir, user_requirement)
|
||
|
||
session_output_dir = session.output_dir
|
||
session.log_file = os.path.join(session_output_dir, "process.log")
|
||
|
||
# 使用 PrintCapture 替代全局 FileLogger,退出 with 块后自动恢复 stdout
|
||
with PrintCapture(session.log_file):
|
||
if is_followup:
|
||
print(f"\n--- Follow-up Session {session_id} Continued ---")
|
||
else:
|
||
print(f"--- Session {session_id} Started ---")
|
||
|
||
try:
|
||
if not is_followup:
|
||
llm_config = LLMConfig()
|
||
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
|
||
session.agent = agent
|
||
|
||
result = agent.analyze(
|
||
user_input=user_requirement,
|
||
files=files,
|
||
session_output_dir=session_output_dir,
|
||
reset_session=True,
|
||
)
|
||
else:
|
||
agent = session.agent
|
||
if not agent:
|
||
print("Error: Agent not initialized for follow-up.")
|
||
return
|
||
|
||
result = agent.analyze(
|
||
user_input=user_requirement,
|
||
files=None,
|
||
session_output_dir=session_output_dir,
|
||
reset_session=False,
|
||
max_rounds=10,
|
||
)
|
||
|
||
session.generated_report = result.get("report_file_path", None)
|
||
session.analysis_results = result.get("analysis_results", [])
|
||
session.reusable_script = result.get("reusable_script_path", None)
|
||
|
||
# 持久化结果
|
||
with open(os.path.join(session_output_dir, "results.json"), "w") as f:
|
||
json.dump(session.analysis_results, f, default=str)
|
||
|
||
except Exception as e:
|
||
print(f"Error during analysis: {e}")
|
||
|
||
except Exception as e:
|
||
print(f"System Error: {e}")
|
||
finally:
|
||
session.is_running = False
|
||
|
||
# --- Pydantic Models ---
|
||
|
||
class StartRequest(BaseModel):
|
||
requirement: str
|
||
|
||
class ChatRequest(BaseModel):
|
||
session_id: str
|
||
message: str
|
||
|
||
# --- API Endpoints ---
|
||
|
||
@app.get("/")
|
||
async def read_root():
|
||
return FileResponse("web/static/index.html")
|
||
|
||
@app.post("/api/upload")
|
||
async def upload_files(files: list[UploadFile] = File(...)):
|
||
saved_files = []
|
||
for file in files:
|
||
file_location = f"uploads/{file.filename}"
|
||
with open(file_location, "wb+") as file_object:
|
||
file_object.write(file.file.read())
|
||
saved_files.append(file_location)
|
||
return {"info": f"Saved {len(saved_files)} files", "paths": saved_files}
|
||
|
||
@app.post("/api/start")
|
||
async def start_analysis(request: StartRequest, background_tasks: BackgroundTasks):
|
||
session_id = session_manager.create_session()
|
||
|
||
files = glob.glob("uploads/*.csv")
|
||
if not files:
|
||
if os.path.exists("cleaned_data.csv"):
|
||
files = ["cleaned_data.csv"]
|
||
else:
|
||
raise HTTPException(status_code=400, detail="No CSV files found")
|
||
|
||
files = [os.path.abspath(f) for f in files] # Only use absolute paths
|
||
|
||
background_tasks.add_task(run_analysis_task, session_id, files, request.requirement, is_followup=False)
|
||
return {"status": "started", "session_id": session_id}
|
||
|
||
@app.post("/api/chat")
|
||
async def chat_analysis(request: ChatRequest, background_tasks: BackgroundTasks):
|
||
session = session_manager.get_session(request.session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
if session.is_running:
|
||
raise HTTPException(status_code=400, detail="Analysis already in progress")
|
||
|
||
background_tasks.add_task(run_analysis_task, request.session_id, [], request.message, is_followup=True)
|
||
return {"status": "started"}
|
||
|
||
@app.get("/api/status")
|
||
async def get_status(session_id: str = Query(..., description="Session ID")):
|
||
session = session_manager.get_session(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
log_content = ""
|
||
if session.log_file and os.path.exists(session.log_file):
|
||
with open(session.log_file, "r", encoding="utf-8") as f:
|
||
log_content = f.read()
|
||
|
||
return {
|
||
"is_running": session.is_running,
|
||
"log": log_content,
|
||
"has_report": session.generated_report is not None,
|
||
"report_path": session.generated_report,
|
||
"script_path": session.reusable_script # 新增:返回脚本路径
|
||
}
|
||
|
||
@app.get("/api/export")
|
||
async def export_session(session_id: str = Query(..., description="Session ID")):
|
||
"""导出会话数据为ZIP"""
|
||
session = session_manager.get_session(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
if not session.output_dir or not os.path.exists(session.output_dir):
|
||
raise HTTPException(status_code=404, detail="No data available for export")
|
||
|
||
import zipfile
|
||
from datetime import datetime as dt
|
||
|
||
timestamp = dt.now().strftime("%Y%m%d_%H%M%S")
|
||
zip_filename = f"report_{timestamp}.zip"
|
||
|
||
export_dir = "outputs"
|
||
os.makedirs(export_dir, exist_ok=True)
|
||
temp_zip_path = os.path.join(export_dir, zip_filename)
|
||
|
||
with zipfile.ZipFile(temp_zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||
for root, dirs, files in os.walk(session.output_dir):
|
||
for file in files:
|
||
if file.endswith(('.md', '.png', '.csv', '.log', '.json', '.yaml')):
|
||
abs_path = os.path.join(root, file)
|
||
rel_path = os.path.relpath(abs_path, session.output_dir)
|
||
zf.write(abs_path, arcname=rel_path)
|
||
|
||
return FileResponse(
|
||
path=temp_zip_path,
|
||
filename=zip_filename,
|
||
media_type='application/zip'
|
||
)
|
||
|
||
@app.get("/api/report")
|
||
async def get_report(session_id: str = Query(..., description="Session ID")):
|
||
session = session_manager.get_session(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
if not session.generated_report or not os.path.exists(session.generated_report):
|
||
return {"content": "Report not ready.", "paragraphs": []}
|
||
|
||
with open(session.generated_report, "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
|
||
# Fix image paths
|
||
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
|
||
web_base_path = f"/{relative_session_path}"
|
||
|
||
# Robust image path replacement
|
||
content = content.replace("](./", f"]({web_base_path}/")
|
||
|
||
import re
|
||
def replace_link(match):
|
||
alt = match.group(1)
|
||
url = match.group(2)
|
||
if url.startswith("http") or url.startswith("/") or url.startswith("data:"):
|
||
return match.group(0)
|
||
clean_url = url.lstrip("./")
|
||
return f""
|
||
|
||
content = re.sub(r'!\[(.*?)\]\((.*?)\)', replace_link, content)
|
||
|
||
# 将报告按段落拆分,为前端润色功能提供结构化数据
|
||
paragraphs = _split_report_to_paragraphs(content)
|
||
|
||
return {"content": content, "base_path": web_base_path, "paragraphs": paragraphs}
|
||
|
||
@app.get("/api/figures")
|
||
async def get_figures(session_id: str = Query(..., description="Session ID")):
|
||
session = session_manager.get_session(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
# We can try to get from memory first
|
||
results = session.analysis_results
|
||
|
||
# If empty in memory (maybe server restarted but files exist?), try load json
|
||
if not results and session.output_dir:
|
||
json_path = os.path.join(session.output_dir, "results.json")
|
||
if os.path.exists(json_path):
|
||
with open(json_path, 'r') as f:
|
||
results = json.load(f)
|
||
|
||
# Extract collected figures
|
||
figures = []
|
||
|
||
# We iterate over analysis results to find 'collect_figures' actions
|
||
if results:
|
||
for item in results:
|
||
if item.get("action") == "collect_figures":
|
||
collected = item.get("collected_figures", [])
|
||
for fig in collected:
|
||
# Enrich with web path
|
||
if session.output_dir:
|
||
# Assume filename is present
|
||
fname = fig.get("filename")
|
||
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
|
||
fig["web_url"] = f"/{relative_session_path}/{fname}"
|
||
figures.append(fig)
|
||
|
||
# Also check for 'generate_code' results that might have implicit figures if we parse them
|
||
# But the 'collect_figures' action is the reliable source as per agent design
|
||
|
||
# Auto-discovery fallback if list is empty but pngs exist?
|
||
if not figures and session.output_dir:
|
||
# Simple scan
|
||
pngs = glob.glob(os.path.join(session.output_dir, "*.png"))
|
||
for p in pngs:
|
||
fname = os.path.basename(p)
|
||
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
|
||
figures.append({
|
||
"filename": fname,
|
||
"description": "Auto-discovered image",
|
||
"analysis": "No analysis available",
|
||
"web_url": f"/{relative_session_path}/{fname}"
|
||
})
|
||
|
||
return {"figures": figures}
|
||
|
||
@app.get("/api/download_script")
|
||
async def download_script(session_id: str = Query(..., description="Session ID")):
|
||
"""下载生成的Python脚本"""
|
||
session = session_manager.get_session(session_id)
|
||
if not session or not session.reusable_script:
|
||
raise HTTPException(status_code=404, detail="Script not found")
|
||
|
||
if not os.path.exists(session.reusable_script):
|
||
raise HTTPException(status_code=404, detail="Script file missing on server")
|
||
|
||
return FileResponse(
|
||
path=session.reusable_script,
|
||
filename=os.path.basename(session.reusable_script),
|
||
media_type='text/x-python'
|
||
)
|
||
|
||
# --- Tools API ---
|
||
|
||
|
||
|
||
# --- 新增API端点 ---
|
||
|
||
@app.get("/api/sessions/progress")
|
||
async def get_session_progress(session_id: str = Query(..., description="Session ID")):
|
||
"""获取会话分析进度"""
|
||
session_info = session_manager.get_session_info(session_id)
|
||
if not session_info:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
return session_info
|
||
|
||
|
||
@app.get("/api/sessions/list")
|
||
async def list_all_sessions():
|
||
"""获取所有会话列表"""
|
||
session_ids = session_manager.list_sessions()
|
||
sessions_info = []
|
||
|
||
for sid in session_ids:
|
||
info = session_manager.get_session_info(sid)
|
||
if info:
|
||
sessions_info.append(info)
|
||
|
||
return {"sessions": sessions_info, "total": len(sessions_info)}
|
||
|
||
|
||
@app.delete("/api/sessions/{session_id}")
|
||
async def delete_specific_session(session_id: str):
|
||
"""删除指定会话"""
|
||
success = session_manager.delete_session(session_id)
|
||
if not success:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
return {"status": "deleted", "session_id": session_id}
|
||
|
||
|
||
# --- Report Polishing API ---
|
||
|
||
import re as _re
|
||
|
||
def _split_report_to_paragraphs(markdown_content: str) -> list:
|
||
"""
|
||
将 Markdown 报告按语义段落拆分。
|
||
每个段落包含 id、类型(heading/text/table/image)、原始内容。
|
||
前端可据此实现段落级选择与润色。
|
||
"""
|
||
lines = markdown_content.split("\n")
|
||
paragraphs = []
|
||
current_block = []
|
||
current_type = "text"
|
||
para_id = 0
|
||
|
||
def flush_block():
|
||
nonlocal para_id, current_block, current_type
|
||
text = "\n".join(current_block).strip()
|
||
if text:
|
||
paragraphs.append({
|
||
"id": f"p-{para_id}",
|
||
"type": current_type,
|
||
"content": text,
|
||
})
|
||
para_id += 1
|
||
current_block = []
|
||
current_type = "text"
|
||
|
||
in_table = False
|
||
in_code = False
|
||
|
||
for line in lines:
|
||
stripped = line.strip()
|
||
|
||
# 代码块边界
|
||
if stripped.startswith("```"):
|
||
if in_code:
|
||
current_block.append(line)
|
||
flush_block()
|
||
in_code = False
|
||
continue
|
||
else:
|
||
flush_block()
|
||
current_block.append(line)
|
||
current_type = "code"
|
||
in_code = True
|
||
continue
|
||
|
||
if in_code:
|
||
current_block.append(line)
|
||
continue
|
||
|
||
# 标题行 — 独立成段
|
||
if _re.match(r"^#{1,6}\s", stripped):
|
||
flush_block()
|
||
current_block.append(line)
|
||
current_type = "heading"
|
||
flush_block()
|
||
continue
|
||
|
||
# 图片行
|
||
if _re.match(r"^!\[.*\]\(.*\)", stripped):
|
||
flush_block()
|
||
current_block.append(line)
|
||
current_type = "image"
|
||
flush_block()
|
||
continue
|
||
|
||
# 表格行
|
||
if stripped.startswith("|"):
|
||
if not in_table:
|
||
flush_block()
|
||
in_table = True
|
||
current_type = "table"
|
||
current_block.append(line)
|
||
continue
|
||
else:
|
||
if in_table:
|
||
flush_block()
|
||
in_table = False
|
||
|
||
# 空行 — 段落分隔
|
||
if not stripped:
|
||
flush_block()
|
||
continue
|
||
|
||
# 普通文本
|
||
current_block.append(line)
|
||
|
||
flush_block()
|
||
return paragraphs
|
||
|
||
|
||
class PolishRequest(BaseModel):
|
||
session_id: str
|
||
paragraph_id: str
|
||
mode: str = "context" # "context" | "data" | "custom"
|
||
custom_instruction: str = ""
|
||
|
||
|
||
@app.post("/api/report/polish")
|
||
async def polish_paragraph(request: PolishRequest):
|
||
"""
|
||
对报告中指定段落进行 AI 润色。
|
||
|
||
mode:
|
||
- context: 根据上下文和图表信息润色,使表述更专业、更有洞察
|
||
- data: 结合原始分析数据重新生成该段落内容
|
||
- custom: 用户自定义润色指令
|
||
"""
|
||
session = session_manager.get_session(request.session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
if not session.generated_report or not os.path.exists(session.generated_report):
|
||
raise HTTPException(status_code=404, detail="Report not found")
|
||
|
||
# 读取报告并拆分段落
|
||
with open(session.generated_report, "r", encoding="utf-8") as f:
|
||
report_content = f.read()
|
||
|
||
paragraphs = _split_report_to_paragraphs(report_content)
|
||
|
||
# 找到目标段落
|
||
target = None
|
||
target_idx = -1
|
||
for i, p in enumerate(paragraphs):
|
||
if p["id"] == request.paragraph_id:
|
||
target = p
|
||
target_idx = i
|
||
break
|
||
|
||
if not target:
|
||
raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found")
|
||
|
||
# 构建上下文窗口(前后各2个段落)
|
||
context_window = []
|
||
for j in range(max(0, target_idx - 2), min(len(paragraphs), target_idx + 3)):
|
||
if j != target_idx:
|
||
context_window.append(paragraphs[j]["content"])
|
||
context_text = "\n\n".join(context_window)
|
||
|
||
# 收集图表信息
|
||
figures_info = ""
|
||
if session.analysis_results:
|
||
fig_parts = []
|
||
for item in session.analysis_results:
|
||
if item.get("action") == "collect_figures":
|
||
for fig in item.get("collected_figures", []):
|
||
fig_parts.append(f"- {fig.get('filename', '?')}: {fig.get('description', '')} / {fig.get('analysis', '')}")
|
||
if fig_parts:
|
||
figures_info = "\n".join(fig_parts)
|
||
|
||
# 构建润色 prompt
|
||
if request.mode == "data":
|
||
# 收集代码执行结果摘要
|
||
data_summary_parts = []
|
||
for item in session.analysis_results:
|
||
result = item.get("result", {})
|
||
if result.get("success") and result.get("output"):
|
||
output_text = result["output"][:2000]
|
||
data_summary_parts.append(output_text)
|
||
data_summary = "\n---\n".join(data_summary_parts[:5])
|
||
|
||
polish_prompt = f"""你是一位资深数据分析专家。请基于以下分析数据,重写下方段落,使其包含更精确的数据引用和更深入的业务洞察。
|
||
|
||
## 分析数据摘要
|
||
{data_summary}
|
||
|
||
## 图表信息
|
||
{figures_info}
|
||
|
||
## 需要润色的段落
|
||
{target['content']}
|
||
|
||
## 要求
|
||
- 保持原有的 Markdown 格式(标题级别、表格结构等)
|
||
- 用具体数据替换模糊描述
|
||
- 增加业务洞察和趋势判断
|
||
- 禁止使用第一人称
|
||
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
|
||
|
||
elif request.mode == "custom":
|
||
polish_prompt = f"""你是一位资深数据分析专家。请根据用户的指令润色以下段落。
|
||
|
||
## 用户指令
|
||
{request.custom_instruction}
|
||
|
||
## 上下文
|
||
{context_text}
|
||
|
||
## 图表信息
|
||
{figures_info}
|
||
|
||
## 需要润色的段落
|
||
{target['content']}
|
||
|
||
## 要求
|
||
- 保持原有的 Markdown 格式
|
||
- 严格遵循用户指令
|
||
- 禁止使用第一人称
|
||
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
|
||
|
||
else: # context mode
|
||
polish_prompt = f"""你是一位资深数据分析专家。请润色以下段落,使其表述更专业、更有洞察力。
|
||
|
||
## 上下文(前后段落)
|
||
{context_text}
|
||
|
||
## 图表信息
|
||
{figures_info}
|
||
|
||
## 需要润色的段落
|
||
{target['content']}
|
||
|
||
## 要求
|
||
- 保持原有的 Markdown 格式(标题级别、表格结构等)
|
||
- 提升专业性:使用同比、环比、占比等术语
|
||
- 增加洞察:不仅描述现象,还要分析原因和影响
|
||
- 禁止使用第一人称
|
||
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
|
||
|
||
# 调用 LLM 润色
|
||
try:
|
||
from utils.llm_helper import LLMHelper
|
||
llm = LLMHelper(LLMConfig())
|
||
polished_content = llm.call(
|
||
prompt=polish_prompt,
|
||
system_prompt="你是一位专业的数据分析报告润色专家。直接输出润色后的内容,不要添加任何解释或包裹。",
|
||
max_tokens=4096,
|
||
)
|
||
|
||
# 清理可能的代码块包裹
|
||
polished_content = polished_content.strip()
|
||
if polished_content.startswith("```markdown"):
|
||
polished_content = polished_content[len("```markdown"):].strip()
|
||
if polished_content.startswith("```"):
|
||
polished_content = polished_content[3:].strip()
|
||
if polished_content.endswith("```"):
|
||
polished_content = polished_content[:-3].strip()
|
||
|
||
return {
|
||
"paragraph_id": request.paragraph_id,
|
||
"original": target["content"],
|
||
"polished": polished_content,
|
||
"mode": request.mode,
|
||
}
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Polish failed: {str(e)}")
|
||
|
||
|
||
class ApplyPolishRequest(BaseModel):
|
||
session_id: str
|
||
paragraph_id: str
|
||
new_content: str
|
||
|
||
|
||
@app.post("/api/report/apply")
|
||
async def apply_polish(request: ApplyPolishRequest):
|
||
"""
|
||
将润色后的内容应用到报告文件中,替换指定段落。
|
||
"""
|
||
session = session_manager.get_session(request.session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
if not session.generated_report or not os.path.exists(session.generated_report):
|
||
raise HTTPException(status_code=404, detail="Report not found")
|
||
|
||
with open(session.generated_report, "r", encoding="utf-8") as f:
|
||
report_content = f.read()
|
||
|
||
paragraphs = _split_report_to_paragraphs(report_content)
|
||
|
||
# 找到目标段落并替换
|
||
target = None
|
||
for p in paragraphs:
|
||
if p["id"] == request.paragraph_id:
|
||
target = p
|
||
break
|
||
|
||
if not target:
|
||
raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found")
|
||
|
||
# 在原文中替换
|
||
new_report = report_content.replace(target["content"], request.new_content, 1)
|
||
|
||
# 写回文件
|
||
with open(session.generated_report, "w", encoding="utf-8") as f:
|
||
f.write(new_report)
|
||
|
||
return {"status": "applied", "paragraph_id": request.paragraph_id}
|
||
|
||
|
||
# --- History API ---
|
||
|
||
@app.get("/api/history")
|
||
async def get_history():
|
||
"""
|
||
Get list of past analysis sessions from outputs directory
|
||
"""
|
||
history = []
|
||
output_base = "outputs"
|
||
|
||
if not os.path.exists(output_base):
|
||
return {"history": []}
|
||
|
||
try:
|
||
# Scan for session_* directories
|
||
for entry in os.scandir(output_base):
|
||
if entry.is_dir() and entry.name.startswith("session_"):
|
||
# Extract timestamp from folder name: session_20250101_120000
|
||
session_id = entry.name.replace("session_", "")
|
||
|
||
# Check creation time or extract from name
|
||
try:
|
||
# Try to parse timestamp from ID if it matches format
|
||
# Format: YYYYMMDD_HHMMSS
|
||
timestamp_str = session_id
|
||
dt = datetime.strptime(timestamp_str, "%Y%m%d_%H%M%S")
|
||
display_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||
sort_key = dt.timestamp()
|
||
except ValueError:
|
||
# Fallback to file creation time
|
||
sort_key = entry.stat().st_ctime
|
||
display_time = datetime.fromtimestamp(sort_key).strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
history.append({
|
||
"id": session_id,
|
||
"timestamp": display_time,
|
||
"sort_key": sort_key,
|
||
"name": f"Session {display_time}"
|
||
})
|
||
|
||
# Sort by latest first
|
||
history.sort(key=lambda x: x["sort_key"], reverse=True)
|
||
|
||
# Cleanup internal sort key
|
||
for item in history:
|
||
del item["sort_key"]
|
||
|
||
return {"history": history}
|
||
|
||
except Exception as e:
|
||
print(f"Error scanning history: {e}")
|
||
return {"history": []}
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|