修改前端显示逻辑
This commit is contained in:
113
web/main.py
113
web/main.py
@@ -42,6 +42,8 @@ class SessionData:
|
||||
self.generated_report: Optional[str] = None
|
||||
self.log_file: Optional[str] = None
|
||||
self.analysis_results: List[Dict] = [] # Store analysis results for gallery
|
||||
self.agent: Optional[DataAnalysisAgent] = None # Store the agent instance for follow-up
|
||||
|
||||
|
||||
class SessionManager:
|
||||
def __init__(self):
|
||||
@@ -72,7 +74,7 @@ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
|
||||
|
||||
# --- Helper Functions ---
|
||||
|
||||
def run_analysis_task(session_id: str, files: list, user_requirement: str):
|
||||
def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False):
|
||||
"""
|
||||
Runs the analysis agent in a background thread for a specific session.
|
||||
"""
|
||||
@@ -83,14 +85,13 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str):
|
||||
|
||||
session.is_running = True
|
||||
try:
|
||||
# Create session directory
|
||||
# Create session directory if not exists (for follow-up it should accept existing)
|
||||
base_output_dir = "outputs"
|
||||
# We enforce a specific directory naming convention or let the util handle it
|
||||
# ideally we map session_id to the directory
|
||||
# For now, let's use the standard utility but we might lose the direct mapping if not careful
|
||||
# Let's trust the return value
|
||||
session_output_dir = create_session_output_dir(base_output_dir, user_requirement)
|
||||
session.output_dir = session_output_dir
|
||||
|
||||
if not session.output_dir:
|
||||
session.output_dir = create_session_output_dir(base_output_dir, user_requirement)
|
||||
|
||||
session_output_dir = session.output_dir
|
||||
|
||||
# Initialize Log capturing
|
||||
session.log_file = os.path.join(session_output_dir, "process.log")
|
||||
@@ -125,8 +126,13 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str):
|
||||
# but try to capture to file?
|
||||
# Let's just write to the file.
|
||||
|
||||
with open(session.log_file, "w", encoding="utf-8") as f:
|
||||
f.write(f"--- Session {session_id} Started ---\n")
|
||||
# Let's just write to the file.
|
||||
|
||||
with open(session.log_file, "a" if is_followup else "w", encoding="utf-8") as f:
|
||||
if is_followup:
|
||||
f.write(f"\n--- Follow-up Session {session_id} Continued ---\n")
|
||||
else:
|
||||
f.write(f"--- Session {session_id} Started ---\n")
|
||||
|
||||
# We will create a custom print function that writes to the file
|
||||
# And monkeypatch builtins.print? No, that's too hacky.
|
||||
@@ -153,14 +159,30 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str):
|
||||
sys.stdout = logger # Global hijack!
|
||||
|
||||
try:
|
||||
llm_config = LLMConfig()
|
||||
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
|
||||
|
||||
result = agent.analyze(
|
||||
user_input=user_requirement,
|
||||
files=files,
|
||||
session_output_dir=session_output_dir
|
||||
)
|
||||
if not is_followup:
|
||||
llm_config = LLMConfig()
|
||||
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
|
||||
session.agent = agent
|
||||
|
||||
result = agent.analyze(
|
||||
user_input=user_requirement,
|
||||
files=files,
|
||||
session_output_dir=session_output_dir,
|
||||
reset_session=True
|
||||
)
|
||||
else:
|
||||
agent = session.agent
|
||||
if not agent:
|
||||
print("Error: Agent not initialized for follow-up.")
|
||||
return
|
||||
|
||||
result = agent.analyze(
|
||||
user_input=user_requirement,
|
||||
files=None,
|
||||
session_output_dir=session_output_dir,
|
||||
reset_session=False,
|
||||
max_rounds=10
|
||||
)
|
||||
|
||||
session.generated_report = result.get("report_file_path", None)
|
||||
session.analysis_results = result.get("analysis_results", [])
|
||||
@@ -185,6 +207,10 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str):
|
||||
class StartRequest(BaseModel):
|
||||
requirement: str
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
message: str
|
||||
|
||||
# --- API Endpoints ---
|
||||
|
||||
@app.get("/")
|
||||
@@ -214,9 +240,21 @@ async def start_analysis(request: StartRequest, background_tasks: BackgroundTask
|
||||
|
||||
files = [os.path.abspath(f) for f in files] # Only use absolute paths
|
||||
|
||||
background_tasks.add_task(run_analysis_task, session_id, files, request.requirement)
|
||||
background_tasks.add_task(run_analysis_task, session_id, files, request.requirement, is_followup=False)
|
||||
return {"status": "started", "session_id": session_id}
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def chat_analysis(request: ChatRequest, background_tasks: BackgroundTasks):
|
||||
session = session_manager.get_session(request.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if session.is_running:
|
||||
raise HTTPException(status_code=400, detail="Analysis already in progress")
|
||||
|
||||
background_tasks.add_task(run_analysis_task, request.session_id, [], request.message, is_followup=True)
|
||||
return {"status": "started"}
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status(session_id: str = Query(..., description="Session ID")):
|
||||
session = session_manager.get_session(session_id)
|
||||
@@ -235,6 +273,27 @@ async def get_status(session_id: str = Query(..., description="Session ID")):
|
||||
"report_path": session.generated_report
|
||||
}
|
||||
|
||||
@app.get("/api/export")
|
||||
async def export_session(session_id: str = Query(..., description="Session ID")):
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if not session.output_dir or not os.path.exists(session.output_dir):
|
||||
raise HTTPException(status_code=404, detail="No data available for export")
|
||||
|
||||
# Create a zip file
|
||||
import shutil
|
||||
|
||||
# We want to zip the contents of session_output_dir
|
||||
# Zip path should be outside to avoid recursive zipping if inside
|
||||
zip_base_name = os.path.join("outputs", f"export_{session_id}")
|
||||
|
||||
# shutil.make_archive expects base_name (without extension) and root_dir
|
||||
archive_path = shutil.make_archive(zip_base_name, 'zip', session.output_dir)
|
||||
|
||||
return FileResponse(archive_path, media_type='application/zip', filename=f"analysis_export_{session_id}.zip")
|
||||
|
||||
@app.get("/api/report")
|
||||
async def get_report(session_id: str = Query(..., description="Session ID")):
|
||||
session = session_manager.get_session(session_id)
|
||||
@@ -250,8 +309,24 @@ async def get_report(session_id: str = Query(..., description="Session ID")):
|
||||
# Fix image paths
|
||||
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
|
||||
web_base_path = f"/{relative_session_path}"
|
||||
|
||||
# Robust image path replacement
|
||||
# 1. Replace explicit relative paths ./image.png
|
||||
content = content.replace("](./", f"]({web_base_path}/")
|
||||
|
||||
# 2. Replace naked paths that might be generated like ](image.png) but NOT ](http...) or ](/...)
|
||||
import re
|
||||
def replace_link(match):
|
||||
alt = match.group(1)
|
||||
url = match.group(2)
|
||||
if url.startswith("http") or url.startswith("/") or url.startswith("data:"):
|
||||
return match.group(0)
|
||||
# Remove ./ if exists again just in case
|
||||
clean_url = url.lstrip("./")
|
||||
return f""
|
||||
|
||||
content = re.sub(r'!\[(.*?)\]\((.*?)\)', replace_link, content)
|
||||
|
||||
return {"content": content, "base_path": web_base_path}
|
||||
|
||||
@app.get("/api/figures")
|
||||
|
||||
Reference in New Issue
Block a user