大更新,架构调整,数据分析能力提升,

This commit is contained in:
2026-04-19 21:30:08 +08:00
parent 9d01f004d4
commit 00bd48e7e7
26 changed files with 4375 additions and 252 deletions

View File

@@ -5,6 +5,7 @@ import threading
import glob
import uuid
import json
import re
from datetime import datetime
from typing import Optional, Dict, List
from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException, Query
@@ -12,6 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
import pandas as pd
# Add parent directory to path to import agent modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -48,6 +50,8 @@ class SessionData:
self.generated_report: Optional[str] = None
self.log_file: Optional[str] = None
self.analysis_results: List[Dict] = [] # Store analysis results for gallery
self.rounds: List[Dict] = [] # Structured Round_Data objects
self.data_files: List[Dict] = [] # File metadata dicts
self.agent: Optional[DataAnalysisAgent] = None # Store the agent instance for follow-up
# 新增:进度跟踪
@@ -128,7 +132,15 @@ class SessionManager:
if os.path.exists(results_json):
try:
with open(results_json, "r") as f:
session.analysis_results = json.load(f)
data = json.load(f)
# Support both old format (plain list) and new format (dict with rounds/data_files)
if isinstance(data, dict):
session.analysis_results = data.get("analysis_results", [])
session.rounds = data.get("rounds", [])
session.data_files = data.get("data_files", [])
else:
# Legacy format: data is the analysis_results list directly
session.analysis_results = data
except:
pass
@@ -190,7 +202,7 @@ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
# --- Helper Functions ---
def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False):
def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False, template_name: str = None):
"""在后台线程中运行分析任务"""
session = session_manager.get_session(session_id)
if not session:
@@ -220,11 +232,22 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str, is_fo
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
session.agent = agent
# Wire progress callback to update session progress fields
def progress_cb(current, total, message):
session.current_round = current
session.max_rounds = total
session.progress_percentage = round((current / total) * 100, 1) if total > 0 else 0
session.status_message = message
agent.set_progress_callback(progress_cb)
agent.set_session_ref(session)
result = agent.analyze(
user_input=user_requirement,
files=files,
session_output_dir=session_output_dir,
reset_session=True,
template_name=template_name,
)
else:
agent = session.agent
@@ -232,6 +255,16 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str, is_fo
print("Error: Agent not initialized for follow-up.")
return
# Wire progress callback for follow-up sessions
def progress_cb_followup(current, total, message):
session.current_round = current
session.max_rounds = total
session.progress_percentage = round((current / total) * 100, 1) if total > 0 else 0
session.status_message = message
agent.set_progress_callback(progress_cb_followup)
agent.set_session_ref(session)
result = agent.analyze(
user_input=user_requirement,
files=None,
@@ -246,7 +279,11 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str, is_fo
# 持久化结果
with open(os.path.join(session_output_dir, "results.json"), "w") as f:
json.dump(session.analysis_results, f, default=str)
json.dump({
"analysis_results": session.analysis_results,
"rounds": session.rounds,
"data_files": session.data_files,
}, f, default=str)
except Exception as e:
print(f"Error during analysis: {e}")
@@ -260,6 +297,7 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str, is_fo
class StartRequest(BaseModel):
requirement: str
template: Optional[str] = None
class ChatRequest(BaseModel):
session_id: str
@@ -294,7 +332,7 @@ async def start_analysis(request: StartRequest, background_tasks: BackgroundTask
files = [os.path.abspath(f) for f in files] # Only use absolute paths
background_tasks.add_task(run_analysis_task, session_id, files, request.requirement, is_followup=False)
background_tasks.add_task(run_analysis_task, session_id, files, request.requirement, is_followup=False, template_name=request.template)
return {"status": "started", "session_id": session_id}
@app.post("/api/chat")
@@ -309,6 +347,19 @@ async def chat_analysis(request: ChatRequest, background_tasks: BackgroundTasks)
background_tasks.add_task(run_analysis_task, request.session_id, [], request.message, is_followup=True)
return {"status": "started"}
import math as _math
def _sanitize_value(v):
"""Replace NaN/inf with None for JSON safety."""
if isinstance(v, float) and (_math.isnan(v) or _math.isinf(v)):
return None
if isinstance(v, dict):
return {k: _sanitize_value(val) for k, val in v.items()}
if isinstance(v, list):
return [_sanitize_value(item) for item in v]
return v
@app.get("/api/status")
async def get_status(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
@@ -320,13 +371,19 @@ async def get_status(session_id: str = Query(..., description="Session ID")):
with open(session.log_file, "r", encoding="utf-8") as f:
log_content = f.read()
return {
response_data = {
"is_running": session.is_running,
"log": log_content,
"has_report": session.generated_report is not None,
"report_path": session.generated_report,
"script_path": session.reusable_script # 新增:返回脚本路径
"script_path": session.reusable_script,
"current_round": session.current_round,
"max_rounds": session.max_rounds,
"progress_percentage": session.progress_percentage,
"status_message": session.status_message,
"rounds": _sanitize_value(session.rounds),
}
return JSONResponse(content=response_data)
@app.get("/api/export")
async def export_session(session_id: str = Query(..., description="Session ID")):
@@ -394,8 +451,11 @@ async def get_report(session_id: str = Query(..., description="Session ID")):
# 将报告按段落拆分,为前端润色功能提供结构化数据
paragraphs = _split_report_to_paragraphs(content)
# Extract evidence annotations and build supporting_data mapping
supporting_data = _extract_evidence_annotations(paragraphs, session)
return {"content": content, "base_path": web_base_path, "paragraphs": paragraphs}
return {"content": content, "base_path": web_base_path, "paragraphs": paragraphs, "supporting_data": supporting_data}
@app.get("/api/figures")
async def get_figures(session_id: str = Query(..., description="Session ID")):
@@ -467,6 +527,120 @@ async def download_script(session_id: str = Query(..., description="Session ID")
# --- Tools API ---
@app.get("/api/templates")
async def list_available_templates():
from utils.analysis_templates import list_templates
return {"templates": list_templates()}
# --- Data Files API ---
@app.get("/api/data-files")
async def list_data_files(session_id: str = Query(..., description="Session ID")):
"""Return session.data_files merged with fallback directory scan for CSV/XLSX files."""
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Start with known data_files from session
known_files = {f["filename"]: f for f in session.data_files}
# Fallback directory scan for CSV/XLSX in output_dir
if session.output_dir and os.path.exists(session.output_dir):
# Collect original uploaded file basenames to exclude them
uploaded_basenames = set()
if hasattr(session, "file_list"):
for fp in session.file_list:
uploaded_basenames.add(os.path.basename(fp))
for pattern in ("*.csv", "*.xlsx"):
for fpath in glob.glob(os.path.join(session.output_dir, pattern)):
fname = os.path.basename(fpath)
if fname in uploaded_basenames:
continue
if fname not in known_files:
try:
size_bytes = os.path.getsize(fpath)
except OSError:
size_bytes = 0
known_files[fname] = {
"filename": fname,
"description": "",
"rows": 0,
"cols": 0,
"size_bytes": size_bytes,
}
return {"files": list(known_files.values())}
@app.get("/api/data-files/preview")
async def preview_data_file(
session_id: str = Query(..., description="Session ID"),
filename: str = Query(..., description="File name"),
):
"""Read CSV/XLSX via pandas, return {columns, rows (first 5)}."""
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.output_dir:
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
file_path = os.path.join(session.output_dir, filename)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
try:
if filename.lower().endswith(".xlsx"):
df = pd.read_excel(file_path, nrows=5)
else:
# Try utf-8-sig first (common for Chinese CSV exports), fall back to utf-8
try:
df = pd.read_csv(file_path, nrows=5, encoding="utf-8-sig")
except UnicodeDecodeError:
try:
df = pd.read_csv(file_path, nrows=5, encoding="utf-8")
except UnicodeDecodeError:
df = pd.read_csv(file_path, nrows=5, encoding="gbk")
columns = list(df.columns)
rows = df.head(5).to_dict(orient="records")
# Sanitize NaN/inf for JSON serialization
rows = [
{k: (None if isinstance(v, float) and (_math.isnan(v) or _math.isinf(v)) else v)
for k, v in row.items()}
for row in rows
]
return {"columns": columns, "rows": rows}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to read file: {str(e)}")
@app.get("/api/data-files/download")
async def download_data_file(
session_id: str = Query(..., description="Session ID"),
filename: str = Query(..., description="File name"),
):
"""Return FileResponse with correct MIME type."""
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.output_dir:
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
file_path = os.path.join(session.output_dir, filename)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
if filename.lower().endswith(".xlsx"):
media_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
else:
media_type = "text/csv"
return FileResponse(path=file_path, filename=filename, media_type=media_type)
# --- 新增API端点 ---
@@ -597,6 +771,30 @@ def _split_report_to_paragraphs(markdown_content: str) -> list:
return paragraphs
def _extract_evidence_annotations(paragraphs: list, session) -> dict:
"""Parse <!-- evidence:round_N --> annotations from paragraph content.
For each paragraph containing an evidence annotation, look up
session.rounds[N-1].evidence_rows and build a supporting_data mapping
keyed by paragraph ID.
"""
supporting_data = {}
evidence_pattern = re.compile(r"<!--\s*evidence:round_(\d+)\s*-->")
for para in paragraphs:
content = para.get("content", "")
match = evidence_pattern.search(content)
if match:
round_num = int(match.group(1))
# rounds are 1-indexed, list is 0-indexed
idx = round_num - 1
if 0 <= idx < len(session.rounds):
evidence_rows = session.rounds[idx].get("evidence_rows", [])
if evidence_rows:
supporting_data[para["id"]] = evidence_rows
return supporting_data
class PolishRequest(BaseModel):
session_id: str
paragraph_id: str