大更新,架构调整,数据分析能力提升,
This commit is contained in:
212
web/main.py
212
web/main.py
@@ -5,6 +5,7 @@ import threading
|
||||
import glob
|
||||
import uuid
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, List
|
||||
from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException, Query
|
||||
@@ -12,6 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from pydantic import BaseModel
|
||||
import pandas as pd
|
||||
|
||||
# Add parent directory to path to import agent modules
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -48,6 +50,8 @@ class SessionData:
|
||||
self.generated_report: Optional[str] = None
|
||||
self.log_file: Optional[str] = None
|
||||
self.analysis_results: List[Dict] = [] # Store analysis results for gallery
|
||||
self.rounds: List[Dict] = [] # Structured Round_Data objects
|
||||
self.data_files: List[Dict] = [] # File metadata dicts
|
||||
self.agent: Optional[DataAnalysisAgent] = None # Store the agent instance for follow-up
|
||||
|
||||
# 新增:进度跟踪
|
||||
@@ -128,7 +132,15 @@ class SessionManager:
|
||||
if os.path.exists(results_json):
|
||||
try:
|
||||
with open(results_json, "r") as f:
|
||||
session.analysis_results = json.load(f)
|
||||
data = json.load(f)
|
||||
# Support both old format (plain list) and new format (dict with rounds/data_files)
|
||||
if isinstance(data, dict):
|
||||
session.analysis_results = data.get("analysis_results", [])
|
||||
session.rounds = data.get("rounds", [])
|
||||
session.data_files = data.get("data_files", [])
|
||||
else:
|
||||
# Legacy format: data is the analysis_results list directly
|
||||
session.analysis_results = data
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -190,7 +202,7 @@ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
|
||||
|
||||
# --- Helper Functions ---
|
||||
|
||||
def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False):
|
||||
def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False, template_name: str = None):
|
||||
"""在后台线程中运行分析任务"""
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
@@ -220,11 +232,22 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str, is_fo
|
||||
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
|
||||
session.agent = agent
|
||||
|
||||
# Wire progress callback to update session progress fields
|
||||
def progress_cb(current, total, message):
|
||||
session.current_round = current
|
||||
session.max_rounds = total
|
||||
session.progress_percentage = round((current / total) * 100, 1) if total > 0 else 0
|
||||
session.status_message = message
|
||||
|
||||
agent.set_progress_callback(progress_cb)
|
||||
agent.set_session_ref(session)
|
||||
|
||||
result = agent.analyze(
|
||||
user_input=user_requirement,
|
||||
files=files,
|
||||
session_output_dir=session_output_dir,
|
||||
reset_session=True,
|
||||
template_name=template_name,
|
||||
)
|
||||
else:
|
||||
agent = session.agent
|
||||
@@ -232,6 +255,16 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str, is_fo
|
||||
print("Error: Agent not initialized for follow-up.")
|
||||
return
|
||||
|
||||
# Wire progress callback for follow-up sessions
|
||||
def progress_cb_followup(current, total, message):
|
||||
session.current_round = current
|
||||
session.max_rounds = total
|
||||
session.progress_percentage = round((current / total) * 100, 1) if total > 0 else 0
|
||||
session.status_message = message
|
||||
|
||||
agent.set_progress_callback(progress_cb_followup)
|
||||
agent.set_session_ref(session)
|
||||
|
||||
result = agent.analyze(
|
||||
user_input=user_requirement,
|
||||
files=None,
|
||||
@@ -246,7 +279,11 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str, is_fo
|
||||
|
||||
# 持久化结果
|
||||
with open(os.path.join(session_output_dir, "results.json"), "w") as f:
|
||||
json.dump(session.analysis_results, f, default=str)
|
||||
json.dump({
|
||||
"analysis_results": session.analysis_results,
|
||||
"rounds": session.rounds,
|
||||
"data_files": session.data_files,
|
||||
}, f, default=str)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during analysis: {e}")
|
||||
@@ -260,6 +297,7 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str, is_fo
|
||||
|
||||
class StartRequest(BaseModel):
|
||||
requirement: str
|
||||
template: Optional[str] = None
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
@@ -294,7 +332,7 @@ async def start_analysis(request: StartRequest, background_tasks: BackgroundTask
|
||||
|
||||
files = [os.path.abspath(f) for f in files] # Only use absolute paths
|
||||
|
||||
background_tasks.add_task(run_analysis_task, session_id, files, request.requirement, is_followup=False)
|
||||
background_tasks.add_task(run_analysis_task, session_id, files, request.requirement, is_followup=False, template_name=request.template)
|
||||
return {"status": "started", "session_id": session_id}
|
||||
|
||||
@app.post("/api/chat")
|
||||
@@ -309,6 +347,19 @@ async def chat_analysis(request: ChatRequest, background_tasks: BackgroundTasks)
|
||||
background_tasks.add_task(run_analysis_task, request.session_id, [], request.message, is_followup=True)
|
||||
return {"status": "started"}
|
||||
|
||||
import math as _math
|
||||
|
||||
def _sanitize_value(v):
|
||||
"""Replace NaN/inf with None for JSON safety."""
|
||||
if isinstance(v, float) and (_math.isnan(v) or _math.isinf(v)):
|
||||
return None
|
||||
if isinstance(v, dict):
|
||||
return {k: _sanitize_value(val) for k, val in v.items()}
|
||||
if isinstance(v, list):
|
||||
return [_sanitize_value(item) for item in v]
|
||||
return v
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status(session_id: str = Query(..., description="Session ID")):
|
||||
session = session_manager.get_session(session_id)
|
||||
@@ -320,13 +371,19 @@ async def get_status(session_id: str = Query(..., description="Session ID")):
|
||||
with open(session.log_file, "r", encoding="utf-8") as f:
|
||||
log_content = f.read()
|
||||
|
||||
return {
|
||||
response_data = {
|
||||
"is_running": session.is_running,
|
||||
"log": log_content,
|
||||
"has_report": session.generated_report is not None,
|
||||
"report_path": session.generated_report,
|
||||
"script_path": session.reusable_script # 新增:返回脚本路径
|
||||
"script_path": session.reusable_script,
|
||||
"current_round": session.current_round,
|
||||
"max_rounds": session.max_rounds,
|
||||
"progress_percentage": session.progress_percentage,
|
||||
"status_message": session.status_message,
|
||||
"rounds": _sanitize_value(session.rounds),
|
||||
}
|
||||
return JSONResponse(content=response_data)
|
||||
|
||||
@app.get("/api/export")
|
||||
async def export_session(session_id: str = Query(..., description="Session ID")):
|
||||
@@ -394,8 +451,11 @@ async def get_report(session_id: str = Query(..., description="Session ID")):
|
||||
|
||||
# 将报告按段落拆分,为前端润色功能提供结构化数据
|
||||
paragraphs = _split_report_to_paragraphs(content)
|
||||
|
||||
# Extract evidence annotations and build supporting_data mapping
|
||||
supporting_data = _extract_evidence_annotations(paragraphs, session)
|
||||
|
||||
return {"content": content, "base_path": web_base_path, "paragraphs": paragraphs}
|
||||
return {"content": content, "base_path": web_base_path, "paragraphs": paragraphs, "supporting_data": supporting_data}
|
||||
|
||||
@app.get("/api/figures")
|
||||
async def get_figures(session_id: str = Query(..., description="Session ID")):
|
||||
@@ -467,6 +527,120 @@ async def download_script(session_id: str = Query(..., description="Session ID")
|
||||
|
||||
# --- Tools API ---
|
||||
|
||||
@app.get("/api/templates")
|
||||
async def list_available_templates():
|
||||
from utils.analysis_templates import list_templates
|
||||
return {"templates": list_templates()}
|
||||
|
||||
|
||||
# --- Data Files API ---
|
||||
|
||||
@app.get("/api/data-files")
|
||||
async def list_data_files(session_id: str = Query(..., description="Session ID")):
|
||||
"""Return session.data_files merged with fallback directory scan for CSV/XLSX files."""
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Start with known data_files from session
|
||||
known_files = {f["filename"]: f for f in session.data_files}
|
||||
|
||||
# Fallback directory scan for CSV/XLSX in output_dir
|
||||
if session.output_dir and os.path.exists(session.output_dir):
|
||||
# Collect original uploaded file basenames to exclude them
|
||||
uploaded_basenames = set()
|
||||
if hasattr(session, "file_list"):
|
||||
for fp in session.file_list:
|
||||
uploaded_basenames.add(os.path.basename(fp))
|
||||
|
||||
for pattern in ("*.csv", "*.xlsx"):
|
||||
for fpath in glob.glob(os.path.join(session.output_dir, pattern)):
|
||||
fname = os.path.basename(fpath)
|
||||
if fname in uploaded_basenames:
|
||||
continue
|
||||
if fname not in known_files:
|
||||
try:
|
||||
size_bytes = os.path.getsize(fpath)
|
||||
except OSError:
|
||||
size_bytes = 0
|
||||
known_files[fname] = {
|
||||
"filename": fname,
|
||||
"description": "",
|
||||
"rows": 0,
|
||||
"cols": 0,
|
||||
"size_bytes": size_bytes,
|
||||
}
|
||||
|
||||
return {"files": list(known_files.values())}
|
||||
|
||||
|
||||
@app.get("/api/data-files/preview")
|
||||
async def preview_data_file(
|
||||
session_id: str = Query(..., description="Session ID"),
|
||||
filename: str = Query(..., description="File name"),
|
||||
):
|
||||
"""Read CSV/XLSX via pandas, return {columns, rows (first 5)}."""
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if not session.output_dir:
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
|
||||
|
||||
file_path = os.path.join(session.output_dir, filename)
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
|
||||
|
||||
try:
|
||||
if filename.lower().endswith(".xlsx"):
|
||||
df = pd.read_excel(file_path, nrows=5)
|
||||
else:
|
||||
# Try utf-8-sig first (common for Chinese CSV exports), fall back to utf-8
|
||||
try:
|
||||
df = pd.read_csv(file_path, nrows=5, encoding="utf-8-sig")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
df = pd.read_csv(file_path, nrows=5, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
df = pd.read_csv(file_path, nrows=5, encoding="gbk")
|
||||
|
||||
columns = list(df.columns)
|
||||
rows = df.head(5).to_dict(orient="records")
|
||||
# Sanitize NaN/inf for JSON serialization
|
||||
rows = [
|
||||
{k: (None if isinstance(v, float) and (_math.isnan(v) or _math.isinf(v)) else v)
|
||||
for k, v in row.items()}
|
||||
for row in rows
|
||||
]
|
||||
return {"columns": columns, "rows": rows}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to read file: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/api/data-files/download")
|
||||
async def download_data_file(
|
||||
session_id: str = Query(..., description="Session ID"),
|
||||
filename: str = Query(..., description="File name"),
|
||||
):
|
||||
"""Return FileResponse with correct MIME type."""
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if not session.output_dir:
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
|
||||
|
||||
file_path = os.path.join(session.output_dir, filename)
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
|
||||
|
||||
if filename.lower().endswith(".xlsx"):
|
||||
media_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
else:
|
||||
media_type = "text/csv"
|
||||
|
||||
return FileResponse(path=file_path, filename=filename, media_type=media_type)
|
||||
|
||||
|
||||
|
||||
# --- 新增API端点 ---
|
||||
@@ -597,6 +771,30 @@ def _split_report_to_paragraphs(markdown_content: str) -> list:
|
||||
return paragraphs
|
||||
|
||||
|
||||
def _extract_evidence_annotations(paragraphs: list, session) -> dict:
|
||||
"""Parse <!-- evidence:round_N --> annotations from paragraph content.
|
||||
|
||||
For each paragraph containing an evidence annotation, look up
|
||||
session.rounds[N-1].evidence_rows and build a supporting_data mapping
|
||||
keyed by paragraph ID.
|
||||
"""
|
||||
supporting_data = {}
|
||||
evidence_pattern = re.compile(r"<!--\s*evidence:round_(\d+)\s*-->")
|
||||
|
||||
for para in paragraphs:
|
||||
content = para.get("content", "")
|
||||
match = evidence_pattern.search(content)
|
||||
if match:
|
||||
round_num = int(match.group(1))
|
||||
# rounds are 1-indexed, list is 0-indexed
|
||||
idx = round_num - 1
|
||||
if 0 <= idx < len(session.rounds):
|
||||
evidence_rows = session.rounds[idx].get("evidence_rows", [])
|
||||
if evidence_rows:
|
||||
supporting_data[para["id"]] = evidence_rows
|
||||
return supporting_data
|
||||
|
||||
|
||||
class PolishRequest(BaseModel):
|
||||
session_id: str
|
||||
paragraph_id: str
|
||||
|
||||
Reference in New Issue
Block a user