大更新,架构调整,数据分析能力提升,
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import ast
|
||||
import traceback
|
||||
@@ -15,6 +16,7 @@ from IPython.utils.capture import capture_output
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.font_manager as fm
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
@@ -82,8 +84,27 @@ class CodeExecutor:
|
||||
"PIL",
|
||||
"random",
|
||||
"networkx",
|
||||
"platform",
|
||||
}
|
||||
|
||||
# Maximum rows for auto-export; DataFrames larger than this are skipped
|
||||
# to avoid heavy disk I/O on large datasets.
|
||||
AUTO_EXPORT_MAX_ROWS = 50000
|
||||
|
||||
# Variable names to skip during DataFrame auto-export
|
||||
# (common import aliases and built-in namespace names)
|
||||
_SKIP_EXPORT_NAMES = {
|
||||
"pd", "np", "plt", "sns", "os", "json", "sys", "re", "io",
|
||||
"csv", "glob", "duckdb", "display", "math", "datetime", "time",
|
||||
"warnings", "logging", "copy", "pickle", "pathlib", "collections",
|
||||
"itertools", "functools", "operator", "random", "networkx",
|
||||
}
|
||||
|
||||
# Regex for parsing DATA_FILE_SAVED markers
|
||||
_DATA_FILE_SAVED_RE = re.compile(
|
||||
r"\[DATA_FILE_SAVED\]\s*filename:\s*(.+?),\s*rows:\s*(\d+),\s*description:\s*(.+)"
|
||||
)
|
||||
|
||||
def __init__(self, output_dir: str = "outputs"):
|
||||
"""
|
||||
初始化代码执行器
|
||||
@@ -318,6 +339,142 @@ from IPython.display import display
|
||||
|
||||
return str(obj)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_for_json(rows: List[Dict]) -> List[Dict]:
|
||||
"""Replace NaN/inf/-inf with None so the data is JSON-serializable."""
|
||||
import math
|
||||
sanitized = []
|
||||
for row in rows:
|
||||
clean = {}
|
||||
for k, v in row.items():
|
||||
if isinstance(v, float) and (math.isnan(v) or math.isinf(v)):
|
||||
clean[k] = None
|
||||
else:
|
||||
clean[k] = v
|
||||
sanitized.append(clean)
|
||||
return sanitized
|
||||
|
||||
def _capture_evidence_rows(self, result, shell) -> List[Dict]:
|
||||
"""
|
||||
Capture up to 10 evidence rows from the execution result.
|
||||
First checks result.result, then falls back to the last DataFrame in namespace.
|
||||
"""
|
||||
try:
|
||||
# Primary: check if result.result is a DataFrame
|
||||
if result.result is not None and isinstance(result.result, pd.DataFrame):
|
||||
return self._sanitize_for_json(
|
||||
result.result.head(10).to_dict(orient="records")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: find the last-assigned DataFrame variable in namespace
|
||||
try:
|
||||
last_df = None
|
||||
for name, obj in shell.user_ns.items():
|
||||
if (
|
||||
not name.startswith("_")
|
||||
and name not in self._SKIP_EXPORT_NAMES
|
||||
and isinstance(obj, pd.DataFrame)
|
||||
):
|
||||
last_df = obj
|
||||
if last_df is not None:
|
||||
return self._sanitize_for_json(
|
||||
last_df.head(10).to_dict(orient="records")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return []
|
||||
|
||||
def _snapshot_dataframes(self, shell) -> Dict[str, int]:
|
||||
"""Snapshot current DataFrame variables as {name: id(obj)}."""
|
||||
snapshot = {}
|
||||
try:
|
||||
for name, obj in shell.user_ns.items():
|
||||
if (
|
||||
not name.startswith("_")
|
||||
and name not in self._SKIP_EXPORT_NAMES
|
||||
and isinstance(obj, pd.DataFrame)
|
||||
):
|
||||
snapshot[name] = id(obj)
|
||||
except Exception:
|
||||
pass
|
||||
return snapshot
|
||||
|
||||
def _detect_new_dataframes(
|
||||
self, before: Dict[str, int], after: Dict[str, int]
|
||||
) -> List[str]:
|
||||
"""Return variable names of new or changed DataFrames."""
|
||||
new_or_changed = []
|
||||
for name, obj_id in after.items():
|
||||
if name not in before or before[name] != obj_id:
|
||||
new_or_changed.append(name)
|
||||
return new_or_changed
|
||||
|
||||
def _export_dataframe(self, var_name: str, df) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Export a DataFrame to CSV with dedup suffix. Returns metadata dict or None.
|
||||
Skips export for DataFrames exceeding AUTO_EXPORT_MAX_ROWS to avoid
|
||||
heavy disk I/O on large datasets; only metadata is recorded.
|
||||
"""
|
||||
try:
|
||||
rows_count = len(df)
|
||||
cols_count = len(df.columns)
|
||||
col_names = list(df.columns)
|
||||
|
||||
# Skip writing large DataFrames to disk — record metadata only
|
||||
if rows_count > self.AUTO_EXPORT_MAX_ROWS:
|
||||
return {
|
||||
"variable_name": var_name,
|
||||
"filename": f"(skipped: {var_name} has {rows_count} rows)",
|
||||
"rows": rows_count,
|
||||
"cols": cols_count,
|
||||
"columns": col_names,
|
||||
"skipped": True,
|
||||
}
|
||||
|
||||
base_filename = f"{var_name}.csv"
|
||||
filepath = os.path.join(self.output_dir, base_filename)
|
||||
|
||||
# Dedup: if file exists, try _1, _2, ...
|
||||
if os.path.exists(filepath):
|
||||
suffix = 1
|
||||
while True:
|
||||
dedup_filename = f"{var_name}_{suffix}.csv"
|
||||
filepath = os.path.join(self.output_dir, dedup_filename)
|
||||
if not os.path.exists(filepath):
|
||||
base_filename = dedup_filename
|
||||
break
|
||||
suffix += 1
|
||||
|
||||
df.to_csv(filepath, index=False)
|
||||
return {
|
||||
"variable_name": var_name,
|
||||
"filename": base_filename,
|
||||
"rows": rows_count,
|
||||
"cols": cols_count,
|
||||
"columns": col_names,
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _parse_data_file_saved_markers(self, stdout_text: str) -> List[Dict[str, Any]]:
|
||||
"""Parse [DATA_FILE_SAVED] marker lines from captured stdout."""
|
||||
results = []
|
||||
try:
|
||||
for line in stdout_text.splitlines():
|
||||
m = self._DATA_FILE_SAVED_RE.search(line)
|
||||
if m:
|
||||
results.append({
|
||||
"filename": m.group(1).strip(),
|
||||
"rows": int(m.group(2)),
|
||||
"description": m.group(3).strip(),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
return results
|
||||
|
||||
def execute_code(self, code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
执行代码并返回结果
|
||||
@@ -330,7 +487,10 @@ from IPython.display import display
|
||||
'success': bool,
|
||||
'output': str,
|
||||
'error': str,
|
||||
'variables': Dict[str, Any] # 新生成的重要变量
|
||||
'variables': Dict[str, Any], # 新生成的重要变量
|
||||
'evidence_rows': List[Dict], # up to 10 evidence rows
|
||||
'auto_exported_files': List[Dict], # auto-detected DataFrame exports
|
||||
'prompt_saved_files': List[Dict], # parsed DATA_FILE_SAVED markers
|
||||
}
|
||||
"""
|
||||
# 检查代码安全性
|
||||
@@ -341,12 +501,18 @@ from IPython.display import display
|
||||
"output": "",
|
||||
"error": f"代码安全检查失败: {safety_error}",
|
||||
"variables": {},
|
||||
"evidence_rows": [],
|
||||
"auto_exported_files": [],
|
||||
"prompt_saved_files": [],
|
||||
}
|
||||
|
||||
# 记录执行前的变量
|
||||
vars_before = set(self.shell.user_ns.keys())
|
||||
|
||||
try:
|
||||
# --- Task 6.1: Snapshot DataFrame variables before execution ---
|
||||
df_snapshot_before = self._snapshot_dataframes(self.shell)
|
||||
|
||||
# 使用IPython的capture_output来捕获所有输出
|
||||
with capture_output() as captured:
|
||||
result = self.shell.run_cell(code)
|
||||
@@ -359,6 +525,9 @@ from IPython.display import display
|
||||
"output": captured.stdout,
|
||||
"error": f"执行前错误: {error_msg}",
|
||||
"variables": {},
|
||||
"evidence_rows": [],
|
||||
"auto_exported_files": [],
|
||||
"prompt_saved_files": self._parse_data_file_saved_markers(captured.stdout),
|
||||
}
|
||||
|
||||
if result.error_in_exec:
|
||||
@@ -368,6 +537,9 @@ from IPython.display import display
|
||||
"output": captured.stdout,
|
||||
"error": f"执行错误: {error_msg}",
|
||||
"variables": {},
|
||||
"evidence_rows": [],
|
||||
"auto_exported_files": [],
|
||||
"prompt_saved_files": self._parse_data_file_saved_markers(captured.stdout),
|
||||
}
|
||||
|
||||
# 获取输出
|
||||
@@ -423,11 +595,36 @@ from IPython.display import display
|
||||
print(f"[WARN] [Auto-Save Global] 异常: {e}")
|
||||
# --- 自动保存机制 end ---
|
||||
|
||||
# --- Task 5: Evidence capture ---
|
||||
evidence_rows = self._capture_evidence_rows(result, self.shell)
|
||||
|
||||
# --- Task 6.2-6.4: DataFrame auto-detection and export ---
|
||||
auto_exported_files = []
|
||||
try:
|
||||
df_snapshot_after = self._snapshot_dataframes(self.shell)
|
||||
new_df_names = self._detect_new_dataframes(df_snapshot_before, df_snapshot_after)
|
||||
for var_name in new_df_names:
|
||||
try:
|
||||
df_obj = self.shell.user_ns[var_name]
|
||||
meta = self._export_dataframe(var_name, df_obj)
|
||||
if meta is not None:
|
||||
auto_exported_files.append(meta)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- Task 7: DATA_FILE_SAVED marker parsing ---
|
||||
prompt_saved_files = self._parse_data_file_saved_markers(captured.stdout)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output": output,
|
||||
"error": "",
|
||||
"variables": important_new_vars,
|
||||
"evidence_rows": evidence_rows,
|
||||
"auto_exported_files": auto_exported_files,
|
||||
"prompt_saved_files": prompt_saved_files,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
@@ -435,6 +632,9 @@ from IPython.display import display
|
||||
"output": captured.stdout if "captured" in locals() else "",
|
||||
"error": f"执行异常: {str(e)}\n{traceback.format_exc()}",
|
||||
"variables": {},
|
||||
"evidence_rows": [],
|
||||
"auto_exported_files": [],
|
||||
"prompt_saved_files": [],
|
||||
}
|
||||
|
||||
def reset_environment(self):
|
||||
|
||||
@@ -154,6 +154,111 @@ def load_data_chunked(file_path: str, chunksize: Optional[int] = None) -> Iterat
|
||||
print(f"[ERROR] 读取Excel文件失败: {e}")
|
||||
|
||||
|
||||
def _profile_chunked(file_path: str) -> str:
|
||||
"""
|
||||
Profile a large file by reading the first chunk plus sampled subsequent chunks.
|
||||
|
||||
Uses ``load_data_chunked()`` to stream the file. The first chunk is kept
|
||||
in full; every 5th subsequent chunk contributes up to 100 sampled rows.
|
||||
A markdown profile is generated from the combined sample.
|
||||
|
||||
Args:
|
||||
file_path: Path to the data file.
|
||||
|
||||
Returns:
|
||||
A markdown string containing the sampled profile for this file.
|
||||
"""
|
||||
file_name = os.path.basename(file_path)
|
||||
chunks_iter = load_data_chunked(file_path)
|
||||
first_chunk = next(chunks_iter, None)
|
||||
if first_chunk is None:
|
||||
return f"## 文件: {file_name}\n\n[ERROR] 无法读取文件: {file_path}\n\n"
|
||||
|
||||
sample_parts = [first_chunk]
|
||||
for i, chunk in enumerate(chunks_iter):
|
||||
if i % 5 == 0: # sample every 5th subsequent chunk
|
||||
sample_parts.append(chunk.head(min(100, len(chunk))))
|
||||
|
||||
combined = pd.concat(sample_parts, ignore_index=True)
|
||||
|
||||
# Build profile from the combined sample
|
||||
profile = f"## 文件: {file_name}\n\n"
|
||||
profile += f"- **注意**: 此画像基于抽样数据生成(首块 + 每5块采样100行)\n"
|
||||
rows, cols = combined.shape
|
||||
profile += f"- **样本维度**: {rows} 行 x {cols} 列\n"
|
||||
profile += f"- **列名**: `{', '.join(combined.columns)}`\n\n"
|
||||
profile += "### 列详细分布:\n"
|
||||
|
||||
for col in combined.columns:
|
||||
dtype = combined[col].dtype
|
||||
null_count = combined[col].isnull().sum()
|
||||
null_ratio = (null_count / rows) * 100 if rows > 0 else 0
|
||||
|
||||
profile += f"#### {col} ({dtype})\n"
|
||||
if null_count > 0:
|
||||
profile += f"- [WARN] 空值: {null_count} ({null_ratio:.1f}%)\n"
|
||||
|
||||
if pd.api.types.is_numeric_dtype(dtype):
|
||||
desc = combined[col].describe()
|
||||
profile += f"- 统计: Min={desc['min']:.2f}, Max={desc['max']:.2f}, Mean={desc['mean']:.2f}\n"
|
||||
elif pd.api.types.is_object_dtype(dtype) or pd.api.types.is_categorical_dtype(dtype):
|
||||
unique_count = combined[col].nunique()
|
||||
profile += f"- 唯一值数量: {unique_count}\n"
|
||||
if unique_count > 0:
|
||||
top_n = combined[col].value_counts().head(5)
|
||||
top_items_str = ", ".join([f"{k}({v})" for k, v in top_n.items()])
|
||||
profile += f"- **TOP 5 高频值**: {top_items_str}\n"
|
||||
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
||||
profile += f"- 范围: {combined[col].min()} 至 {combined[col].max()}\n"
|
||||
|
||||
profile += "\n"
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
def load_and_profile_data_smart(file_paths: list, max_file_size_mb: int = None) -> str:
|
||||
"""
|
||||
Smart data loader: selects chunked profiling for large files and full
|
||||
profiling for small files based on a size threshold.
|
||||
|
||||
Args:
|
||||
file_paths: List of file paths to profile.
|
||||
max_file_size_mb: Size threshold in MB. Files larger than this use
|
||||
chunked profiling. Defaults to ``app_config.max_file_size_mb``.
|
||||
|
||||
Returns:
|
||||
A markdown string containing the combined data profile.
|
||||
"""
|
||||
if max_file_size_mb is None:
|
||||
max_file_size_mb = app_config.max_file_size_mb
|
||||
|
||||
profile_summary = "# 数据画像报告 (Data Profile)\n\n"
|
||||
|
||||
if not file_paths:
|
||||
return profile_summary + "未提供数据文件。"
|
||||
|
||||
for file_path in file_paths:
|
||||
if not os.path.exists(file_path):
|
||||
profile_summary += f"## 文件: {os.path.basename(file_path)}\n\n"
|
||||
profile_summary += f"[WARN] 文件不存在: {file_path}\n\n"
|
||||
continue
|
||||
|
||||
try:
|
||||
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
|
||||
if file_size_mb > max_file_size_mb:
|
||||
profile_summary += _profile_chunked(file_path)
|
||||
else:
|
||||
# Use existing full-load profiling for this single file
|
||||
profile_summary += load_and_profile_data([file_path]).replace(
|
||||
"# 数据画像报告 (Data Profile)\n\n", ""
|
||||
)
|
||||
except Exception as e:
|
||||
profile_summary += f"## 文件: {os.path.basename(file_path)}\n\n"
|
||||
profile_summary += f"[ERROR] 读取或分析文件失败: {str(e)}\n\n"
|
||||
|
||||
return profile_summary
|
||||
|
||||
|
||||
def load_data_with_cache(file_path: str, force_reload: bool = False) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
带缓存的数据加载
|
||||
|
||||
@@ -154,6 +154,82 @@ def sanitize_execution_feedback(feedback: str, max_lines: int = 30) -> str:
|
||||
return "\n".join(safe_lines)
|
||||
|
||||
|
||||
def _extract_column_from_error(error_message: str) -> Optional[str]:
|
||||
"""Extract column name from error message patterns like KeyError: 'col_name'.
|
||||
|
||||
Supports:
|
||||
- KeyError: 'column_name' or KeyError: "column_name"
|
||||
- column 'column_name' or column "column_name" (case-insensitive)
|
||||
|
||||
Returns:
|
||||
The extracted column name, or None if no column reference is found.
|
||||
"""
|
||||
match = re.search(r"KeyError:\s*['\"](.+?)['\"]", error_message)
|
||||
if match:
|
||||
return match.group(1)
|
||||
match = re.search(r"column\s+['\"](.+?)['\"]", error_message, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
|
||||
def _lookup_column_in_profile(column_name: Optional[str], safe_profile: str) -> Optional[dict]:
|
||||
"""Look up column metadata in the safe profile markdown table.
|
||||
|
||||
Parses the markdown table rows produced by build_safe_profile() and returns
|
||||
a dict with keys: dtype, null_rate, unique_count, description.
|
||||
|
||||
Args:
|
||||
column_name: The column name to look up (may be None).
|
||||
safe_profile: The safe profile markdown string.
|
||||
|
||||
Returns:
|
||||
A dict of column metadata, or None if not found.
|
||||
"""
|
||||
if not column_name:
|
||||
return None
|
||||
for line in safe_profile.split("\n"):
|
||||
if line.startswith("|") and column_name in line:
|
||||
parts = [p.strip() for p in line.split("|") if p.strip()]
|
||||
if len(parts) >= 5 and parts[0] == column_name:
|
||||
return {
|
||||
"dtype": parts[1],
|
||||
"null_rate": parts[2],
|
||||
"unique_count": parts[3],
|
||||
"description": parts[4],
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def generate_enriched_hint(error_message: str, safe_profile: str) -> str:
|
||||
"""Generate an enriched hint from the safe profile for a data-context error.
|
||||
|
||||
Extracts the referenced column name from the error, looks it up in the safe
|
||||
profile markdown table, and returns a hint string containing only schema-level
|
||||
metadata — no real data values.
|
||||
|
||||
Args:
|
||||
error_message: The error message from code execution.
|
||||
safe_profile: The safe profile markdown string.
|
||||
|
||||
Returns:
|
||||
A hint string with retry context and column metadata (if found).
|
||||
"""
|
||||
column_name = _extract_column_from_error(error_message)
|
||||
column_meta = _lookup_column_in_profile(column_name, safe_profile)
|
||||
|
||||
hint = "[RETRY CONTEXT] 上一次代码执行因数据上下文错误失败。\n"
|
||||
hint += f"错误信息: {error_message}\n"
|
||||
if column_meta:
|
||||
hint += f"相关列 '{column_name}' 的结构信息:\n"
|
||||
hint += f" - 数据类型: {column_meta['dtype']}\n"
|
||||
hint += f" - 唯一值数量: {column_meta['unique_count']}\n"
|
||||
hint += f" - 空值率: {column_meta['null_rate']}\n"
|
||||
hint += f" - 特征描述: {column_meta['description']}\n"
|
||||
hint += "请根据以上结构信息修正代码,不要假设具体的数据值。"
|
||||
return hint
|
||||
|
||||
|
||||
def _load_dataframe(file_path: str):
|
||||
"""加载 DataFrame,支持多种格式和编码"""
|
||||
import os
|
||||
|
||||
Reference in New Issue
Block a user