Merge branch 'main' of http://jeason.online:3000/zhaojie/iov_data_analysis_agent
This commit is contained in:
536
web/main.py
536
web/main.py
@@ -19,11 +19,16 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from data_analysis_agent import DataAnalysisAgent
|
||||
from config.llm_config import LLMConfig
|
||||
from utils.create_session_dir import create_session_output_dir
|
||||
from config.llm_config import LLMConfig
|
||||
from utils.create_session_dir import create_session_output_dir
|
||||
from utils.logger import PrintCapture
|
||||
|
||||
app = FastAPI(title="IOV Data Analysis Agent")
|
||||
|
||||
|
||||
def _to_web_path(fs_path: str) -> str:
|
||||
"""将文件系统路径转为 URL 安全的正斜杠路径(修复 Windows 反斜杠问题)"""
|
||||
return fs_path.replace("\\", "/")
|
||||
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -168,8 +173,6 @@ class SessionManager:
|
||||
"max_rounds": session.max_rounds,
|
||||
"created_at": session.created_at,
|
||||
"last_updated": session.last_updated,
|
||||
"created_at": session.created_at,
|
||||
"last_updated": session.last_updated,
|
||||
"user_requirement": session.user_requirement[:100] + "..." if len(session.user_requirement) > 100 else session.user_requirement,
|
||||
"script_path": session.reusable_script # 新增:返回脚本路径
|
||||
}
|
||||
@@ -188,9 +191,7 @@ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
|
||||
# --- Helper Functions ---
|
||||
|
||||
def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False):
|
||||
"""
|
||||
Runs the analysis agent in a background thread for a specific session.
|
||||
"""
|
||||
"""在后台线程中运行分析任务"""
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
print(f"Error: Session {session_id} not found in background task.")
|
||||
@@ -198,119 +199,58 @@ def run_analysis_task(session_id: str, files: list, user_requirement: str, is_fo
|
||||
|
||||
session.is_running = True
|
||||
try:
|
||||
# Create session directory if not exists (for follow-up it should accept existing)
|
||||
base_output_dir = "outputs"
|
||||
|
||||
|
||||
if not session.output_dir:
|
||||
session.output_dir = create_session_output_dir(base_output_dir, user_requirement)
|
||||
|
||||
session.output_dir = create_session_output_dir(base_output_dir, user_requirement)
|
||||
|
||||
session_output_dir = session.output_dir
|
||||
|
||||
# Initialize Log capturing
|
||||
session.log_file = os.path.join(session_output_dir, "process.log")
|
||||
|
||||
# Thread-safe logging requires a bit of care.
|
||||
# Since we are running in a thread, redirecting sys.stdout globally is BAD for multi-session.
|
||||
# However, for this MVP, if we run multiple sessions concurrently, their logs will mix in stdout.
|
||||
# BUT we are writing to specific log files.
|
||||
# We need a logger that writes to the session's log file.
|
||||
# And the Agent needs to use that logger.
|
||||
# Currently the Agent uses print().
|
||||
# To support true concurrent logging without mixing, we'd need to refactor Agent to use a logger instance.
|
||||
# LIMITATION: For now, we accept that stdout redirection intercepts EVERYTHING.
|
||||
# So multiple concurrent sessions is risky with global stdout redirection.
|
||||
# A safer approach for now: We won't redirect stdout globally for multi-session support
|
||||
# unless we lock execution to one at a time.
|
||||
# OR: We just rely on the fact that we might only run one analysis at a time mostly.
|
||||
# Let's try to just write to the log file explicitly if we could, but we can't change Agent easily right now.
|
||||
# Compromise: We will continue to use global redirection but acknowledge it's not thread-safe for output.
|
||||
# A better way: Modify Agent to accept a 'log_callback'.
|
||||
# For this refactor, let's stick to the existing pattern but bind it to the thread if possible? No.
|
||||
|
||||
# We will wrap the execution with a simple File Logger that appends to the distinct file.
|
||||
# But sys.stdout is global.
|
||||
# We will assume single concurrent analysis for safety, or accept mixed terminal output but separate file logs?
|
||||
# Actually, if we swap sys.stdout, it affects all threads.
|
||||
# So we MUST NOT swap sys.stdout if we want concurrency.
|
||||
# If we don't swap stdout, we don't capture logs to file unless Agent does it.
|
||||
# The Agent code has `print`.
|
||||
# Correct fix: Refactor Agent to use `logging` module or pass a printer.
|
||||
# Given the scope, let's just hold the lock (serialize execution) OR allow mixing in terminal
|
||||
# but try to capture to file?
|
||||
# Let's just write to the file.
|
||||
|
||||
# Let's just write to the file.
|
||||
|
||||
with open(session.log_file, "a" if is_followup else "w", encoding="utf-8") as f:
|
||||
|
||||
# 使用 PrintCapture 替代全局 FileLogger,退出 with 块后自动恢复 stdout
|
||||
with PrintCapture(session.log_file):
|
||||
if is_followup:
|
||||
f.write(f"\n--- Follow-up Session {session_id} Continued ---\n")
|
||||
print(f"\n--- Follow-up Session {session_id} Continued ---")
|
||||
else:
|
||||
f.write(f"--- Session {session_id} Started ---\n")
|
||||
print(f"--- Session {session_id} Started ---")
|
||||
|
||||
# We will create a custom print function that writes to the file
|
||||
# And monkeypatch builtins.print? No, that's too hacky.
|
||||
# Let's just use the stdout redirector, but acknowledge only one active session at a time is safe.
|
||||
# We can implement a crude lock for now.
|
||||
|
||||
class FileLogger:
|
||||
def __init__(self, filename):
|
||||
self.terminal = sys.__stdout__
|
||||
self.log = open(filename, "a", encoding="utf-8", buffering=1)
|
||||
|
||||
def write(self, message):
|
||||
self.terminal.write(message)
|
||||
self.log.write(message)
|
||||
|
||||
def flush(self):
|
||||
self.terminal.flush()
|
||||
self.log.flush()
|
||||
|
||||
def close(self):
|
||||
self.log.close()
|
||||
try:
|
||||
if not is_followup:
|
||||
llm_config = LLMConfig()
|
||||
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
|
||||
session.agent = agent
|
||||
|
||||
logger = FileLogger(session.log_file)
|
||||
sys.stdout = logger # Global hijack!
|
||||
|
||||
try:
|
||||
if not is_followup:
|
||||
llm_config = LLMConfig()
|
||||
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
|
||||
session.agent = agent
|
||||
|
||||
result = agent.analyze(
|
||||
user_input=user_requirement,
|
||||
files=files,
|
||||
session_output_dir=session_output_dir,
|
||||
reset_session=True
|
||||
)
|
||||
else:
|
||||
agent = session.agent
|
||||
if not agent:
|
||||
print("Error: Agent not initialized for follow-up.")
|
||||
return
|
||||
result = agent.analyze(
|
||||
user_input=user_requirement,
|
||||
files=files,
|
||||
session_output_dir=session_output_dir,
|
||||
reset_session=True,
|
||||
)
|
||||
else:
|
||||
agent = session.agent
|
||||
if not agent:
|
||||
print("Error: Agent not initialized for follow-up.")
|
||||
return
|
||||
|
||||
result = agent.analyze(
|
||||
user_input=user_requirement,
|
||||
files=None,
|
||||
session_output_dir=session_output_dir,
|
||||
reset_session=False,
|
||||
max_rounds=10,
|
||||
)
|
||||
|
||||
session.generated_report = result.get("report_file_path", None)
|
||||
session.analysis_results = result.get("analysis_results", [])
|
||||
session.reusable_script = result.get("reusable_script_path", None)
|
||||
|
||||
# 持久化结果
|
||||
with open(os.path.join(session_output_dir, "results.json"), "w") as f:
|
||||
json.dump(session.analysis_results, f, default=str)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during analysis: {e}")
|
||||
|
||||
result = agent.analyze(
|
||||
user_input=user_requirement,
|
||||
files=None,
|
||||
session_output_dir=session_output_dir,
|
||||
reset_session=False,
|
||||
max_rounds=10
|
||||
)
|
||||
|
||||
session.generated_report = result.get("report_file_path", None)
|
||||
session.analysis_results = result.get("analysis_results", [])
|
||||
session.reusable_script = result.get("reusable_script_path", None) # 新增:保存脚本路径
|
||||
|
||||
# Save results to json for persistence
|
||||
with open(os.path.join(session_output_dir, "results.json"), "w") as f:
|
||||
json.dump(session.analysis_results, f, default=str)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during analysis: {e}")
|
||||
finally:
|
||||
sys.stdout = logger.terminal
|
||||
logger.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"System Error: {e}")
|
||||
finally:
|
||||
@@ -390,24 +330,37 @@ async def get_status(session_id: str = Query(..., description="Session ID")):
|
||||
|
||||
@app.get("/api/export")
|
||||
async def export_session(session_id: str = Query(..., description="Session ID")):
|
||||
"""导出会话数据为ZIP"""
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if not session.output_dir or not os.path.exists(session.output_dir):
|
||||
raise HTTPException(status_code=404, detail="No data available for export")
|
||||
|
||||
# Create a zip file
|
||||
import shutil
|
||||
|
||||
# We want to zip the contents of session_output_dir
|
||||
# Zip path should be outside to avoid recursive zipping if inside
|
||||
zip_base_name = os.path.join("outputs", f"export_{session_id}")
|
||||
import zipfile
|
||||
from datetime import datetime as dt
|
||||
|
||||
# shutil.make_archive expects base_name (without extension) and root_dir
|
||||
archive_path = shutil.make_archive(zip_base_name, 'zip', session.output_dir)
|
||||
timestamp = dt.now().strftime("%Y%m%d_%H%M%S")
|
||||
zip_filename = f"report_{timestamp}.zip"
|
||||
|
||||
return FileResponse(archive_path, media_type='application/zip', filename=f"analysis_export_{session_id}.zip")
|
||||
export_dir = "outputs"
|
||||
os.makedirs(export_dir, exist_ok=True)
|
||||
temp_zip_path = os.path.join(export_dir, zip_filename)
|
||||
|
||||
with zipfile.ZipFile(temp_zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for root, dirs, files in os.walk(session.output_dir):
|
||||
for file in files:
|
||||
if file.endswith(('.md', '.png', '.csv', '.log', '.json', '.yaml')):
|
||||
abs_path = os.path.join(root, file)
|
||||
rel_path = os.path.relpath(abs_path, session.output_dir)
|
||||
zf.write(abs_path, arcname=rel_path)
|
||||
|
||||
return FileResponse(
|
||||
path=temp_zip_path,
|
||||
filename=zip_filename,
|
||||
media_type='application/zip'
|
||||
)
|
||||
|
||||
@app.get("/api/report")
|
||||
async def get_report(session_id: str = Query(..., description="Session ID")):
|
||||
@@ -416,33 +369,33 @@ async def get_report(session_id: str = Query(..., description="Session ID")):
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if not session.generated_report or not os.path.exists(session.generated_report):
|
||||
return {"content": "Report not ready."}
|
||||
return {"content": "Report not ready.", "paragraphs": []}
|
||||
|
||||
with open(session.generated_report, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Fix image paths
|
||||
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
|
||||
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
|
||||
web_base_path = f"/{relative_session_path}"
|
||||
|
||||
# Robust image path replacement
|
||||
# 1. Replace explicit relative paths ./image.png
|
||||
content = content.replace("](./", f"]({web_base_path}/")
|
||||
|
||||
# 2. Replace naked paths that might be generated like ](image.png) but NOT ](http...) or ](/...)
|
||||
import re
|
||||
def replace_link(match):
|
||||
alt = match.group(1)
|
||||
url = match.group(2)
|
||||
if url.startswith("http") or url.startswith("/") or url.startswith("data:"):
|
||||
return match.group(0)
|
||||
# Remove ./ if exists again just in case
|
||||
clean_url = url.lstrip("./")
|
||||
return f""
|
||||
|
||||
content = re.sub(r'!\[(.*?)\]\((.*?)\)', replace_link, content)
|
||||
|
||||
# 将报告按段落拆分,为前端润色功能提供结构化数据
|
||||
paragraphs = _split_report_to_paragraphs(content)
|
||||
|
||||
return {"content": content, "base_path": web_base_path}
|
||||
return {"content": content, "base_path": web_base_path, "paragraphs": paragraphs}
|
||||
|
||||
@app.get("/api/figures")
|
||||
async def get_figures(session_id: str = Query(..., description="Session ID")):
|
||||
@@ -473,7 +426,7 @@ async def get_figures(session_id: str = Query(..., description="Session ID")):
|
||||
if session.output_dir:
|
||||
# Assume filename is present
|
||||
fname = fig.get("filename")
|
||||
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
|
||||
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
|
||||
fig["web_url"] = f"/{relative_session_path}/{fname}"
|
||||
figures.append(fig)
|
||||
|
||||
@@ -486,7 +439,7 @@ async def get_figures(session_id: str = Query(..., description="Session ID")):
|
||||
pngs = glob.glob(os.path.join(session.output_dir, "*.png"))
|
||||
for p in pngs:
|
||||
fname = os.path.basename(p)
|
||||
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
|
||||
relative_session_path = _to_web_path(os.path.relpath(session.output_dir, os.getcwd()))
|
||||
figures.append({
|
||||
"filename": fname,
|
||||
"description": "Auto-discovered image",
|
||||
@@ -496,37 +449,6 @@ async def get_figures(session_id: str = Query(..., description="Session ID")):
|
||||
|
||||
return {"figures": figures}
|
||||
|
||||
@app.get("/api/export")
|
||||
async def export_report(session_id: str = Query(..., description="Session ID")):
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session or not session.output_dir:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
import zipfile
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
zip_filename = f"report_{timestamp}.zip"
|
||||
|
||||
export_dir = "outputs"
|
||||
os.makedirs(export_dir, exist_ok=True)
|
||||
temp_zip_path = os.path.join(export_dir, zip_filename)
|
||||
|
||||
with zipfile.ZipFile(temp_zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for root, dirs, files in os.walk(session.output_dir):
|
||||
for file in files:
|
||||
if file.endswith(('.md', '.png', '.csv', '.log', '.json', '.yaml')):
|
||||
abs_path = os.path.join(root, file)
|
||||
rel_path = os.path.relpath(abs_path, session.output_dir)
|
||||
zf.write(abs_path, arcname=rel_path)
|
||||
|
||||
return FileResponse(
|
||||
path=temp_zip_path,
|
||||
filename=zip_filename,
|
||||
media_type='application/zip'
|
||||
)
|
||||
|
||||
@app.get("/api/download_script")
|
||||
async def download_script(session_id: str = Query(..., description="Session ID")):
|
||||
"""下载生成的Python脚本"""
|
||||
@@ -580,7 +502,301 @@ async def delete_specific_session(session_id: str):
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return {"status": "deleted", "session_id": session_id}
|
||||
|
||||
return {"status": "deleted", "session_id": session_id}
|
||||
|
||||
# --- Report Polishing API ---
|
||||
|
||||
import re as _re
|
||||
|
||||
def _split_report_to_paragraphs(markdown_content: str) -> list:
|
||||
"""
|
||||
将 Markdown 报告按语义段落拆分。
|
||||
每个段落包含 id、类型(heading/text/table/image)、原始内容。
|
||||
前端可据此实现段落级选择与润色。
|
||||
"""
|
||||
lines = markdown_content.split("\n")
|
||||
paragraphs = []
|
||||
current_block = []
|
||||
current_type = "text"
|
||||
para_id = 0
|
||||
|
||||
def flush_block():
|
||||
nonlocal para_id, current_block, current_type
|
||||
text = "\n".join(current_block).strip()
|
||||
if text:
|
||||
paragraphs.append({
|
||||
"id": f"p-{para_id}",
|
||||
"type": current_type,
|
||||
"content": text,
|
||||
})
|
||||
para_id += 1
|
||||
current_block = []
|
||||
current_type = "text"
|
||||
|
||||
in_table = False
|
||||
in_code = False
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
# 代码块边界
|
||||
if stripped.startswith("```"):
|
||||
if in_code:
|
||||
current_block.append(line)
|
||||
flush_block()
|
||||
in_code = False
|
||||
continue
|
||||
else:
|
||||
flush_block()
|
||||
current_block.append(line)
|
||||
current_type = "code"
|
||||
in_code = True
|
||||
continue
|
||||
|
||||
if in_code:
|
||||
current_block.append(line)
|
||||
continue
|
||||
|
||||
# 标题行 — 独立成段
|
||||
if _re.match(r"^#{1,6}\s", stripped):
|
||||
flush_block()
|
||||
current_block.append(line)
|
||||
current_type = "heading"
|
||||
flush_block()
|
||||
continue
|
||||
|
||||
# 图片行
|
||||
if _re.match(r"^!\[.*\]\(.*\)", stripped):
|
||||
flush_block()
|
||||
current_block.append(line)
|
||||
current_type = "image"
|
||||
flush_block()
|
||||
continue
|
||||
|
||||
# 表格行
|
||||
if stripped.startswith("|"):
|
||||
if not in_table:
|
||||
flush_block()
|
||||
in_table = True
|
||||
current_type = "table"
|
||||
current_block.append(line)
|
||||
continue
|
||||
else:
|
||||
if in_table:
|
||||
flush_block()
|
||||
in_table = False
|
||||
|
||||
# 空行 — 段落分隔
|
||||
if not stripped:
|
||||
flush_block()
|
||||
continue
|
||||
|
||||
# 普通文本
|
||||
current_block.append(line)
|
||||
|
||||
flush_block()
|
||||
return paragraphs
|
||||
|
||||
|
||||
class PolishRequest(BaseModel):
|
||||
session_id: str
|
||||
paragraph_id: str
|
||||
mode: str = "context" # "context" | "data" | "custom"
|
||||
custom_instruction: str = ""
|
||||
|
||||
|
||||
@app.post("/api/report/polish")
|
||||
async def polish_paragraph(request: PolishRequest):
|
||||
"""
|
||||
对报告中指定段落进行 AI 润色。
|
||||
|
||||
mode:
|
||||
- context: 根据上下文和图表信息润色,使表述更专业、更有洞察
|
||||
- data: 结合原始分析数据重新生成该段落内容
|
||||
- custom: 用户自定义润色指令
|
||||
"""
|
||||
session = session_manager.get_session(request.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if not session.generated_report or not os.path.exists(session.generated_report):
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
|
||||
# 读取报告并拆分段落
|
||||
with open(session.generated_report, "r", encoding="utf-8") as f:
|
||||
report_content = f.read()
|
||||
|
||||
paragraphs = _split_report_to_paragraphs(report_content)
|
||||
|
||||
# 找到目标段落
|
||||
target = None
|
||||
target_idx = -1
|
||||
for i, p in enumerate(paragraphs):
|
||||
if p["id"] == request.paragraph_id:
|
||||
target = p
|
||||
target_idx = i
|
||||
break
|
||||
|
||||
if not target:
|
||||
raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found")
|
||||
|
||||
# 构建上下文窗口(前后各2个段落)
|
||||
context_window = []
|
||||
for j in range(max(0, target_idx - 2), min(len(paragraphs), target_idx + 3)):
|
||||
if j != target_idx:
|
||||
context_window.append(paragraphs[j]["content"])
|
||||
context_text = "\n\n".join(context_window)
|
||||
|
||||
# 收集图表信息
|
||||
figures_info = ""
|
||||
if session.analysis_results:
|
||||
fig_parts = []
|
||||
for item in session.analysis_results:
|
||||
if item.get("action") == "collect_figures":
|
||||
for fig in item.get("collected_figures", []):
|
||||
fig_parts.append(f"- {fig.get('filename', '?')}: {fig.get('description', '')} / {fig.get('analysis', '')}")
|
||||
if fig_parts:
|
||||
figures_info = "\n".join(fig_parts)
|
||||
|
||||
# 构建润色 prompt
|
||||
if request.mode == "data":
|
||||
# 收集代码执行结果摘要
|
||||
data_summary_parts = []
|
||||
for item in session.analysis_results:
|
||||
result = item.get("result", {})
|
||||
if result.get("success") and result.get("output"):
|
||||
output_text = result["output"][:2000]
|
||||
data_summary_parts.append(output_text)
|
||||
data_summary = "\n---\n".join(data_summary_parts[:5])
|
||||
|
||||
polish_prompt = f"""你是一位资深数据分析专家。请基于以下分析数据,重写下方段落,使其包含更精确的数据引用和更深入的业务洞察。
|
||||
|
||||
## 分析数据摘要
|
||||
{data_summary}
|
||||
|
||||
## 图表信息
|
||||
{figures_info}
|
||||
|
||||
## 需要润色的段落
|
||||
{target['content']}
|
||||
|
||||
## 要求
|
||||
- 保持原有的 Markdown 格式(标题级别、表格结构等)
|
||||
- 用具体数据替换模糊描述
|
||||
- 增加业务洞察和趋势判断
|
||||
- 禁止使用第一人称
|
||||
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
|
||||
|
||||
elif request.mode == "custom":
|
||||
polish_prompt = f"""你是一位资深数据分析专家。请根据用户的指令润色以下段落。
|
||||
|
||||
## 用户指令
|
||||
{request.custom_instruction}
|
||||
|
||||
## 上下文
|
||||
{context_text}
|
||||
|
||||
## 图表信息
|
||||
{figures_info}
|
||||
|
||||
## 需要润色的段落
|
||||
{target['content']}
|
||||
|
||||
## 要求
|
||||
- 保持原有的 Markdown 格式
|
||||
- 严格遵循用户指令
|
||||
- 禁止使用第一人称
|
||||
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
|
||||
|
||||
else: # context mode
|
||||
polish_prompt = f"""你是一位资深数据分析专家。请润色以下段落,使其表述更专业、更有洞察力。
|
||||
|
||||
## 上下文(前后段落)
|
||||
{context_text}
|
||||
|
||||
## 图表信息
|
||||
{figures_info}
|
||||
|
||||
## 需要润色的段落
|
||||
{target['content']}
|
||||
|
||||
## 要求
|
||||
- 保持原有的 Markdown 格式(标题级别、表格结构等)
|
||||
- 提升专业性:使用同比、环比、占比等术语
|
||||
- 增加洞察:不仅描述现象,还要分析原因和影响
|
||||
- 禁止使用第一人称
|
||||
- 直接输出润色后的 Markdown 内容,不要包裹在代码块中"""
|
||||
|
||||
# 调用 LLM 润色
|
||||
try:
|
||||
from utils.llm_helper import LLMHelper
|
||||
llm = LLMHelper(LLMConfig())
|
||||
polished_content = llm.call(
|
||||
prompt=polish_prompt,
|
||||
system_prompt="你是一位专业的数据分析报告润色专家。直接输出润色后的内容,不要添加任何解释或包裹。",
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
# 清理可能的代码块包裹
|
||||
polished_content = polished_content.strip()
|
||||
if polished_content.startswith("```markdown"):
|
||||
polished_content = polished_content[len("```markdown"):].strip()
|
||||
if polished_content.startswith("```"):
|
||||
polished_content = polished_content[3:].strip()
|
||||
if polished_content.endswith("```"):
|
||||
polished_content = polished_content[:-3].strip()
|
||||
|
||||
return {
|
||||
"paragraph_id": request.paragraph_id,
|
||||
"original": target["content"],
|
||||
"polished": polished_content,
|
||||
"mode": request.mode,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Polish failed: {str(e)}")
|
||||
|
||||
|
||||
class ApplyPolishRequest(BaseModel):
|
||||
session_id: str
|
||||
paragraph_id: str
|
||||
new_content: str
|
||||
|
||||
|
||||
@app.post("/api/report/apply")
|
||||
async def apply_polish(request: ApplyPolishRequest):
|
||||
"""
|
||||
将润色后的内容应用到报告文件中,替换指定段落。
|
||||
"""
|
||||
session = session_manager.get_session(request.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if not session.generated_report or not os.path.exists(session.generated_report):
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
|
||||
with open(session.generated_report, "r", encoding="utf-8") as f:
|
||||
report_content = f.read()
|
||||
|
||||
paragraphs = _split_report_to_paragraphs(report_content)
|
||||
|
||||
# 找到目标段落并替换
|
||||
target = None
|
||||
for p in paragraphs:
|
||||
if p["id"] == request.paragraph_id:
|
||||
target = p
|
||||
break
|
||||
|
||||
if not target:
|
||||
raise HTTPException(status_code=404, detail=f"Paragraph {request.paragraph_id} not found")
|
||||
|
||||
# 在原文中替换
|
||||
new_report = report_content.replace(target["content"], request.new_content, 1)
|
||||
|
||||
# 写回文件
|
||||
with open(session.generated_report, "w", encoding="utf-8") as f:
|
||||
f.write(new_report)
|
||||
|
||||
return {"status": "applied", "paragraph_id": request.paragraph_id}
|
||||
|
||||
|
||||
# --- History API ---
|
||||
|
||||
Reference in New Issue
Block a user