修改前端显示逻辑

This commit is contained in:
2026-01-22 22:26:04 +08:00
parent b1d0cc5462
commit 162f5c4da4
10 changed files with 828 additions and 581 deletions

View File

@@ -42,6 +42,8 @@ class SessionData:
self.generated_report: Optional[str] = None
self.log_file: Optional[str] = None
self.analysis_results: List[Dict] = [] # Store analysis results for gallery
self.agent: Optional[DataAnalysisAgent] = None # Store the agent instance for follow-up
class SessionManager:
def __init__(self):
@@ -72,7 +74,7 @@ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
# --- Helper Functions ---
def run_analysis_task(session_id: str, files: list, user_requirement: str):
def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False):
"""
Runs the analysis agent in a background thread for a specific session.
"""
@@ -83,14 +85,13 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str):
session.is_running = True
try:
# Create session directory
# Create session directory if not exists (for follow-up it should accept existing)
base_output_dir = "outputs"
# We enforce a specific directory naming convention or let the util handle it
# ideally we map session_id to the directory
# For now, let's use the standard utility but we might lose the direct mapping if not careful
# Let's trust the return value
session_output_dir = create_session_output_dir(base_output_dir, user_requirement)
session.output_dir = session_output_dir
if not session.output_dir:
session.output_dir = create_session_output_dir(base_output_dir, user_requirement)
session_output_dir = session.output_dir
# Initialize Log capturing
session.log_file = os.path.join(session_output_dir, "process.log")
@@ -125,8 +126,13 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str):
# but try to capture to file?
# Let's just write to the file.
with open(session.log_file, "w", encoding="utf-8") as f:
f.write(f"--- Session {session_id} Started ---\n")
# Let's just write to the file.
with open(session.log_file, "a" if is_followup else "w", encoding="utf-8") as f:
if is_followup:
f.write(f"\n--- Follow-up Session {session_id} Continued ---\n")
else:
f.write(f"--- Session {session_id} Started ---\n")
# We will create a custom print function that writes to the file
# And monkeypatch builtins.print? No, that's too hacky.
@@ -153,14 +159,30 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str):
sys.stdout = logger # Global hijack!
try:
llm_config = LLMConfig()
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
result = agent.analyze(
user_input=user_requirement,
files=files,
session_output_dir=session_output_dir
)
if not is_followup:
llm_config = LLMConfig()
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
session.agent = agent
result = agent.analyze(
user_input=user_requirement,
files=files,
session_output_dir=session_output_dir,
reset_session=True
)
else:
agent = session.agent
if not agent:
print("Error: Agent not initialized for follow-up.")
return
result = agent.analyze(
user_input=user_requirement,
files=None,
session_output_dir=session_output_dir,
reset_session=False,
max_rounds=10
)
session.generated_report = result.get("report_file_path", None)
session.analysis_results = result.get("analysis_results", [])
@@ -185,6 +207,10 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str):
class StartRequest(BaseModel):
requirement: str
class ChatRequest(BaseModel):
session_id: str
message: str
# --- API Endpoints ---
@app.get("/")
@@ -214,9 +240,21 @@ async def start_analysis(request: StartRequest, background_tasks: BackgroundTask
files = [os.path.abspath(f) for f in files] # Only use absolute paths
background_tasks.add_task(run_analysis_task, session_id, files, request.requirement)
background_tasks.add_task(run_analysis_task, session_id, files, request.requirement, is_followup=False)
return {"status": "started", "session_id": session_id}
@app.post("/api/chat")
async def chat_analysis(request: ChatRequest, background_tasks: BackgroundTasks):
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.is_running:
raise HTTPException(status_code=400, detail="Analysis already in progress")
background_tasks.add_task(run_analysis_task, request.session_id, [], request.message, is_followup=True)
return {"status": "started"}
@app.get("/api/status")
async def get_status(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
@@ -235,6 +273,27 @@ async def get_status(session_id: str = Query(..., description="Session ID")):
"report_path": session.generated_report
}
@app.get("/api/export")
async def export_session(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.output_dir or not os.path.exists(session.output_dir):
raise HTTPException(status_code=404, detail="No data available for export")
# Create a zip file
import shutil
# We want to zip the contents of session_output_dir
# Zip path should be outside to avoid recursive zipping if inside
zip_base_name = os.path.join("outputs", f"export_{session_id}")
# shutil.make_archive expects base_name (without extension) and root_dir
archive_path = shutil.make_archive(zip_base_name, 'zip', session.output_dir)
return FileResponse(archive_path, media_type='application/zip', filename=f"analysis_export_{session_id}.zip")
@app.get("/api/report")
async def get_report(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
@@ -250,8 +309,24 @@ async def get_report(session_id: str = Query(..., description="Session ID")):
# Fix image paths
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
web_base_path = f"/{relative_session_path}"
# Robust image path replacement
# 1. Replace explicit relative paths ./image.png
content = content.replace("](./", f"]({web_base_path}/")
# 2. Replace naked paths that might be generated like ](image.png) but NOT ](http...) or ](/...)
import re
def replace_link(match):
alt = match.group(1)
url = match.group(2)
if url.startswith("http") or url.startswith("/") or url.startswith("data:"):
return match.group(0)
# Remove ./ if exists again just in case
clean_url = url.lstrip("./")
return f"![{alt}]({web_base_path}/{clean_url})"
content = re.sub(r'!\[(.*?)\]\((.*?)\)', replace_link, content)
return {"content": content, "base_path": web_base_path}
@app.get("/api/figures")