大更新,架构调整,数据分析能力提升,
This commit is contained in:
@@ -10,20 +10,33 @@
|
||||
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import yaml
|
||||
from typing import Dict, Any, List, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from utils.create_session_dir import create_session_output_dir
|
||||
from utils.format_execution_result import format_execution_result
|
||||
from utils.extract_code import extract_code_from_response
|
||||
from utils.data_loader import load_and_profile_data
|
||||
from utils.data_loader import load_and_profile_data, load_data_chunked, load_and_profile_data_smart
|
||||
from utils.llm_helper import LLMHelper
|
||||
from utils.code_executor import CodeExecutor
|
||||
from utils.script_generator import generate_reusable_script
|
||||
from utils.data_privacy import build_safe_profile, build_local_profile, sanitize_execution_feedback
|
||||
from utils.data_privacy import build_safe_profile, build_local_profile, sanitize_execution_feedback, generate_enriched_hint
|
||||
from config.llm_config import LLMConfig
|
||||
from config.app_config import app_config
|
||||
from prompts import data_analysis_system_prompt, final_report_system_prompt, data_analysis_followup_prompt
|
||||
|
||||
|
||||
# Regex patterns that indicate a data-context error (column/variable/DataFrame issues)
|
||||
DATA_CONTEXT_PATTERNS = [
|
||||
r"KeyError:\s*['\"](.+?)['\"]",
|
||||
r"ValueError.*(?:column|col|field)",
|
||||
r"NameError.*(?:df|data|frame)",
|
||||
r"(?:empty|no\s+data|0\s+rows)",
|
||||
r"IndexError.*(?:out of range|out of bounds)",
|
||||
]
|
||||
|
||||
|
||||
class DataAnalysisAgent:
|
||||
"""
|
||||
数据分析智能体
|
||||
@@ -66,6 +79,53 @@ class DataAnalysisAgent:
|
||||
self.data_profile_safe = "" # 存储安全画像(发给LLM)
|
||||
self.data_files = [] # 存储数据文件列表
|
||||
self.user_requirement = "" # 存储用户需求
|
||||
self._progress_callback = None # 进度回调函数
|
||||
self._session_ref = None # Reference to SessionData for round tracking
|
||||
|
||||
def set_session_ref(self, session):
|
||||
"""Set a reference to the SessionData instance for appending round data.
|
||||
|
||||
Args:
|
||||
session: The SessionData instance for the current analysis session.
|
||||
"""
|
||||
self._session_ref = session
|
||||
|
||||
def set_progress_callback(self, callback):
|
||||
"""Set a callback function(current_round, max_rounds, message) for progress updates."""
|
||||
self._progress_callback = callback
|
||||
|
||||
def _summarize_result(self, result: Dict[str, Any]) -> str:
|
||||
"""Produce a one-line summary from a code execution result.
|
||||
|
||||
Args:
|
||||
result: The execution result dict from CodeExecutor.
|
||||
|
||||
Returns:
|
||||
A concise summary string, e.g. "执行成功,输出 DataFrame (150行×8列)"
|
||||
or "执行失败: KeyError: 'col_x'".
|
||||
"""
|
||||
if result.get("success"):
|
||||
evidence_rows = result.get("evidence_rows", [])
|
||||
if evidence_rows:
|
||||
num_rows = len(evidence_rows)
|
||||
num_cols = len(evidence_rows[0]) if evidence_rows else 0
|
||||
# Check auto_exported_files for more accurate row/col counts
|
||||
auto_files = result.get("auto_exported_files", [])
|
||||
if auto_files:
|
||||
last_file = auto_files[-1]
|
||||
num_rows = last_file.get("rows", num_rows)
|
||||
num_cols = last_file.get("cols", num_cols)
|
||||
return f"执行成功,输出 DataFrame ({num_rows}行×{num_cols}列)"
|
||||
output = result.get("output", "")
|
||||
if output:
|
||||
first_line = output.strip().split("\n")[0][:80]
|
||||
return f"执行成功: {first_line}"
|
||||
return "执行成功"
|
||||
else:
|
||||
error = result.get("error", "未知错误")
|
||||
if len(error) > 100:
|
||||
error = error[:100] + "..."
|
||||
return f"执行失败: {error}"
|
||||
|
||||
def _process_response(self, response: str) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -181,6 +241,7 @@ class DataAnalysisAgent:
|
||||
"""处理代码生成和执行动作"""
|
||||
# 从YAML数据中获取代码(更准确)
|
||||
code = yaml_data.get("code", "")
|
||||
reasoning = yaml_data.get("reasoning", "")
|
||||
|
||||
# 如果YAML中没有代码,尝试从响应中提取
|
||||
if not code:
|
||||
@@ -190,7 +251,6 @@ class DataAnalysisAgent:
|
||||
if code:
|
||||
code = code.strip()
|
||||
if code.startswith("```"):
|
||||
import re
|
||||
# 去除开头的 ```python 或 ```
|
||||
code = re.sub(r"^```[a-zA-Z]*\n", "", code)
|
||||
# 去除结尾的 ```
|
||||
@@ -211,6 +271,7 @@ class DataAnalysisAgent:
|
||||
return {
|
||||
"action": "generate_code",
|
||||
"code": code,
|
||||
"reasoning": reasoning,
|
||||
"result": result,
|
||||
"feedback": feedback,
|
||||
"response": response,
|
||||
@@ -221,12 +282,146 @@ class DataAnalysisAgent:
|
||||
print("[WARN] 未从响应中提取到可执行代码,要求LLM重新生成")
|
||||
return {
|
||||
"action": "invalid_response",
|
||||
"reasoning": reasoning,
|
||||
"error": "响应中缺少可执行代码",
|
||||
"response": response,
|
||||
"continue": True,
|
||||
}
|
||||
|
||||
def analyze(self, user_input: str, files: List[str] = None, session_output_dir: str = None, reset_session: bool = True, max_rounds: int = None) -> Dict[str, Any]:
|
||||
def _classify_error(self, error_message: str) -> str:
|
||||
"""Classify execution error as data-context or other.
|
||||
|
||||
Inspects the error message against DATA_CONTEXT_PATTERNS to determine
|
||||
if the error is related to data context (missing columns, undefined
|
||||
data variables, empty DataFrames, etc.).
|
||||
|
||||
Args:
|
||||
error_message: The error message string from code execution.
|
||||
|
||||
Returns:
|
||||
"data_context" if the error matches a data-context pattern,
|
||||
"other" otherwise.
|
||||
"""
|
||||
for pattern in DATA_CONTEXT_PATTERNS:
|
||||
if re.search(pattern, error_message, re.IGNORECASE):
|
||||
return "data_context"
|
||||
return "other"
|
||||
|
||||
def _trim_conversation_history(self):
|
||||
"""Apply sliding window trimming to conversation history.
|
||||
|
||||
Retains the first user message (original requirement + Safe_Profile) at
|
||||
index 0, generates a compressed summary of old messages, and keeps only
|
||||
the most recent ``conversation_window_size`` message pairs in full.
|
||||
"""
|
||||
window_size = app_config.conversation_window_size
|
||||
max_messages = window_size * 2 # pairs of user+assistant messages
|
||||
|
||||
if len(self.conversation_history) <= max_messages:
|
||||
return # No trimming needed
|
||||
|
||||
first_message = self.conversation_history[0] # Always retain
|
||||
|
||||
# Determine trim boundary: skip first message + possible existing summary
|
||||
start_idx = 1
|
||||
has_existing_summary = (
|
||||
len(self.conversation_history) > 1
|
||||
and self.conversation_history[1]["role"] == "user"
|
||||
and self.conversation_history[1]["content"].startswith("[分析摘要]")
|
||||
)
|
||||
if has_existing_summary:
|
||||
start_idx = 2
|
||||
|
||||
# Messages to trim vs keep
|
||||
messages_to_consider = self.conversation_history[start_idx:]
|
||||
messages_to_trim = messages_to_consider[:-max_messages]
|
||||
messages_to_keep = messages_to_consider[-max_messages:]
|
||||
|
||||
if not messages_to_trim:
|
||||
return
|
||||
|
||||
# Generate summary of trimmed messages
|
||||
summary = self._compress_trimmed_messages(messages_to_trim)
|
||||
|
||||
# Rebuild history: first_message + summary + recent messages
|
||||
self.conversation_history = [first_message]
|
||||
if summary:
|
||||
self.conversation_history.append({"role": "user", "content": summary})
|
||||
self.conversation_history.extend(messages_to_keep)
|
||||
|
||||
def _compress_trimmed_messages(self, messages: list) -> str:
|
||||
"""Compress trimmed messages into a concise summary string.
|
||||
|
||||
Extracts the action type from each assistant message and the execution
|
||||
outcome (success / failure) from the subsequent user feedback message.
|
||||
Code blocks and raw execution output are excluded.
|
||||
|
||||
Args:
|
||||
messages: List of conversation message dicts to compress.
|
||||
|
||||
Returns:
|
||||
A summary string prefixed with ``[分析摘要]``.
|
||||
"""
|
||||
summary_parts = ["[分析摘要] 以下是之前分析轮次的概要:"]
|
||||
round_num = 0
|
||||
|
||||
for msg in messages:
|
||||
content = msg["content"]
|
||||
if msg["role"] == "assistant":
|
||||
round_num += 1
|
||||
# Extract action type from YAML-like content
|
||||
action = "generate_code"
|
||||
if "action: \"collect_figures\"" in content or "action: collect_figures" in content:
|
||||
action = "collect_figures"
|
||||
elif "action: \"analysis_complete\"" in content or "action: analysis_complete" in content:
|
||||
action = "analysis_complete"
|
||||
summary_parts.append(f"- 轮次{round_num}: 动作={action}")
|
||||
elif msg["role"] == "user" and "代码执行反馈" in content:
|
||||
success = "失败" if "[ERROR]" in content or "执行错误" in content else "成功"
|
||||
if summary_parts and summary_parts[-1].startswith("- 轮次"):
|
||||
summary_parts[-1] += f", 执行结果={success}"
|
||||
|
||||
return "\n".join(summary_parts)
|
||||
|
||||
def _profile_files_parallel(self, file_paths: list) -> tuple:
|
||||
"""Profile multiple files concurrently using ThreadPoolExecutor.
|
||||
|
||||
Each file is profiled independently via ``build_safe_profile`` and
|
||||
``build_local_profile``. Results are collected and merged. If any
|
||||
individual file fails, an error entry is included for that file and
|
||||
profiling continues for the remaining files.
|
||||
|
||||
Args:
|
||||
file_paths: List of file paths to profile.
|
||||
|
||||
Returns:
|
||||
A tuple ``(safe_profile, local_profile)`` of merged markdown strings.
|
||||
"""
|
||||
max_workers = app_config.max_parallel_profiles
|
||||
safe_profiles = []
|
||||
local_profiles = []
|
||||
|
||||
def profile_single(path):
|
||||
safe = build_safe_profile([path])
|
||||
local = build_local_profile([path])
|
||||
return path, safe, local
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {executor.submit(profile_single, p): p for p in file_paths}
|
||||
for future in as_completed(futures):
|
||||
path = futures[future]
|
||||
try:
|
||||
_, safe, local = future.result()
|
||||
safe_profiles.append(safe)
|
||||
local_profiles.append(local)
|
||||
except Exception as e:
|
||||
error_entry = f"## 文件: {os.path.basename(path)}\n[ERROR] 分析失败: {e}\n\n"
|
||||
safe_profiles.append(error_entry)
|
||||
local_profiles.append(error_entry)
|
||||
|
||||
return "\n".join(safe_profiles), "\n".join(local_profiles)
|
||||
|
||||
def analyze(self, user_input: str, files: List[str] = None, session_output_dir: str = None, reset_session: bool = True, max_rounds: int = None, template_name: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
开始分析流程
|
||||
|
||||
@@ -236,6 +431,7 @@ class DataAnalysisAgent:
|
||||
session_output_dir: 指定的会话输出目录(可选)
|
||||
reset_session: 是否重置会话 (True: 新开启分析; False: 在现有上下文中继续)
|
||||
max_rounds: 本次分析的最大轮数 (可选,如果不填则使用默认值)
|
||||
template_name: 分析模板名称 (可选,如果提供则使用模板引导分析)
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
@@ -244,6 +440,13 @@ class DataAnalysisAgent:
|
||||
# 确定本次运行的轮数限制
|
||||
current_max_rounds = max_rounds if max_rounds is not None else self.max_rounds
|
||||
|
||||
# Template integration: prepend template prompt to user input if provided
|
||||
if template_name:
|
||||
from utils.analysis_templates import get_template
|
||||
template = get_template(template_name) # Raises ValueError if invalid
|
||||
template_prompt = template.get_full_prompt()
|
||||
user_input = f"{template_prompt}\n\n{user_input}"
|
||||
|
||||
if reset_session:
|
||||
# --- 初始化新会话 ---
|
||||
self.conversation_history = []
|
||||
@@ -272,11 +475,28 @@ class DataAnalysisAgent:
|
||||
if files:
|
||||
print("[SEARCH] 正在生成数据画像...")
|
||||
try:
|
||||
data_profile_safe = build_safe_profile(files)
|
||||
data_profile_local = build_local_profile(files)
|
||||
if len(files) > 1:
|
||||
# Parallel profiling for multiple files
|
||||
data_profile_safe, data_profile_local = self._profile_files_parallel(files)
|
||||
else:
|
||||
data_profile_safe = build_safe_profile(files)
|
||||
data_profile_local = build_local_profile(files)
|
||||
print("[OK] 数据画像生成完毕(安全级 + 本地级)")
|
||||
except Exception as e:
|
||||
print(f"[WARN] 数据画像生成失败: {e}")
|
||||
|
||||
# Expose chunked iterators for large files in the Code_Executor namespace
|
||||
for fp in files:
|
||||
try:
|
||||
if os.path.exists(fp):
|
||||
file_size_mb = os.path.getsize(fp) / (1024 * 1024)
|
||||
if file_size_mb > app_config.max_file_size_mb:
|
||||
var_name = "chunked_iter_" + os.path.splitext(os.path.basename(fp))[0]
|
||||
# Store a factory so the iterator can be re-created
|
||||
self.executor.set_variable(var_name, lambda p=fp: load_data_chunked(p))
|
||||
print(f"[OK] 大文件 {os.path.basename(fp)} 的分块迭代器已注入为 {var_name}()")
|
||||
except Exception as e:
|
||||
print(f"[WARN] 注入分块迭代器失败 ({os.path.basename(fp)}): {e}")
|
||||
|
||||
# 安全画像发给LLM,完整画像留给最终报告生成
|
||||
self.data_profile = data_profile_local # 本地完整版用于最终报告
|
||||
@@ -327,9 +547,23 @@ class DataAnalysisAgent:
|
||||
|
||||
# 初始化连续失败计数器
|
||||
consecutive_failures = 0
|
||||
# Per-round data-context retry counter
|
||||
data_context_retries = 0
|
||||
last_retry_round = 0
|
||||
|
||||
while self.current_round < self.max_rounds:
|
||||
self.current_round += 1
|
||||
# Notify progress callback
|
||||
if self._progress_callback:
|
||||
self._progress_callback(self.current_round, self.max_rounds, f"第{self.current_round}/{self.max_rounds}轮分析中...")
|
||||
# Reset data-context retry counter when entering a new round
|
||||
if self.current_round != last_retry_round:
|
||||
data_context_retries = 0
|
||||
|
||||
# Trim conversation history after the first round to bound token usage
|
||||
if self.current_round > 1:
|
||||
self._trim_conversation_history()
|
||||
|
||||
print(f"\n[LOOP] 第 {self.current_round} 轮分析")
|
||||
# 调用LLM生成响应
|
||||
try: # 获取当前执行环境的变量信息
|
||||
@@ -388,7 +622,40 @@ class DataAnalysisAgent:
|
||||
# 根据动作类型添加不同的反馈
|
||||
if process_result["action"] == "generate_code":
|
||||
feedback = process_result.get("feedback", "")
|
||||
# 对执行反馈进行脱敏,移除真实数据值后再发给LLM
|
||||
result = process_result.get("result", {})
|
||||
execution_failed = not result.get("success", True)
|
||||
|
||||
# --- Data-context retry logic ---
|
||||
if execution_failed:
|
||||
error_output = result.get("error", "") or feedback
|
||||
error_class = self._classify_error(error_output)
|
||||
|
||||
if error_class == "data_context" and data_context_retries < app_config.max_data_context_retries:
|
||||
data_context_retries += 1
|
||||
last_retry_round = self.current_round
|
||||
print(f"[RETRY] 数据上下文错误,重试 {data_context_retries}/{app_config.max_data_context_retries}")
|
||||
# Generate enriched hint from safe profile
|
||||
enriched_hint = generate_enriched_hint(error_output, self.data_profile_safe)
|
||||
# Add enriched hint to conversation history (assistant response already added above)
|
||||
self.conversation_history.append(
|
||||
{"role": "user", "content": enriched_hint}
|
||||
)
|
||||
# Record the failed attempt
|
||||
self.analysis_results.append(
|
||||
{
|
||||
"round": self.current_round,
|
||||
"code": process_result.get("code", ""),
|
||||
"result": result,
|
||||
"response": response,
|
||||
"retry": True,
|
||||
}
|
||||
)
|
||||
# Retry within the same round: decrement round counter so the
|
||||
# outer loop's increment brings us back to the same round number
|
||||
self.current_round -= 1
|
||||
continue
|
||||
|
||||
# Normal feedback path (no retry or non-data-context error or at limit)
|
||||
safe_feedback = sanitize_execution_feedback(feedback)
|
||||
self.conversation_history.append(
|
||||
{"role": "user", "content": f"代码执行反馈:\n{safe_feedback}"}
|
||||
@@ -403,6 +670,45 @@ class DataAnalysisAgent:
|
||||
"response": response,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Construct Round_Data and append to session ---
|
||||
result = process_result.get("result", {})
|
||||
round_data = {
|
||||
"round": self.current_round,
|
||||
"reasoning": process_result.get("reasoning", ""),
|
||||
"code": process_result.get("code", ""),
|
||||
"result_summary": self._summarize_result(result),
|
||||
"evidence_rows": result.get("evidence_rows", []),
|
||||
"raw_log": feedback,
|
||||
"auto_exported_files": result.get("auto_exported_files", []),
|
||||
"prompt_saved_files": result.get("prompt_saved_files", []),
|
||||
}
|
||||
|
||||
if self._session_ref:
|
||||
self._session_ref.rounds.append(round_data)
|
||||
# Merge file metadata into SessionData.data_files
|
||||
for f in round_data.get("auto_exported_files", []):
|
||||
if f.get("skipped"):
|
||||
continue # Large DataFrame — not written to disk
|
||||
self._session_ref.data_files.append({
|
||||
"filename": f.get("filename", ""),
|
||||
"description": f"自动导出: {f.get('variable_name', '')}",
|
||||
"rows": f.get("rows", 0),
|
||||
"cols": f.get("cols", 0),
|
||||
"columns": f.get("columns", []),
|
||||
"size_bytes": 0,
|
||||
"source": "auto",
|
||||
})
|
||||
for f in round_data.get("prompt_saved_files", []):
|
||||
self._session_ref.data_files.append({
|
||||
"filename": f.get("filename", ""),
|
||||
"description": f.get("description", ""),
|
||||
"rows": f.get("rows", 0),
|
||||
"cols": 0,
|
||||
"columns": [],
|
||||
"size_bytes": 0,
|
||||
"source": "prompt",
|
||||
})
|
||||
elif process_result["action"] == "collect_figures":
|
||||
# 记录图片收集结果
|
||||
collected_figures = process_result.get("collected_figures", [])
|
||||
@@ -596,6 +902,23 @@ class DataAnalysisAgent:
|
||||
f"输出: {exec_result.get('output')[:]}\n\n"
|
||||
)
|
||||
|
||||
# 构建各轮次证据数据摘要
|
||||
evidence_summary = ""
|
||||
if self._session_ref and self._session_ref.rounds:
|
||||
evidence_parts = []
|
||||
for rd in self._session_ref.rounds:
|
||||
round_num = rd.get("round", 0)
|
||||
summary = rd.get("result_summary", "")
|
||||
evidence = rd.get("evidence_rows", [])
|
||||
reasoning = rd.get("reasoning", "")
|
||||
part = f"第{round_num}轮: {summary}"
|
||||
if reasoning:
|
||||
part += f"\n 推理: {reasoning[:200]}"
|
||||
if evidence:
|
||||
part += f"\n 数据样本({len(evidence)}行): {json.dumps(evidence[:3], ensure_ascii=False, default=str)}"
|
||||
evidence_parts.append(part)
|
||||
evidence_summary = "\n".join(evidence_parts)
|
||||
|
||||
# 使用 prompts.py 中的统一提示词模板,并添加相对路径使用说明
|
||||
prompt = final_report_system_prompt.format(
|
||||
current_round=self.current_round,
|
||||
@@ -605,14 +928,24 @@ class DataAnalysisAgent:
|
||||
code_results_summary=code_results_summary,
|
||||
)
|
||||
|
||||
# Append evidence data from all rounds for evidence annotation
|
||||
if evidence_summary:
|
||||
prompt += f"""
|
||||
|
||||
**各轮次分析证据数据 (Evidence by Round)**:
|
||||
以下是每轮分析的结果摘要和数据样本,请在报告中使用 `<!-- evidence:round_N -->` 标注引用了哪一轮的数据:
|
||||
|
||||
{evidence_summary}
|
||||
"""
|
||||
|
||||
# 在提示词中明确要求使用相对路径
|
||||
prompt += """
|
||||
|
||||
[FOLDER] **图片路径使用说明**:
|
||||
报告和图片都在同一目录下,请在报告中使用相对路径引用图片:
|
||||
- 格式:
|
||||
- 格式:
|
||||
- 示例:
|
||||
- 这样可以确保报告在不同环境下都能正确显示图片
|
||||
- 注意:必须使用实际生成的图片文件名,严禁使用占位符
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
Reference in New Issue
Block a user