Compare commits
30 Commits
11f14ede90
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 8222c8fd7c | |||
| d1fb498579 | |||
| c0879765fd | |||
| c7224153b1 | |||
| 7303008f48 | |||
| 3e1ecf2549 | |||
| 00bd48e7e7 | |||
| 9d01f004d4 | |||
| b256aa27d9 | |||
| c5083736e2 | |||
| b033eb61cc | |||
| c8fe5e6d6f | |||
| 3585ba6932 | |||
| ad90cd29d3 | |||
| e9644360ce | |||
| 5eb13324c2 | |||
| 674f48c74b | |||
| fbbb5a2470 | |||
| 162f5c4da4 | |||
| b1d0cc5462 | |||
| e51cdfea6f | |||
| 621e546b43 | |||
| 3a2f90aef5 | |||
|
|
fae233b10d | ||
|
|
8d90f029e1 | ||
|
|
1f420b1b6e | ||
|
|
fcbdec1298 | ||
| 8115abb6d6 | |||
| ca134e94c8 | |||
| 24870ba497 |
22
.env.example
22
.env.example
@@ -1,8 +1,18 @@
|
||||
# LLM Provider 配置
|
||||
# 支持 openai / gemini
|
||||
LLM_PROVIDER=openai
|
||||
|
||||
# 火山引擎配置
|
||||
OPENAI_API_KEY=sk-c44i1hy64xgzwox6x08o4zug93frq6rgn84oqugf2pje1tg4
|
||||
OPENAI_BASE_URL=https://api.xiaomimimo.com/v1/chat/completions
|
||||
# 文本模型
|
||||
OPENAI_MODEL=mimo-v2-flash
|
||||
# OPENAI_MODEL=deepseek-r1-250528
|
||||
# OpenAI 兼容接口配置
|
||||
OPENAI_API_KEY=your-api-key-here
|
||||
OPENAI_BASE_URL=http://127.0.0.1:9999/v1
|
||||
OPENAI_MODEL=your-model-name
|
||||
|
||||
# Gemini 配置(当 LLM_PROVIDER=gemini 时生效)
|
||||
# GEMINI_API_KEY=your-gemini-api-key
|
||||
# GEMINI_BASE_URL=https://generativelanguage.googleapis.com
|
||||
# GEMINI_MODEL=gemini-2.5-flash
|
||||
|
||||
# 应用配置(可选)
|
||||
# APP_MAX_ROUNDS=20
|
||||
# APP_CHUNK_SIZE=100000
|
||||
# APP_CACHE_ENABLED=true
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,6 +6,8 @@ __pycache__/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
|
||||
1
.kiro/specs/agent-robustness-optimization/.config.kiro
Normal file
1
.kiro/specs/agent-robustness-optimization/.config.kiro
Normal file
@@ -0,0 +1 @@
|
||||
{"specId": "ea41aaef-0737-4255-bcad-90f156a5b2d5", "workflowType": "requirements-first", "specType": "feature"}
|
||||
515
.kiro/specs/agent-robustness-optimization/design.md
Normal file
515
.kiro/specs/agent-robustness-optimization/design.md
Normal file
@@ -0,0 +1,515 @@
|
||||
# Design Document: Agent Robustness Optimization
|
||||
|
||||
## Overview
|
||||
|
||||
This design addresses five areas of improvement for the AI Data Analysis Agent: data privacy fallback recovery, conversation history trimming, analysis template integration, frontend progress display, and multi-file chunked/parallel loading. The changes span the Python backend (`data_analysis_agent.py`, `config/app_config.py`, `utils/data_privacy.py`, `utils/data_loader.py`, `web/main.py`) and the vanilla JS frontend (`web/static/script.js`, `web/static/index.html`, `web/static/clean_style.css`).
|
||||
|
||||
The core design principle is **minimal invasiveness**: each feature is implemented as a composable module or method that plugs into the existing agent loop, avoiding large-scale refactors of the `DataAnalysisAgent.analyze()` main loop.
|
||||
|
||||
## Architecture
|
||||
|
||||
The system follows a layered architecture where the `DataAnalysisAgent` orchestrates LLM calls and code execution, the FastAPI server manages sessions and exposes APIs, and the frontend polls for status updates.
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph Frontend
|
||||
UI[script.js + index.html]
|
||||
end
|
||||
|
||||
subgraph FastAPI Server
|
||||
API[web/main.py]
|
||||
SM[SessionManager]
|
||||
end
|
||||
|
||||
subgraph Agent Core
|
||||
DA[DataAnalysisAgent]
|
||||
EC[ErrorClassifier]
|
||||
HG[HintGenerator]
|
||||
HT[HistoryTrimmer]
|
||||
TI[TemplateIntegration]
|
||||
end
|
||||
|
||||
subgraph Utilities
|
||||
DP[data_privacy.py]
|
||||
DL[data_loader.py]
|
||||
AT[analysis_templates.py]
|
||||
CE[code_executor.py]
|
||||
end
|
||||
|
||||
subgraph Config
|
||||
AC[app_config.py]
|
||||
end
|
||||
|
||||
UI -->|POST /api/start, GET /api/status, GET /api/templates| API
|
||||
API --> SM
|
||||
API --> DA
|
||||
DA --> EC
|
||||
DA --> HG
|
||||
DA --> HT
|
||||
DA --> TI
|
||||
DA --> CE
|
||||
HG --> DP
|
||||
DL --> AC
|
||||
DA --> DL
|
||||
TI --> AT
|
||||
EC --> AC
|
||||
HT --> AC
|
||||
```
|
||||
|
||||
### Change Impact Summary
|
||||
|
||||
| Area | Files Modified | New Files |
|
||||
|------|---------------|-----------|
|
||||
| Data Privacy Fallback | `data_analysis_agent.py`, `utils/data_privacy.py`, `config/app_config.py` | None |
|
||||
| Conversation Trimming | `data_analysis_agent.py`, `config/app_config.py` | None |
|
||||
| Template System | `data_analysis_agent.py`, `web/main.py`, `web/static/script.js`, `web/static/index.html`, `web/static/clean_style.css` | None |
|
||||
| Progress Bar | `web/main.py`, `web/static/script.js`, `web/static/index.html`, `web/static/clean_style.css` | None |
|
||||
| Multi-File Loading | `utils/data_loader.py`, `data_analysis_agent.py`, `config/app_config.py` | None |
|
||||
|
||||
## Components and Interfaces
|
||||
|
||||
### 1. Error Classifier (`data_analysis_agent.py`)
|
||||
|
||||
A new method `_classify_error(error_message: str) -> str` on `DataAnalysisAgent` that inspects error messages and returns `"data_context"` or `"other"`.
|
||||
|
||||
```python
|
||||
DATA_CONTEXT_PATTERNS = [
|
||||
r"KeyError:\s*['\"](.+?)['\"]",
|
||||
r"ValueError.*(?:column|col|field)",
|
||||
r"NameError.*(?:df|data|frame)",
|
||||
r"(?:empty|no\s+data|0\s+rows)",
|
||||
r"IndexError.*(?:out of range|out of bounds)",
|
||||
]
|
||||
|
||||
def _classify_error(self, error_message: str) -> str:
|
||||
"""Classify execution error as data-context or other."""
|
||||
for pattern in DATA_CONTEXT_PATTERNS:
|
||||
if re.search(pattern, error_message, re.IGNORECASE):
|
||||
return "data_context"
|
||||
return "other"
|
||||
```
|
||||
|
||||
### 2. Enriched Hint Generator (`utils/data_privacy.py`)
|
||||
|
||||
A new function `generate_enriched_hint(error_message: str, safe_profile: str) -> str` that extracts the referenced column name from the error, looks it up in the safe profile, and returns a hint string containing only schema-level metadata.
|
||||
|
||||
```python
|
||||
def generate_enriched_hint(error_message: str, safe_profile: str) -> str:
|
||||
"""
|
||||
Generate an enriched hint from the safe profile for a data-context error.
|
||||
Returns schema-level metadata only — no real data values.
|
||||
"""
|
||||
column_name = _extract_column_from_error(error_message)
|
||||
column_meta = _lookup_column_in_profile(column_name, safe_profile)
|
||||
|
||||
hint = "[RETRY CONTEXT] 上一次代码执行因数据上下文错误失败。\n"
|
||||
hint += f"错误信息: {error_message}\n"
|
||||
if column_meta:
|
||||
hint += f"相关列 '{column_name}' 的结构信息:\n"
|
||||
hint += f" - 数据类型: {column_meta['dtype']}\n"
|
||||
hint += f" - 唯一值数量: {column_meta['unique_count']}\n"
|
||||
hint += f" - 空值率: {column_meta['null_rate']}\n"
|
||||
hint += f" - 特征描述: {column_meta['description']}\n"
|
||||
hint += "请根据以上结构信息修正代码,不要假设具体的数据值。"
|
||||
return hint
|
||||
|
||||
def _extract_column_from_error(error_message: str) -> Optional[str]:
|
||||
"""Extract column name from error message patterns like KeyError: 'col_name'."""
|
||||
match = re.search(r"KeyError:\s*['\"](.+?)['\"]", error_message)
|
||||
if match:
|
||||
return match.group(1)
|
||||
match = re.search(r"column\s+['\"](.+?)['\"]", error_message, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
def _lookup_column_in_profile(column_name: Optional[str], safe_profile: str) -> Optional[dict]:
|
||||
"""Look up column metadata in the safe profile markdown table."""
|
||||
if not column_name:
|
||||
return None
|
||||
# Parse the markdown table rows for the matching column
|
||||
for line in safe_profile.split("\n"):
|
||||
if line.startswith("|") and column_name in line:
|
||||
parts = [p.strip() for p in line.split("|") if p.strip()]
|
||||
if len(parts) >= 5 and parts[0] == column_name:
|
||||
return {
|
||||
"dtype": parts[1],
|
||||
"null_rate": parts[2],
|
||||
"unique_count": parts[3],
|
||||
"description": parts[4],
|
||||
}
|
||||
return None
|
||||
```
|
||||
|
||||
### 3. Conversation History Trimmer (`data_analysis_agent.py`)
|
||||
|
||||
A new method `_trim_conversation_history()` on `DataAnalysisAgent` that implements sliding window trimming with summary compression.
|
||||
|
||||
```python
|
||||
def _trim_conversation_history(self):
|
||||
"""Apply sliding window trimming to conversation history."""
|
||||
window_size = app_config.conversation_window_size
|
||||
max_messages = window_size * 2 # pairs of user+assistant messages
|
||||
|
||||
if len(self.conversation_history) <= max_messages:
|
||||
return # No trimming needed
|
||||
|
||||
first_message = self.conversation_history[0] # Always retain
|
||||
|
||||
# Determine trim boundary: skip first message + possible existing summary
|
||||
start_idx = 1
|
||||
has_existing_summary = (
|
||||
len(self.conversation_history) > 1
|
||||
and self.conversation_history[1]["role"] == "user"
|
||||
and self.conversation_history[1]["content"].startswith("[分析摘要]")
|
||||
)
|
||||
if has_existing_summary:
|
||||
start_idx = 2
|
||||
|
||||
# Messages to trim vs keep
|
||||
messages_to_consider = self.conversation_history[start_idx:]
|
||||
messages_to_trim = messages_to_consider[:-max_messages]
|
||||
messages_to_keep = messages_to_consider[-max_messages:]
|
||||
|
||||
if not messages_to_trim:
|
||||
return
|
||||
|
||||
# Generate summary of trimmed messages
|
||||
summary = self._compress_trimmed_messages(messages_to_trim)
|
||||
|
||||
# Rebuild history: first_message + summary + recent messages
|
||||
self.conversation_history = [first_message]
|
||||
if summary:
|
||||
self.conversation_history.append({"role": "user", "content": summary})
|
||||
self.conversation_history.extend(messages_to_keep)
|
||||
|
||||
def _compress_trimmed_messages(self, messages: list) -> str:
|
||||
"""Compress trimmed messages into a summary string."""
|
||||
summary_parts = ["[分析摘要] 以下是之前分析轮次的概要:"]
|
||||
round_num = 0
|
||||
|
||||
for msg in messages:
|
||||
content = msg["content"]
|
||||
if msg["role"] == "assistant":
|
||||
round_num += 1
|
||||
# Extract action type from YAML-like content
|
||||
action = "generate_code"
|
||||
if "action: \"collect_figures\"" in content or "action: collect_figures" in content:
|
||||
action = "collect_figures"
|
||||
elif "action: \"analysis_complete\"" in content or "action: analysis_complete" in content:
|
||||
action = "analysis_complete"
|
||||
summary_parts.append(f"- 轮次{round_num}: 动作={action}")
|
||||
elif msg["role"] == "user" and "代码执行反馈" in content:
|
||||
success = "失败" if "[ERROR]" in content or "执行错误" in content else "成功"
|
||||
summary_parts[-1] += f", 执行结果={success}"
|
||||
|
||||
return "\n".join(summary_parts)
|
||||
```
|
||||
|
||||
### 4. Template Integration (`data_analysis_agent.py` + `web/main.py`)
|
||||
|
||||
The `analyze()` method gains an optional `template_name` parameter. When provided, the template prompt is prepended to the user requirement.
|
||||
|
||||
**Agent side:**
|
||||
```python
|
||||
def analyze(self, user_input: str, files=None, session_output_dir=None,
|
||||
reset_session=True, max_rounds=None, template_name=None):
|
||||
# ... existing init code ...
|
||||
if template_name:
|
||||
from utils.analysis_templates import get_template
|
||||
template = get_template(template_name) # Raises ValueError if invalid
|
||||
template_prompt = template.get_full_prompt()
|
||||
user_input = f"{template_prompt}\n\n{user_input}"
|
||||
# ... rest of analyze ...
|
||||
```
|
||||
|
||||
**API side (`web/main.py`):**
|
||||
```python
|
||||
# New endpoint
|
||||
@app.get("/api/templates")
|
||||
async def list_available_templates():
|
||||
from utils.analysis_templates import list_templates
|
||||
return {"templates": list_templates()}
|
||||
|
||||
# Modified StartRequest
|
||||
class StartRequest(BaseModel):
|
||||
requirement: str
|
||||
template: Optional[str] = None
|
||||
```
|
||||
|
||||
### 5. Progress Bar Integration
|
||||
|
||||
**Backend (`web/main.py`):** Update `run_analysis_task` to set progress fields on `SessionData` via a callback or by polling the agent's `current_round`. The simplest approach is to add a progress callback to the agent.
|
||||
|
||||
```python
|
||||
# In DataAnalysisAgent
|
||||
def set_progress_callback(self, callback):
|
||||
"""Set a callback function(current_round, max_rounds, message) for progress updates."""
|
||||
self._progress_callback = callback
|
||||
|
||||
# Called at the start of each round in the analyze() loop:
|
||||
if hasattr(self, '_progress_callback') and self._progress_callback:
|
||||
self._progress_callback(self.current_round, self.max_rounds, f"第{self.current_round}轮分析中...")
|
||||
```
|
||||
|
||||
**Backend (`web/main.py`):** In `run_analysis_task`, wire the callback:
|
||||
```python
|
||||
def progress_cb(current, total, message):
|
||||
session.current_round = current
|
||||
session.max_rounds = total
|
||||
session.progress_percentage = round((current / total) * 100, 1) if total > 0 else 0
|
||||
session.status_message = message
|
||||
|
||||
agent.set_progress_callback(progress_cb)
|
||||
```
|
||||
|
||||
**API response:** Add progress fields to `GET /api/status`:
|
||||
```python
|
||||
return {
|
||||
"is_running": session.is_running,
|
||||
"log": log_content,
|
||||
"has_report": ...,
|
||||
"current_round": session.current_round,
|
||||
"max_rounds": session.max_rounds,
|
||||
"progress_percentage": session.progress_percentage,
|
||||
"status_message": session.status_message,
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
**Frontend (`script.js`):** During polling, render a progress bar when `is_running` is true:
|
||||
```javascript
|
||||
// In the polling callback:
|
||||
if (data.is_running) {
|
||||
updateProgressBar(data.progress_percentage, data.status_message);
|
||||
}
|
||||
```
|
||||
|
||||
### 6. Multi-File Chunked & Parallel Loading
|
||||
|
||||
**Chunked loading enhancement (`utils/data_loader.py`):**
|
||||
|
||||
```python
|
||||
def load_and_profile_data_smart(file_paths: list, max_file_size_mb: int = None) -> str:
|
||||
"""Smart loader: uses chunked reading for large files, regular for small."""
|
||||
if max_file_size_mb is None:
|
||||
max_file_size_mb = app_config.max_file_size_mb
|
||||
|
||||
profile_summary = "# 数据画像报告 (Data Profile)\n\n"
|
||||
for file_path in file_paths:
|
||||
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
|
||||
if file_size_mb > max_file_size_mb:
|
||||
profile_summary += _profile_chunked(file_path)
|
||||
else:
|
||||
profile_summary += _profile_full(file_path)
|
||||
return profile_summary
|
||||
|
||||
def _profile_chunked(file_path: str) -> str:
|
||||
"""Profile a large file by reading first chunk + sampling subsequent chunks."""
|
||||
chunks = load_data_chunked(file_path)
|
||||
first_chunk = next(chunks, None)
|
||||
if first_chunk is None:
|
||||
return f"[ERROR] 无法读取文件: {file_path}\n"
|
||||
|
||||
# Sample from subsequent chunks
|
||||
sample_rows = [first_chunk]
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i % 5 == 0: # Sample every 5th chunk
|
||||
sample_rows.append(chunk.sample(min(100, len(chunk))))
|
||||
|
||||
combined = pd.concat(sample_rows, ignore_index=True)
|
||||
# Generate profile from combined sample
|
||||
return _generate_profile_for_df(combined, file_path, sampled=True)
|
||||
```
|
||||
|
||||
**Parallel profiling (`data_analysis_agent.py`):**
|
||||
|
||||
```python
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
def _profile_files_parallel(self, file_paths: list) -> tuple[str, str]:
|
||||
"""Profile multiple files concurrently."""
|
||||
max_workers = app_config.max_parallel_profiles
|
||||
safe_profiles = []
|
||||
local_profiles = []
|
||||
|
||||
def profile_single(path):
|
||||
safe = build_safe_profile([path])
|
||||
local = build_local_profile([path])
|
||||
return path, safe, local
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {executor.submit(profile_single, p): p for p in file_paths}
|
||||
for future in as_completed(futures):
|
||||
path = futures[future]
|
||||
try:
|
||||
_, safe, local = future.result()
|
||||
safe_profiles.append(safe)
|
||||
local_profiles.append(local)
|
||||
except Exception as e:
|
||||
error_entry = f"## 文件: {os.path.basename(path)}\n[ERROR] 分析失败: {e}\n\n"
|
||||
safe_profiles.append(error_entry)
|
||||
local_profiles.append(error_entry)
|
||||
|
||||
return "\n".join(safe_profiles), "\n".join(local_profiles)
|
||||
```
|
||||
|
||||
## Data Models
|
||||
|
||||
### AppConfig Extensions (`config/app_config.py`)
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
# ... existing fields ...
|
||||
|
||||
# New fields
|
||||
max_data_context_retries: int = field(default=2)
|
||||
conversation_window_size: int = field(default=10)
|
||||
max_parallel_profiles: int = field(default=4)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'AppConfig':
|
||||
config = cls()
|
||||
# ... existing env overrides ...
|
||||
if val := os.getenv("APP_MAX_DATA_CONTEXT_RETRIES"):
|
||||
config.max_data_context_retries = int(val)
|
||||
if val := os.getenv("APP_CONVERSATION_WINDOW_SIZE"):
|
||||
config.conversation_window_size = int(val)
|
||||
if val := os.getenv("APP_MAX_PARALLEL_PROFILES"):
|
||||
config.max_parallel_profiles = int(val)
|
||||
return config
|
||||
```
|
||||
|
||||
### StartRequest Extension (`web/main.py`)
|
||||
|
||||
```python
|
||||
class StartRequest(BaseModel):
|
||||
requirement: str
|
||||
template: Optional[str] = None # New field
|
||||
```
|
||||
|
||||
### SessionData Progress Fields (already exist, just need wiring)
|
||||
|
||||
The `SessionData` class already has `current_round`, `max_rounds`, `progress_percentage`, and `status_message` fields. These just need to be updated during analysis and included in the `/api/status` response.
|
||||
|
||||
## Correctness Properties
|
||||
|
||||
*A property is a characteristic or behavior that should hold true across all valid executions of a system — essentially, a formal statement about what the system should do. Properties serve as the bridge between human-readable specifications and machine-verifiable correctness guarantees.*
|
||||
|
||||
### Property 1: Error Classification Correctness
|
||||
|
||||
*For any* error message string, if it contains a data-context pattern (KeyError on a column name, ValueError on column values, NameError for data variables, or empty DataFrame conditions), `_classify_error` SHALL return `"data_context"`; otherwise it SHALL return `"other"`.
|
||||
|
||||
**Validates: Requirements 1.1**
|
||||
|
||||
### Property 2: Retry Below Limit Produces Enriched Hint
|
||||
|
||||
*For any* `max_data_context_retries` value and any current retry count strictly less than that value, when a data-context error is detected, the agent SHALL produce an enriched hint message rather than forwarding the raw error.
|
||||
|
||||
**Validates: Requirements 1.3**
|
||||
|
||||
### Property 3: Enriched Hint Contains Correct Column Metadata Without Real Data
|
||||
|
||||
*For any* error message referencing a column name present in the Safe_Profile, the generated enriched hint SHALL contain that column's data type, unique value count, null rate, and categorical description, and SHALL NOT contain any real data values (min, max, mean, sample rows) from the Local_Profile.
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.4**
|
||||
|
||||
### Property 4: Environment Variable Override for Config Fields
|
||||
|
||||
*For any* positive integer value set as the `APP_MAX_DATA_CONTEXT_RETRIES` environment variable, `AppConfig.from_env()` SHALL produce a config where `max_data_context_retries` equals that integer value.
|
||||
|
||||
**Validates: Requirements 3.2**
|
||||
|
||||
### Property 5: Sliding Window Trimming Preserves First Message and Retains Recent Pairs
|
||||
|
||||
*For any* conversation history whose length exceeds `2 * conversation_window_size` and any `conversation_window_size >= 1`, after trimming: (a) the first user message is always retained at index 0, and (b) the most recent `conversation_window_size` message pairs are retained in full.
|
||||
|
||||
**Validates: Requirements 4.2, 4.3**
|
||||
|
||||
### Property 6: Trimming Summary Contains Round Info and Excludes Code/Raw Output
|
||||
|
||||
*For any* set of trimmed conversation messages, the generated summary SHALL list each trimmed round's action type and execution success/failure, and SHALL NOT contain any code blocks (``` markers) or raw execution output.
|
||||
|
||||
**Validates: Requirements 4.4, 5.1, 5.2**
|
||||
|
||||
### Property 7: Template Prompt Integration
|
||||
|
||||
*For any* valid template name in `TEMPLATE_REGISTRY` and any user requirement string, the initial conversation message SHALL contain the template's `get_full_prompt()` output prepended to the user requirement.
|
||||
|
||||
**Validates: Requirements 6.1, 6.2**
|
||||
|
||||
### Property 8: Invalid Template Name Raises Descriptive Error
|
||||
|
||||
*For any* string that is not a key in `TEMPLATE_REGISTRY`, calling `get_template()` SHALL raise a `ValueError` whose message contains the list of available template names.
|
||||
|
||||
**Validates: Requirements 6.3**
|
||||
|
||||
### Property 9: Chunked Loading Threshold
|
||||
|
||||
*For any* file path and `max_file_size_mb` threshold, if the file's size in MB exceeds the threshold, the smart loader SHALL use chunked loading; otherwise it SHALL use full loading.
|
||||
|
||||
**Validates: Requirements 10.1**
|
||||
|
||||
### Property 10: Chunked Profiling Uses First Chunk Plus Samples
|
||||
|
||||
*For any* file loaded in chunked mode, the generated profile SHALL be based on the first chunk plus sampled rows from subsequent chunks, not from the entire file loaded into memory.
|
||||
|
||||
**Validates: Requirements 10.3**
|
||||
|
||||
### Property 11: Parallel Profile Merge With Error Resilience
|
||||
|
||||
*For any* set of file paths where some are valid and some are invalid/corrupted, the merged profile output SHALL contain valid profile entries for successful files and error entries for failed files, with no files missing from the output.
|
||||
|
||||
**Validates: Requirements 11.2, 11.3**
|
||||
|
||||
## Error Handling
|
||||
|
||||
| Scenario | Handling Strategy |
|
||||
|----------|------------------|
|
||||
| Data-context error below retry limit | Generate enriched hint, retry with LLM |
|
||||
| Data-context error at retry limit | Fall back to normal sanitized error forwarding |
|
||||
| Invalid template name | Raise `ValueError` with available template list |
|
||||
| File too large for memory | Automatically switch to chunked loading |
|
||||
| Chunked loading fails | Return descriptive error, continue with other files |
|
||||
| Single file profiling fails in parallel | Include error entry, continue profiling remaining files |
|
||||
| Conversation history exceeds window | Trim old messages, generate compressed summary |
|
||||
| Summary generation fails | Log warning, proceed without summary (graceful degradation) |
|
||||
| Progress callback fails | Log warning, analysis continues without progress updates |
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Property-Based Tests (using `hypothesis`)
|
||||
|
||||
Each correctness property maps to a property-based test with minimum 100 iterations. The test library is `hypothesis` (Python).
|
||||
|
||||
- **Property 1**: Generate random error strings with/without data-context patterns → verify classification
|
||||
- **Property 2**: Generate random retry counts and limits → verify hint vs raw error behavior
|
||||
- **Property 3**: Generate random Safe_Profile tables and error messages → verify hint content and absence of real data
|
||||
- **Property 4**: Generate random positive integers → set env var → verify config
|
||||
- **Property 5**: Generate random conversation histories and window sizes → verify trimming invariants
|
||||
- **Property 6**: Generate random trimmed message sets → verify summary content and absence of code blocks
|
||||
- **Property 7**: Pick random valid template names and requirement strings → verify prompt construction
|
||||
- **Property 8**: Generate random strings not in registry → verify ValueError
|
||||
- **Property 9**: Generate random file sizes and thresholds → verify loading method selection
|
||||
- **Property 10**: Generate random chunked data → verify profile source
|
||||
- **Property 11**: Generate random file sets with failures → verify merged output
|
||||
|
||||
Tag format: `Feature: agent-robustness-optimization, Property {N}: {title}`
|
||||
|
||||
### Unit Tests
|
||||
|
||||
- Error classifier with specific known error messages (KeyError, ValueError, NameError, generic errors)
|
||||
- Enriched hint generation with known column profiles
|
||||
- Conversation trimming with exact message counts at boundary conditions
|
||||
- Template retrieval for each registered template
|
||||
- Progress callback wiring
|
||||
- API endpoint response shapes (`GET /api/templates`, `GET /api/status` with progress fields)
|
||||
|
||||
### Integration Tests
|
||||
|
||||
- `GET /api/templates` returns all registered templates
|
||||
- `POST /api/start` with `template` field passes template to agent
|
||||
- `GET /api/status` includes progress fields during analysis
|
||||
- Multi-file parallel profiling with real CSV files
|
||||
- End-to-end: start analysis with template → verify template prompt in conversation history
|
||||
142
.kiro/specs/agent-robustness-optimization/requirements.md
Normal file
142
.kiro/specs/agent-robustness-optimization/requirements.md
Normal file
@@ -0,0 +1,142 @@
|
||||
# Requirements Document
|
||||
|
||||
## Introduction
|
||||
|
||||
This document specifies the requirements for improving the robustness, efficiency, and usability of the AI Data Analysis Agent. The improvements span five areas: a data privacy fallback mechanism for recovering from LLM-generated code failures when real data is unavailable, conversation history trimming to reduce token consumption and prevent data leakage, integration of the existing analysis template system, frontend progress bar display, and multi-file parallel/chunked analysis support.
|
||||
|
||||
## Glossary
|
||||
|
||||
- **Agent**: The `DataAnalysisAgent` class in `data_analysis_agent.py` that orchestrates LLM calls and IPython code execution for data analysis.
|
||||
- **Safe_Profile**: The schema-only data description generated by `build_safe_profile()` in `utils/data_privacy.py`, containing column names, data types, null rates, and unique value counts — but no real data values.
|
||||
- **Local_Profile**: The full data profile generated by `build_local_profile()` containing real data values, statistics, and sample rows — used only in the local execution environment.
|
||||
- **Code_Executor**: The `CodeExecutor` class in `utils/code_executor.py` that runs Python code in an IPython sandbox and returns execution results.
|
||||
- **Conversation_History**: The list of `{"role": ..., "content": ...}` message dictionaries maintained by the Agent across analysis rounds.
|
||||
- **Feedback_Sanitizer**: The `sanitize_execution_feedback()` function in `utils/data_privacy.py` that removes real data values from execution output before sending to the LLM.
|
||||
- **Template_Registry**: The `TEMPLATE_REGISTRY` dictionary in `utils/analysis_templates.py` mapping template names to template classes.
|
||||
- **Session_Data**: The `SessionData` class in `web/main.py` that tracks session state including `progress_percentage`, `current_round`, `max_rounds`, and `status_message`.
|
||||
- **Polling_Loop**: The `setInterval`-based polling mechanism in `web/static/script.js` that fetches `/api/status` every 2 seconds.
|
||||
- **Data_Loader**: The module `utils/data_loader.py` providing `load_and_profile_data`, `load_data_chunked`, and `load_data_with_cache` functions.
|
||||
- **AppConfig**: The `AppConfig` dataclass in `config/app_config.py` holding configuration values such as `max_rounds`, `chunk_size`, and `max_file_size_mb`.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Requirement 1: Data Privacy Fallback — Error Detection
|
||||
|
||||
**User Story:** As a system operator, I want the Agent to detect when LLM-generated code fails due to missing real data context, so that the system can attempt intelligent recovery instead of wasting an analysis round.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN the Code_Executor returns a failed execution result, THE Agent SHALL classify the error as either a data-context error or a non-data error by inspecting the error message for patterns such as `KeyError`, `ValueError` on column values, `NameError` for undefined data variables, or empty DataFrame conditions.
|
||||
2. WHEN a data-context error is detected, THE Agent SHALL increment a per-round retry counter for the current analysis round.
|
||||
3. WHILE the retry counter for a given round is below the configured maximum retry limit, THE Agent SHALL attempt recovery by generating an enriched hint prompt rather than forwarding the raw error to the LLM as a normal failure.
|
||||
4. IF the retry counter reaches the configured maximum retry limit, THEN THE Agent SHALL fall back to normal error handling by forwarding the sanitized error feedback to the LLM and proceeding to the next round.
|
||||
|
||||
### Requirement 2: Data Privacy Fallback — Enriched Hint Generation
|
||||
|
||||
**User Story:** As a system operator, I want the Agent to provide the LLM with enriched schema hints when data-context errors occur, so that the LLM can generate corrected code without receiving raw data values.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN a data-context error is detected and retry is permitted, THE Agent SHALL generate an enriched hint containing the relevant column's data type, unique value count, null rate, and a categorical description (e.g., "low-cardinality category with 5 classes") extracted from the Safe_Profile.
|
||||
2. WHEN the error involves a specific column name referenced in the error message, THE Agent SHALL include that column's schema metadata in the enriched hint.
|
||||
3. THE Agent SHALL append the enriched hint to the conversation history as a user message with a prefix indicating it is a retry context, before requesting a new LLM response.
|
||||
4. THE Agent SHALL NOT include any real data values, sample rows, or statistical values (min, max, mean) from the Local_Profile in the enriched hint sent to the LLM.
|
||||
|
||||
### Requirement 3: Data Privacy Fallback — Configuration
|
||||
|
||||
**User Story:** As a system operator, I want to configure the maximum number of data-context retries, so that I can balance between recovery attempts and analysis throughput.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. THE AppConfig SHALL include a `max_data_context_retries` field with a default value of 2.
|
||||
2. WHEN the `APP_MAX_DATA_CONTEXT_RETRIES` environment variable is set, THE AppConfig SHALL use its integer value to override the default.
|
||||
3. THE Agent SHALL read the `max_data_context_retries` value from AppConfig during initialization.
|
||||
|
||||
### Requirement 4: Conversation History Trimming — Sliding Window
|
||||
|
||||
**User Story:** As a system operator, I want the conversation history to be trimmed using a sliding window, so that token consumption stays bounded and early execution results containing potential data leakage are removed.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. THE AppConfig SHALL include a `conversation_window_size` field with a default value of 10, representing the maximum number of recent message pairs to retain in full.
|
||||
2. WHEN the Conversation_History length exceeds twice the `conversation_window_size` (counting individual messages), THE Agent SHALL retain only the most recent `conversation_window_size` pairs of messages in full detail.
|
||||
3. THE Agent SHALL always retain the first user message (containing the original requirement and Safe_Profile) regardless of window trimming.
|
||||
4. WHEN messages are trimmed from the Conversation_History, THE Agent SHALL generate a compressed summary of the trimmed messages and prepend it after the first user message.
|
||||
|
||||
### Requirement 5: Conversation History Trimming — Summary Compression
|
||||
|
||||
**User Story:** As a system operator, I want trimmed conversation rounds to be compressed into a summary, so that the LLM retains awareness of prior analysis steps without consuming excessive tokens.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN conversation messages are trimmed, THE Agent SHALL produce a summary string that lists each trimmed round's action type (generate_code, collect_figures), a one-line description of what was done, and whether execution succeeded or failed.
|
||||
2. THE summary SHALL NOT contain any code blocks, raw execution output, or data values from prior rounds.
|
||||
3. THE summary SHALL be inserted into the Conversation_History as a single user message immediately after the first user message, replacing any previous summary message.
|
||||
4. IF no messages have been trimmed, THEN THE Agent SHALL NOT insert a summary message.
|
||||
|
||||
### Requirement 6: Analysis Template System — Backend Integration
|
||||
|
||||
**User Story:** As a user, I want to select a predefined analysis template when starting an analysis, so that the Agent follows a structured analysis plan tailored to my scenario.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN a template name is provided in the analysis request, THE Agent SHALL retrieve the corresponding template from the Template_Registry using the `get_template()` function.
|
||||
2. WHEN a valid template is retrieved, THE Agent SHALL call `get_full_prompt()` on the template and prepend the resulting structured prompt to the user's requirement in the initial conversation message.
|
||||
3. IF an invalid template name is provided, THEN THE Agent SHALL raise a descriptive error listing available template names.
|
||||
4. WHEN no template name is provided, THE Agent SHALL proceed with the default unstructured analysis flow.
|
||||
|
||||
### Requirement 7: Analysis Template System — API Endpoint
|
||||
|
||||
**User Story:** As a frontend developer, I want API endpoints to list available templates and to accept a template selection when starting analysis, so that the frontend can offer template choices to users.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. THE FastAPI server SHALL expose a `GET /api/templates` endpoint that returns the list of available templates by calling `list_templates()`, with each entry containing `name`, `display_name`, and `description`.
|
||||
2. THE `POST /api/start` request body SHALL accept an optional `template` field containing the template name string.
|
||||
3. WHEN the `template` field is present in the start request, THE FastAPI server SHALL pass the template name to the Agent's `analyze()` method.
|
||||
4. WHEN the `template` field is absent or empty, THE FastAPI server SHALL start analysis without a template.
|
||||
|
||||
### Requirement 8: Analysis Template System — Frontend Template Selector
|
||||
|
||||
**User Story:** As a user, I want to see and select analysis templates in the web interface before starting analysis, so that I can choose a structured analysis approach.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN the web page loads, THE frontend SHALL fetch the template list from `GET /api/templates` and render selectable template cards above the requirement input area.
|
||||
2. WHEN a user selects a template card, THE frontend SHALL visually highlight the selected template and store the template name.
|
||||
3. WHEN the user clicks "Start Analysis" with a template selected, THE frontend SHALL include the template name in the `POST /api/start` request body.
|
||||
4. THE frontend SHALL provide a "No Template (Free Analysis)" option that is selected by default, allowing users to proceed without a template.
|
||||
|
||||
### Requirement 9: Frontend Progress Bar Display
|
||||
|
||||
**User Story:** As a user, I want to see a real-time progress bar during analysis, so that I can understand how far the analysis has progressed.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. THE FastAPI server SHALL update the Session_Data's `current_round`, `max_rounds`, `progress_percentage`, and `status_message` fields during each analysis round in the `run_analysis_task` function.
|
||||
2. THE `GET /api/status` response SHALL include `current_round`, `max_rounds`, `progress_percentage`, and `status_message` fields.
|
||||
3. WHEN the Polling_Loop receives status data with `is_running` equal to true, THE frontend SHALL render a progress bar element showing the `progress_percentage` value and the `status_message` text.
|
||||
4. WHEN `progress_percentage` changes between polls, THE frontend SHALL animate the progress bar width transition smoothly.
|
||||
5. WHEN `is_running` becomes false, THE frontend SHALL set the progress bar to 100% and display a completion message.
|
||||
|
||||
### Requirement 10: Multi-File Chunked Loading
|
||||
|
||||
**User Story:** As a user, I want large data files to be loaded in chunks, so that the system can handle files that exceed available memory.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN a data file's size exceeds the `max_file_size_mb` threshold in AppConfig, THE Data_Loader SHALL use `load_data_chunked()` to stream the file in chunks of `chunk_size` rows instead of loading the entire file into memory.
|
||||
2. WHEN chunked loading is used, THE Agent SHALL instruct the Code_Executor to make the chunked iterator available in the notebook environment as a variable, so that LLM-generated code can process data in chunks.
|
||||
3. WHEN chunked loading is used for profiling, THE Agent SHALL generate the Safe_Profile by reading only the first chunk plus sampling from subsequent chunks, rather than loading the entire file.
|
||||
4. IF a file cannot be loaded even in chunked mode, THEN THE Data_Loader SHALL return a descriptive error message indicating the failure reason.
|
||||
|
||||
### Requirement 11: Multi-File Parallel Profiling
|
||||
|
||||
**User Story:** As a user, I want multiple data files to be profiled concurrently, so that the initial data exploration phase completes faster when multiple files are uploaded.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN multiple files are provided for analysis, THE Agent SHALL profile each file concurrently using thread-based parallelism rather than sequentially.
|
||||
2. THE Agent SHALL collect all profiling results and merge them into a single Safe_Profile string and a single Local_Profile string, maintaining the same format as the current sequential output.
|
||||
3. IF any individual file profiling fails, THEN THE Agent SHALL include an error entry for that file in the profile output and continue profiling the remaining files.
|
||||
4. THE AppConfig SHALL include a `max_parallel_profiles` field with a default value of 4, controlling the maximum number of concurrent profiling threads.
|
||||
74
.kiro/specs/agent-robustness-optimization/tasks.md
Normal file
74
.kiro/specs/agent-robustness-optimization/tasks.md
Normal file
@@ -0,0 +1,74 @@
|
||||
# Tasks — Agent Robustness Optimization
|
||||
|
||||
## Priority 1: Configuration Foundation
|
||||
|
||||
- [x] 1. Add new config fields to AppConfig
|
||||
- [x] 1.1 Add `max_data_context_retries` field (default=2) with `APP_MAX_DATA_CONTEXT_RETRIES` env override to `config/app_config.py`
|
||||
- [x] 1.2 Add `conversation_window_size` field (default=10) with `APP_CONVERSATION_WINDOW_SIZE` env override to `config/app_config.py`
|
||||
- [x] 1.3 Add `max_parallel_profiles` field (default=4) with `APP_MAX_PARALLEL_PROFILES` env override to `config/app_config.py`
|
||||
|
||||
## Priority 2: Data Privacy Fallback (R1–R3)
|
||||
|
||||
- [x] 2. Implement error classification
|
||||
- [x] 2.1 Add `_classify_error(error_message: str) -> str` method to `DataAnalysisAgent` in `data_analysis_agent.py` with regex patterns for KeyError, ValueError, NameError, empty DataFrame
|
||||
- [x] 2.2 Add `_extract_column_from_error(error_message: str) -> Optional[str]` function to `utils/data_privacy.py`
|
||||
- [x] 2.3 Add `_lookup_column_in_profile(column_name, safe_profile) -> Optional[dict]` function to `utils/data_privacy.py`
|
||||
- [x] 3. Implement enriched hint generation
|
||||
- [x] 3.1 Add `generate_enriched_hint(error_message: str, safe_profile: str) -> str` function to `utils/data_privacy.py`
|
||||
- [x] 3.2 Integrate retry logic into the `analyze()` loop in `data_analysis_agent.py`: add per-round retry counter, call `_classify_error` on failures, generate enriched hint when below retry limit, fall back to normal error handling at limit
|
||||
|
||||
## Priority 3: Conversation History Trimming (R4–R5)
|
||||
|
||||
- [x] 4. Implement conversation trimming
|
||||
- [x] 4.1 Add `_trim_conversation_history()` method to `DataAnalysisAgent` implementing sliding window with first-message preservation
|
||||
- [x] 4.2 Add `_compress_trimmed_messages(messages: list) -> str` method to `DataAnalysisAgent` that generates summary with action types and success/failure, excluding code blocks and raw output
|
||||
- [x] 4.3 Call `_trim_conversation_history()` at the start of each round in the `analyze()` loop, after the first round
|
||||
|
||||
## Priority 4: Analysis Template System (R6–R8)
|
||||
|
||||
- [x] 5. Backend template integration
|
||||
- [x] 5.1 Add optional `template_name` parameter to `DataAnalysisAgent.analyze()` method; retrieve template via `get_template()`, prepend `get_full_prompt()` to user requirement
|
||||
- [x] 5.2 Add `GET /api/templates` endpoint to `web/main.py` returning `list_templates()` result
|
||||
- [x] 5.3 Add optional `template` field to `StartRequest` model in `web/main.py`; pass template name to agent in `run_analysis_task`
|
||||
- [x] 6. Frontend template selector
|
||||
- [x] 6.1 Add template selector HTML section (cards above requirement input) to `web/static/index.html`
|
||||
- [x] 6.2 Add template fetching, selection logic, and "No Template" default to `web/static/script.js`
|
||||
- [x] 6.3 Add template card styles (`.template-card`, `.template-card.selected`) to `web/static/clean_style.css`
|
||||
|
||||
## Priority 5: Frontend Progress Bar (R9)
|
||||
|
||||
- [x] 7. Backend progress updates
|
||||
- [x] 7.1 Add `set_progress_callback(callback)` method to `DataAnalysisAgent`; call callback at start of each round in `analyze()` loop
|
||||
- [x] 7.2 Wire progress callback in `run_analysis_task` in `web/main.py` to update `SessionData` progress fields
|
||||
- [x] 7.3 Add `current_round`, `max_rounds`, `progress_percentage`, `status_message` to `GET /api/status` response in `web/main.py`
|
||||
- [x] 8. Frontend progress bar
|
||||
- [x] 8.1 Add progress bar HTML element below the status bar area in `web/static/index.html`
|
||||
- [x] 8.2 Add `updateProgressBar(percentage, message)` function to `web/static/script.js`; call it during polling when `is_running` is true; set to 100% on completion
|
||||
- [x] 8.3 Add progress bar styles with CSS transition animation to `web/static/clean_style.css`
|
||||
|
||||
## Priority 6: Multi-File Chunked & Parallel Loading (R10–R11)
|
||||
|
||||
- [x] 9. Chunked loading enhancement
|
||||
- [x] 9.1 Add `_profile_chunked(file_path: str) -> str` function to `utils/data_loader.py` that profiles using first chunk + sampled subsequent chunks
|
||||
- [x] 9.2 Add `load_and_profile_data_smart(file_paths, max_file_size_mb) -> str` function to `utils/data_loader.py` that selects chunked vs full loading based on file size threshold
|
||||
- [x] 9.3 Update `DataAnalysisAgent.analyze()` to use smart loader and expose chunked iterator in Code_Executor namespace for large files
|
||||
- [x] 10. Parallel profiling
|
||||
- [x] 10.1 Add `_profile_files_parallel(file_paths: list) -> tuple[str, str]` method to `DataAnalysisAgent` using `ThreadPoolExecutor` with `max_parallel_profiles` workers
|
||||
- [x] 10.2 Update `DataAnalysisAgent.analyze()` to call `_profile_files_parallel` when multiple files are provided, replacing sequential `build_safe_profile` + `build_local_profile` calls
|
||||
|
||||
## Priority 7: Testing
|
||||
|
||||
- [x] 11. Write property-based tests
|
||||
- [x] 11.1 ~PBT~ Property test for error classification correctness (Property 1) using `hypothesis`
|
||||
- [x] 11.2 ~PBT~ Property test for enriched hint content and privacy (Property 3) using `hypothesis`
|
||||
- [x] 11.3 ~PBT~ Property test for env var config override (Property 4) using `hypothesis`
|
||||
- [x] 11.4 ~PBT~ Property test for sliding window trimming invariants (Property 5) using `hypothesis`
|
||||
- [x] 11.5 ~PBT~ Property test for trimming summary content (Property 6) using `hypothesis`
|
||||
- [x] 11.6 ~PBT~ Property test for template prompt integration (Property 7) using `hypothesis`
|
||||
- [x] 11.7 ~PBT~ Property test for invalid template error (Property 8) using `hypothesis`
|
||||
- [x] 11.8 ~PBT~ Property test for parallel profile merge with error resilience (Property 11) using `hypothesis`
|
||||
- [x] 12. Write unit and integration tests
|
||||
- [x] 12.1 Unit tests for error classifier with known error messages
|
||||
- [x] 12.2 Unit tests for conversation trimming at boundary conditions
|
||||
- [x] 12.3 Integration tests for `GET /api/templates` and `POST /api/start` with template field
|
||||
- [x] 12.4 Integration tests for `GET /api/status` progress fields
|
||||
1
.kiro/specs/analysis-dashboard-redesign/.config.kiro
Normal file
1
.kiro/specs/analysis-dashboard-redesign/.config.kiro
Normal file
@@ -0,0 +1 @@
|
||||
{"specId": "ea41aaef-0737-4255-bcad-90f156a5b2d5", "workflowType": "requirements-first", "specType": "feature"}
|
||||
393
.kiro/specs/analysis-dashboard-redesign/design.md
Normal file
393
.kiro/specs/analysis-dashboard-redesign/design.md
Normal file
@@ -0,0 +1,393 @@
|
||||
# Design Document: Analysis Dashboard Redesign
|
||||
|
||||
## Overview
|
||||
|
||||
This design transforms the Analysis Dashboard from a raw-log-centric 3-tab layout (Live Log, Report, Gallery) into a structured, evidence-driven 3-tab layout (Execution Process, Data Files, Report). The core architectural change is introducing a **Round_Data** structured data model that flows from the agent's execution loop through the API to the frontend, replacing the current raw text log approach.
|
||||
|
||||
Key design decisions:
|
||||
- **Round_Data as the central abstraction**: Every analysis round produces a structured object containing reasoning, code, result summary, data evidence, and raw log. This single model drives the Execution Process tab, evidence linking, and data file tracking.
|
||||
- **Auto-detection at the CodeExecutor level**: DataFrame detection and CSV export happen transparently in `CodeExecutor.execute_code()`, requiring no LLM cooperation. Prompt guidance is additive — it encourages the LLM to save files explicitly, but the system doesn't depend on it.
|
||||
- **Gallery absorbed into Report**: Images are already rendered inline via `marked.js` Markdown parsing. Removing the Gallery tab is a subtraction, not an addition.
|
||||
- **Evidence linking via HTML comments**: The LLM annotates report paragraphs with `<!-- evidence:round_N -->` comments during final report generation. The backend parses these to build a `supporting_data` mapping. This is a best-effort approach — missing annotations simply mean no "查看支撑数据" button.
|
||||
|
||||
## Architecture
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
subgraph Backend
|
||||
A[DataAnalysisAgent] -->|produces| B[Round_Data objects]
|
||||
A -->|calls| C[CodeExecutor]
|
||||
C -->|auto-detects DataFrames| D[CSV export to session dir]
|
||||
C -->|captures evidence rows| B
|
||||
C -->|parses DATA_FILE_SAVED markers| E[File metadata]
|
||||
B -->|stored on| F[SessionData]
|
||||
E -->|stored on| F
|
||||
F -->|serves| G[GET /api/status]
|
||||
F -->|serves| H[GET /api/data-files]
|
||||
F -->|serves| I[GET /api/report]
|
||||
end
|
||||
|
||||
subgraph Frontend
|
||||
G -->|rounds array| J[Execution Process Tab]
|
||||
H -->|file list + preview| K[Data Files Tab]
|
||||
I -->|paragraphs + supporting_data| L[Report Tab]
|
||||
end
|
||||
```
|
||||
|
||||
### Data Flow
|
||||
|
||||
1. **Agent loop** (`DataAnalysisAgent.analyze`): Each round calls `CodeExecutor.execute_code()`, which returns an enriched result dict containing `evidence_rows`, `auto_exported_files`, and `prompt_saved_files`. The agent wraps this into a `Round_Data` dict and appends it to `SessionData.rounds`.
|
||||
|
||||
2. **Status polling**: Frontend polls `GET /api/status` every 2 seconds. The response now includes a `rounds` array. The frontend incrementally appends new `Round_Card` elements — it tracks the last-seen round count and only renders new entries.
|
||||
|
||||
3. **Data Files**: `GET /api/data-files` reads `SessionData.data_files` plus scans the session directory for CSV/XLSX files (fallback discovery). Preview reads the first 5 rows via pandas.
|
||||
|
||||
4. **Report with evidence**: `GET /api/report` parses `<!-- evidence:round_N -->` annotations, looks up `SessionData.rounds[N].evidence_rows`, and builds a `supporting_data` mapping keyed by paragraph ID.
|
||||
|
||||
## Components and Interfaces
|
||||
|
||||
### 1. CodeExecutor Enhancements (`utils/code_executor.py`)
|
||||
|
||||
**New behavior in `execute_code()`:**
|
||||
|
||||
```python
|
||||
def execute_code(self, code: str) -> Dict[str, Any]:
|
||||
"""Returns dict with keys: success, output, error, variables,
|
||||
evidence_rows, auto_exported_files, prompt_saved_files"""
|
||||
```
|
||||
|
||||
- **DataFrame snapshot before/after**: Before execution, capture `{name: id(obj)}` for all DataFrame variables. After execution, detect new names or changed `id()` values.
|
||||
- **Evidence capture**: If the execution result is a DataFrame (via `result.result`), call `.head(10).to_dict(orient='records')` to produce `evidence_rows`. Also check the last assigned DataFrame variable in the namespace.
|
||||
- **Auto-export**: For each newly detected DataFrame, export to `{session_dir}/{var_name}.csv` with dedup suffix. Record metadata in `auto_exported_files` list.
|
||||
- **Marker parsing**: Scan `captured.stdout` for `[DATA_FILE_SAVED]` lines, parse filename/rows/description, record in `prompt_saved_files` list.
|
||||
|
||||
**Interface contract:**
|
||||
```python
|
||||
# evidence_rows: list[dict] — up to 10 rows as dicts
|
||||
# auto_exported_files: list[dict] — [{variable_name, filename, rows, cols, columns}]
|
||||
# prompt_saved_files: list[dict] — [{filename, rows, description}]
|
||||
```
|
||||
|
||||
### 2. DataAnalysisAgent Changes (`data_analysis_agent.py`)
|
||||
|
||||
**Round_Data construction** in `_handle_generate_code()` and the main loop:
|
||||
|
||||
```python
|
||||
round_data = {
|
||||
"round": self.current_round,
|
||||
"reasoning": yaml_data.get("reasoning", ""),
|
||||
"code": code,
|
||||
"result_summary": self._summarize_result(result),
|
||||
"evidence_rows": result.get("evidence_rows", []),
|
||||
"raw_log": feedback,
|
||||
"auto_exported_files": result.get("auto_exported_files", []),
|
||||
"prompt_saved_files": result.get("prompt_saved_files", []),
|
||||
}
|
||||
```
|
||||
|
||||
The agent appends `round_data` to `SessionData.rounds` (accessed via the progress callback or a direct reference). File metadata from both `auto_exported_files` and `prompt_saved_files` is merged into `SessionData.data_files`.
|
||||
|
||||
**`_summarize_result()`**: Produces a one-line summary from the execution result — e.g., "执行成功,输出 DataFrame (150行×8列)" or "执行失败: KeyError: 'col_x'".
|
||||
|
||||
### 3. SessionData Extension (`web/main.py`)
|
||||
|
||||
```python
|
||||
class SessionData:
|
||||
def __init__(self, session_id: str):
|
||||
# ... existing fields ...
|
||||
self.rounds: List[Dict] = [] # Round_Data objects
|
||||
self.data_files: List[Dict] = [] # File metadata dicts
|
||||
```
|
||||
|
||||
Persistence: `rounds` and `data_files` are written to `results.json` on analysis completion (existing pattern).
|
||||
|
||||
### 4. API Changes (`web/main.py`)
|
||||
|
||||
**`GET /api/status`** — add `rounds` to response:
|
||||
```python
|
||||
return {
|
||||
# ... existing fields ...
|
||||
"rounds": session.rounds,
|
||||
}
|
||||
```
|
||||
|
||||
**`GET /api/data-files`** — new endpoint:
|
||||
```python
|
||||
@app.get("/api/data-files")
|
||||
async def list_data_files(session_id: str = Query(...)):
|
||||
# Returns session.data_files + fallback directory scan
|
||||
```
|
||||
|
||||
**`GET /api/data-files/preview`** — new endpoint:
|
||||
```python
|
||||
@app.get("/api/data-files/preview")
|
||||
async def preview_data_file(session_id: str = Query(...), filename: str = Query(...)):
|
||||
# Reads CSV/XLSX, returns {columns: [...], rows: [...first 5...]}
|
||||
```
|
||||
|
||||
**`GET /api/data-files/download`** — new endpoint:
|
||||
```python
|
||||
@app.get("/api/data-files/download")
|
||||
async def download_data_file(session_id: str = Query(...), filename: str = Query(...)):
|
||||
# Returns FileResponse with appropriate MIME type
|
||||
```
|
||||
|
||||
**`GET /api/report`** — enhanced response:
|
||||
```python
|
||||
return {
|
||||
"content": content,
|
||||
"base_path": web_base_path,
|
||||
"paragraphs": paragraphs,
|
||||
"supporting_data": supporting_data_map, # NEW: {paragraph_id: [evidence_rows]}
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Prompt Changes (`prompts.py`)
|
||||
|
||||
Add to `data_analysis_system_prompt` after the existing code generation rules:
|
||||
|
||||
```
|
||||
**中间数据保存规则**:
|
||||
- 当你生成了有价值的中间数据(筛选子集、聚合表、聚类结果等),请主动保存为CSV/XLSX文件。
|
||||
- 保存后必须打印标记行:`[DATA_FILE_SAVED] filename: {文件名}, rows: {行数}, description: {描述}`
|
||||
- 示例:
|
||||
```python
|
||||
top_issues.to_csv(os.path.join(session_output_dir, "TOP问题汇总.csv"), index=False)
|
||||
print(f"[DATA_FILE_SAVED] filename: TOP问题汇总.csv, rows: {len(top_issues)}, description: 各类型TOP问题聚合统计")
|
||||
```
|
||||
```
|
||||
|
||||
Add to `final_report_system_prompt` for evidence annotation:
|
||||
|
||||
```
|
||||
**证据标注规则**:
|
||||
- 当报告段落的结论来源于某一轮分析的数据,请在段落末尾添加HTML注释标注:`<!-- evidence:round_N -->`
|
||||
- N 为产生该数据的分析轮次编号(从1开始)
|
||||
- 示例:某段落描述了第3轮分析发现的车型分布规律,则在段落末尾添加 `<!-- evidence:round_3 -->`
|
||||
```
|
||||
|
||||
### 6. Frontend Changes
|
||||
|
||||
**`index.html`**:
|
||||
- Replace tab labels: "Live Log" → "执行过程", add "数据文件", keep "Report"
|
||||
- Remove Gallery tab HTML and carousel container
|
||||
- Add Execution Process tab container with round card template
|
||||
- Add Data Files tab container with file card template
|
||||
|
||||
**`script.js`**:
|
||||
- Remove gallery functions and state
|
||||
- Add `renderRoundCards(rounds)` — incremental rendering using a `lastRenderedRound` counter
|
||||
- Add `loadDataFiles()`, `previewDataFile(filename)`, `downloadDataFile(filename)`
|
||||
- Modify `startPolling()` to call `renderRoundCards()` and `loadDataFiles()` on each cycle
|
||||
- Add `showSupportingData(paraId)` for the evidence popover
|
||||
- Modify `renderParagraphReport()` to add "查看支撑数据" buttons when `supporting_data[paraId]` exists
|
||||
- Update `switchTab()` to handle `execution`, `datafiles`, `report`
|
||||
|
||||
**`clean_style.css`**:
|
||||
- Add `.round-card`, `.round-card-header`, `.round-card-body` styles
|
||||
- Add `.data-file-card`, `.data-preview-table` styles
|
||||
- Add `.supporting-data-btn`, `.supporting-data-popover` styles
|
||||
- Remove `.carousel-*` styles
|
||||
|
||||
## Data Models
|
||||
|
||||
### Round_Data (Python dict)
|
||||
|
||||
```python
|
||||
{
|
||||
"round": int, # 1-indexed round number
|
||||
"reasoning": str, # LLM reasoning text (may be empty)
|
||||
"code": str, # Generated Python code
|
||||
"result_summary": str, # One-line execution summary
|
||||
"evidence_rows": list[dict], # Up to 10 rows as [{col: val, ...}]
|
||||
"raw_log": str, # Full execution feedback text
|
||||
"auto_exported_files": list[dict], # Auto-detected DataFrame exports
|
||||
"prompt_saved_files": list[dict], # LLM-guided file saves
|
||||
}
|
||||
```
|
||||
|
||||
### File Metadata (Python dict)
|
||||
|
||||
```python
|
||||
{
|
||||
"filename": str, # e.g., "top_issues.csv"
|
||||
"description": str, # Human-readable description
|
||||
"rows": int, # Row count
|
||||
"cols": int, # Column count (optional, may be 0)
|
||||
"columns": list[str], # Column names (optional)
|
||||
"size_bytes": int, # File size
|
||||
"source": str, # "auto" | "prompt" — how the file was created
|
||||
}
|
||||
```
|
||||
|
||||
### SessionData Extension
|
||||
|
||||
```python
|
||||
class SessionData:
|
||||
rounds: List[Dict] = [] # List of Round_Data dicts
|
||||
data_files: List[Dict] = [] # List of File Metadata dicts
|
||||
```
|
||||
|
||||
### API Response: GET /api/status (extended)
|
||||
|
||||
```json
|
||||
{
|
||||
"is_running": true,
|
||||
"log": "...",
|
||||
"has_report": false,
|
||||
"rounds": [
|
||||
{
|
||||
"round": 1,
|
||||
"reasoning": "正在执行阶段1...",
|
||||
"code": "import pandas as pd\n...",
|
||||
"result_summary": "执行成功,输出 DataFrame (150行×8列)",
|
||||
"evidence_rows": [{"车型": "...", "模块": "..."}],
|
||||
"raw_log": "..."
|
||||
}
|
||||
],
|
||||
"progress_percentage": 25.0,
|
||||
"current_round": 1,
|
||||
"max_rounds": 20,
|
||||
"status_message": "第1/20轮分析中..."
|
||||
}
|
||||
```
|
||||
|
||||
### API Response: GET /api/data-files
|
||||
|
||||
```json
|
||||
{
|
||||
"files": [
|
||||
{
|
||||
"filename": "top_issues.csv",
|
||||
"description": "各类型TOP问题聚合统计",
|
||||
"rows": 25,
|
||||
"cols": 6,
|
||||
"size_bytes": 2048
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### API Response: GET /api/report (extended)
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "...",
|
||||
"base_path": "/outputs/session_xxx",
|
||||
"paragraphs": [...],
|
||||
"supporting_data": {
|
||||
"p-3": [{"车型": "A", "模块": "TSP", "数量": 42}],
|
||||
"p-7": [{"问题类型": "远控", "占比": "35%"}]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Correctness Properties
|
||||
|
||||
*A property is a characteristic or behavior that should hold true across all valid executions of a system — essentially, a formal statement about what the system should do. Properties serve as the bridge between human-readable specifications and machine-verifiable correctness guarantees.*
|
||||
|
||||
### Property 1: Round_Data Structural Completeness and Ordering
|
||||
|
||||
*For any* sequence of analysis rounds (varying in count from 1 to N, with varying execution results including successes, failures, and missing YAML fields), every Round_Data object appended to `SessionData.rounds` SHALL contain all required fields (`round`, `reasoning`, `code`, `result_summary`, `evidence_rows`, `raw_log`) with correct types, and the list SHALL preserve insertion order (i.e., `rounds[i].round <= rounds[i+1].round` for all consecutive pairs).
|
||||
|
||||
**Validates: Requirements 1.1, 1.3, 1.4**
|
||||
|
||||
### Property 2: Evidence Capture Bounded and Correctly Serialized
|
||||
|
||||
*For any* DataFrame of arbitrary size (0 to 10,000 rows, 1 to 50 columns) produced by code execution, the evidence capture SHALL return a list of at most 10 dictionaries, where each dictionary's keys exactly match the DataFrame's column names, and the list length equals `min(10, len(dataframe))`.
|
||||
|
||||
**Validates: Requirements 4.1, 4.2, 4.3**
|
||||
|
||||
### Property 3: Filename Deduplication Uniqueness
|
||||
|
||||
*For any* sequence of auto-export operations (1 to 20) targeting the same variable name in the same session directory, all generated filenames SHALL be unique (no two exports produce the same filename), and no previously existing file SHALL be overwritten.
|
||||
|
||||
**Validates: Requirements 5.3**
|
||||
|
||||
### Property 4: Auto-Export Metadata Completeness
|
||||
|
||||
*For any* newly detected DataFrame variable (with arbitrary variable name, row count, column count, and column names), the auto-export metadata dict SHALL contain all required fields (`variable_name`, `filename`, `rows`, `cols`, `columns`) with values matching the source DataFrame's actual properties.
|
||||
|
||||
**Validates: Requirements 5.4, 5.5**
|
||||
|
||||
### Property 5: DATA_FILE_SAVED Marker Parsing Round-Trip
|
||||
|
||||
*For any* valid filename string (alphanumeric, Chinese characters, underscores, hyphens, with .csv or .xlsx extension), any positive integer row count, and any non-empty description string, formatting these values into the standardized marker format `[DATA_FILE_SAVED] filename: {name}, rows: {count}, description: {desc}` and then parsing the marker SHALL recover the original filename, row count, and description exactly.
|
||||
|
||||
**Validates: Requirements 6.3**
|
||||
|
||||
### Property 6: Data File Preview Bounded Rows
|
||||
|
||||
*For any* CSV file containing 0 to 10,000 rows and 1 to 50 columns, the preview function SHALL return a result with `columns` matching the file's column names exactly, and `rows` containing at most 5 dictionaries, where each dictionary's keys match the column names.
|
||||
|
||||
**Validates: Requirements 7.2**
|
||||
|
||||
### Property 7: Evidence Annotation Parsing Correctness
|
||||
|
||||
*For any* Markdown report text containing a mix of paragraphs with and without `<!-- evidence:round_N -->` annotations (where N varies from 1 to 100), the annotation parser SHALL: (a) correctly extract the round number for every annotated paragraph, (b) exclude non-annotated paragraphs from the `supporting_data` mapping, and (c) produce a mapping where each key is a valid paragraph ID and each value references a valid round number.
|
||||
|
||||
**Validates: Requirements 11.3, 11.4**
|
||||
|
||||
### Property 8: SessionData JSON Serialization Round-Trip
|
||||
|
||||
*For any* `SessionData` instance with arbitrary `rounds` (list of Round_Data dicts) and `data_files` (list of file metadata dicts), serializing these to JSON and deserializing back SHALL produce lists that are equal to the originals.
|
||||
|
||||
**Validates: Requirements 12.4**
|
||||
|
||||
## Error Handling
|
||||
|
||||
### CodeExecutor Errors
|
||||
- **DataFrame evidence capture failure**: If `.head(10).to_dict(orient='records')` raises an exception (e.g., mixed types, memory issues), catch the exception and return an empty `evidence_rows` list. Log a warning but do not fail the execution.
|
||||
- **Auto-export failure**: If CSV writing fails for a detected DataFrame (e.g., permission error, disk full), catch the exception, log a warning with the variable name, and skip that export. Other detected DataFrames should still be exported.
|
||||
- **Marker parsing failure**: If a `[DATA_FILE_SAVED]` line doesn't match the expected format, skip it silently. Malformed markers should not crash the execution pipeline.
|
||||
|
||||
### API Errors
|
||||
- **Missing session**: All new endpoints return HTTP 404 with `{"detail": "Session not found"}` for invalid session IDs.
|
||||
- **Missing file**: `GET /api/data-files/preview` and `GET /api/data-files/download` return HTTP 404 with `{"detail": "File not found: {filename}"}` when the requested file doesn't exist in the session directory.
|
||||
- **Corrupt CSV**: If a CSV file can't be read by pandas during preview, return HTTP 500 with `{"detail": "Failed to read file: {error}"}`.
|
||||
|
||||
### Frontend Errors
|
||||
- **Polling with missing rounds**: If `rounds` is undefined or null in the status response, treat it as an empty array. Don't crash the rendering loop.
|
||||
- **Evidence popover with empty data**: If `supporting_data[paraId]` is an empty array, don't show the button (same as missing).
|
||||
- **Incremental rendering mismatch**: If `rounds.length < lastRenderedRound` (server restart scenario), reset `lastRenderedRound` to 0 and re-render all cards.
|
||||
|
||||
### Agent Errors
|
||||
- **Missing reasoning field**: Already handled — store empty string (Requirement 1.4).
|
||||
- **Evidence annotation missing**: Already handled — paragraphs without annotations simply don't get supporting data buttons. This is by design, not an error.
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Property-Based Tests (Hypothesis)
|
||||
|
||||
The project already uses `hypothesis` with `max_examples=20` for fast execution (see `tests/test_properties.py`). New property tests will follow the same pattern.
|
||||
|
||||
**Library**: `hypothesis` (already installed)
|
||||
**Configuration**: `max_examples=100` minimum per property (increased from existing 20 for new properties)
|
||||
**Tag format**: `Feature: analysis-dashboard-redesign, Property {N}: {title}`
|
||||
|
||||
Properties to implement:
|
||||
1. **Round_Data structural completeness** — Generate random execution results, verify Round_Data fields
|
||||
2. **Evidence capture bounded** — Generate random DataFrames, verify evidence row count and format
|
||||
3. **Filename deduplication** — Generate sequences of same-name exports, verify uniqueness
|
||||
4. **Auto-export metadata** — Generate random DataFrames, verify metadata fields
|
||||
5. **Marker parsing round-trip** — Generate random filenames/rows/descriptions, verify parse(format(x)) == x
|
||||
6. **Preview bounded rows** — Generate random CSVs, verify preview row count and columns
|
||||
7. **Evidence annotation parsing** — Generate random annotated Markdown, verify extraction
|
||||
8. **SessionData JSON round-trip** — Generate random rounds/data_files, verify serialize/deserialize identity
|
||||
|
||||
### Unit Tests
|
||||
|
||||
- Prompt content assertions (6.1, 6.2, 11.2): Verify prompt strings contain required instruction text
|
||||
- SessionData initialization (12.1, 12.2): Verify new attributes exist with correct defaults
|
||||
- API response shape (2.1, 2.3): Verify status endpoint returns rounds array and log field
|
||||
- Tab switching (9.4): Verify switchTab handles new tab identifiers
|
||||
|
||||
### Integration Tests
|
||||
|
||||
- End-to-end round capture: Run a mini analysis session, verify rounds are populated
|
||||
- Data file API flow: Create files, call list/preview/download endpoints, verify responses
|
||||
- Report evidence linking: Generate a report with annotations, call report API, verify supporting_data mapping
|
||||
|
||||
### Manual Testing
|
||||
|
||||
- UI layout verification (3.1-3.6, 8.1-8.5, 9.1-9.3, 10.1-10.4): Visual inspection of tab layout, round cards, data file cards, inline images, and supporting data popovers
|
||||
159
.kiro/specs/analysis-dashboard-redesign/requirements.md
Normal file
159
.kiro/specs/analysis-dashboard-redesign/requirements.md
Normal file
@@ -0,0 +1,159 @@
|
||||
# Requirements Document
|
||||
|
||||
## Introduction
|
||||
|
||||
This feature redesigns the Analysis Dashboard from the current 3-tab layout (Live Log, Report, Gallery) to a new 3-tab layout (Execution Process, Data Files, Report) with richer functionality. The redesign introduces structured round-by-round execution cards, intermediate data file browsing, inline image display within the report, and a data evidence/supporting data feature that links analysis conclusions to the specific data rows that support them. The Gallery tab is removed; its functionality is absorbed into the Report tab.
|
||||
|
||||
## Glossary
|
||||
|
||||
- **Dashboard**: The main analysis output panel in the web frontend (`index.html`) containing tabs for viewing analysis results.
|
||||
- **Execution_Process_Tab**: The new first tab (执行过程) replacing the Live Log tab, displaying analysis rounds as collapsible cards.
|
||||
- **Round_Card**: A collapsible UI card within the Execution_Process_Tab representing one analysis round, containing reasoning, code, result summary, data evidence, and raw log.
|
||||
- **Data_Files_Tab**: The new second tab (数据文件) showing intermediate data files produced during analysis.
|
||||
- **Report_Tab**: The enhanced third tab (报告) with inline images and supporting data links.
|
||||
- **Data_Evidence**: Specific data rows extracted during analysis that support a particular analytical conclusion or claim.
|
||||
- **CodeExecutor**: The Python class (`utils/code_executor.py`) responsible for executing generated analysis code in an IPython environment.
|
||||
- **DataAnalysisAgent**: The Python class (`data_analysis_agent.py`) orchestrating the multi-round LLM-driven analysis workflow.
|
||||
- **SessionData**: The Python class (`web/main.py`) tracking per-session state including running status, output directory, and analysis results.
|
||||
- **Status_API**: The `GET /api/status` endpoint polled every 2 seconds by the frontend to retrieve analysis progress.
|
||||
- **Data_Files_API**: The new set of API endpoints (`GET /api/data-files`, `GET /api/data-files/preview`, `GET /api/data-files/download`) for listing, previewing, and downloading intermediate data files.
|
||||
- **Round_Data**: A structured JSON object representing one analysis round, containing fields for reasoning, code, execution result summary, data evidence rows, and raw log output.
|
||||
- **Auto_Detection**: The mechanism by which CodeExecutor automatically detects new DataFrames created during code execution and exports them as files.
|
||||
- **Prompt_Guidance**: Instructions embedded in the system prompt that direct the LLM to proactively save intermediate analysis results as files.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Requirement 1: Structured Round Data Capture
|
||||
|
||||
**User Story:** As a user, I want each analysis round's data to be captured in a structured format, so that the frontend can render rich execution cards instead of raw log text.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN an analysis round completes, THE DataAnalysisAgent SHALL produce a Round_Data object containing the following fields: round number, AI reasoning text, generated code, execution result summary, data evidence rows (list of dictionaries), and raw log output.
|
||||
2. WHEN the DataAnalysisAgent processes an LLM response with a YAML `reasoning` field, THE DataAnalysisAgent SHALL extract and store the reasoning text in the Round_Data object for that round.
|
||||
3. THE DataAnalysisAgent SHALL append each completed Round_Data object to a list stored on the SessionData instance, preserving insertion order.
|
||||
4. IF the LLM response does not contain a parseable `reasoning` field, THEN THE DataAnalysisAgent SHALL store an empty string as the reasoning text in the Round_Data object.
|
||||
|
||||
### Requirement 2: Structured Status API Response
|
||||
|
||||
**User Story:** As a frontend developer, I want the status API to return structured round data, so that I can render execution cards in real time.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN the frontend polls `GET /api/status`, THE Status_API SHALL return a JSON response containing a `rounds` array of Round_Data objects in addition to the existing fields (`is_running`, `has_report`, `progress_percentage`, `current_round`, `max_rounds`, `status_message`).
|
||||
2. WHEN a new analysis round completes between two polling intervals, THE Status_API SHALL include the newly completed Round_Data object in the `rounds` array on the next poll response.
|
||||
3. THE Status_API SHALL continue to return the `log` field containing raw log text for backward compatibility.
|
||||
|
||||
### Requirement 3: Execution Process Tab UI
|
||||
|
||||
**User Story:** As a user, I want to see each analysis round as a collapsible card with reasoning, code, results, and data evidence, so that I can understand the step-by-step analysis process.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. THE Dashboard SHALL display an "执行过程" (Execution Process) tab as the first tab, replacing the current "Live Log" tab.
|
||||
2. WHEN the Execution_Process_Tab is active, THE Dashboard SHALL render one Round_Card for each entry in the `rounds` array returned by the Status_API.
|
||||
3. THE Round_Card SHALL default to a collapsed state showing only the round number and a one-line execution result summary.
|
||||
4. WHEN a user clicks on a collapsed Round_Card, THE Dashboard SHALL expand the card to reveal: AI reasoning text, generated code (in a collapsible sub-section), execution result summary, data evidence section (labeled "本轮数据案例"), and raw log output (in a collapsible sub-section).
|
||||
5. WHEN a new Round_Data object appears in the polling response, THE Dashboard SHALL append a new Round_Card to the Execution_Process_Tab without removing or re-rendering existing cards.
|
||||
6. WHILE analysis is running, THE Dashboard SHALL auto-scroll the Execution_Process_Tab to keep the latest Round_Card visible.
|
||||
|
||||
### Requirement 4: Data Evidence Capture
|
||||
|
||||
**User Story:** As a user, I want to see the specific data rows that support each analytical conclusion, so that I can verify claims made by the AI agent.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN the CodeExecutor executes code that produces a DataFrame result, THE CodeExecutor SHALL capture up to 10 representative rows from that DataFrame as the data evidence for the current round.
|
||||
2. THE CodeExecutor SHALL serialize data evidence rows as a list of dictionaries (one dictionary per row, keys being column names) and include the list in the execution result returned to the DataAnalysisAgent.
|
||||
3. IF the code execution does not produce a DataFrame result, THEN THE CodeExecutor SHALL return an empty list as the data evidence.
|
||||
4. THE DataAnalysisAgent SHALL include the data evidence list in the Round_Data object for the corresponding round.
|
||||
|
||||
### Requirement 5: DataFrame Auto-Detection and Export
|
||||
|
||||
**User Story:** As a user, I want intermediate DataFrames created during analysis to be automatically saved as files, so that I can browse and download them from the Data Files tab.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN code execution completes, THE CodeExecutor SHALL compare the set of DataFrame variables in the IPython namespace before and after execution to detect newly created DataFrames.
|
||||
2. WHEN a new DataFrame variable is detected, THE CodeExecutor SHALL export the DataFrame to the session output directory as a CSV file named `{variable_name}.csv`.
|
||||
3. IF a file with the same name already exists in the session output directory, THEN THE CodeExecutor SHALL append a numeric suffix (e.g., `_1`, `_2`) to avoid overwriting.
|
||||
4. THE CodeExecutor SHALL record metadata for each auto-exported file: variable name, filename, row count, column count, and column names.
|
||||
5. WHEN auto-export completes, THE CodeExecutor SHALL include the exported file metadata in the execution result returned to the DataAnalysisAgent.
|
||||
|
||||
### Requirement 6: Prompt Guidance for Intermediate File Saving
|
||||
|
||||
**User Story:** As a user, I want the LLM to proactively save intermediate analysis results as files, so that important intermediate datasets are available for review.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. THE system prompt (`prompts.py`) SHALL include instructions directing the LLM to save intermediate analysis results (filtered subsets, aggregation tables, clustering results) as CSV or XLSX files in the `session_output_dir`.
|
||||
2. THE system prompt SHALL instruct the LLM to print a standardized marker line after saving each file, in the format: `[DATA_FILE_SAVED] filename: {name}, rows: {count}, description: {desc}`.
|
||||
3. WHEN the CodeExecutor detects a `[DATA_FILE_SAVED]` marker in the execution output, THE CodeExecutor SHALL parse the marker and record the file metadata (filename, row count, description).
|
||||
|
||||
### Requirement 7: Data Files API
|
||||
|
||||
**User Story:** As a frontend developer, I want API endpoints to list, preview, and download intermediate data files, so that the Data Files tab can display and serve them.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN the frontend requests `GET /api/data-files?session_id={id}`, THE Data_Files_API SHALL return a JSON array of file entries, each containing: filename, description, row count, column count, and file size in bytes.
|
||||
2. WHEN the frontend requests `GET /api/data-files/preview?session_id={id}&filename={name}`, THE Data_Files_API SHALL return a JSON object containing: column names (list of strings), and up to 5 data rows (list of dictionaries).
|
||||
3. WHEN the frontend requests `GET /api/data-files/download?session_id={id}&filename={name}`, THE Data_Files_API SHALL return the file as a downloadable attachment with the appropriate MIME type (`text/csv` for CSV, `application/vnd.openxmlformats-officedocument.spreadsheetml.sheet` for XLSX).
|
||||
4. IF the requested file does not exist, THEN THE Data_Files_API SHALL return HTTP 404 with a descriptive error message.
|
||||
|
||||
### Requirement 8: Data Files Tab UI
|
||||
|
||||
**User Story:** As a user, I want to browse intermediate data files produced during analysis, preview their contents, and download them individually.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. THE Dashboard SHALL display a "数据文件" (Data Files) tab as the second tab.
|
||||
2. WHEN the Data_Files_Tab is active, THE Dashboard SHALL fetch the file list from `GET /api/data-files` and render each file as a card showing: filename, description, and row count.
|
||||
3. WHEN a user clicks on a file card, THE Dashboard SHALL fetch the preview from `GET /api/data-files/preview` and display a table showing column headers and up to 5 data rows.
|
||||
4. WHEN a user clicks the download button on a file card, THE Dashboard SHALL initiate a file download via `GET /api/data-files/download`.
|
||||
5. WHILE analysis is running, THE Dashboard SHALL refresh the file list on each polling cycle to show newly created files.
|
||||
|
||||
### Requirement 9: Gallery Removal and Inline Images in Report
|
||||
|
||||
**User Story:** As a user, I want images displayed inline within report paragraphs instead of in a separate Gallery tab, so that visual evidence is presented in context.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. THE Dashboard SHALL remove the "Gallery" tab from the tab bar.
|
||||
2. THE Dashboard SHALL remove the gallery carousel UI (carousel container, navigation buttons, image info panel) from the HTML.
|
||||
3. THE Report_Tab SHALL render images inline within report paragraphs using standard Markdown image syntax (``), as already supported by the existing `marked.js` rendering.
|
||||
4. THE `switchTab` function in `script.js` SHALL handle only the three new tab identifiers: `execution`, `datafiles`, and `report`.
|
||||
5. THE frontend SHALL remove all gallery-related JavaScript functions (`loadGallery`, `renderGalleryImage`, `prevImage`, `nextImage`) and associated state variables (`galleryImages`, `currentImageIndex`).
|
||||
|
||||
### Requirement 10: Supporting Data Button in Report
|
||||
|
||||
**User Story:** As a user, I want report paragraphs that make data-driven claims to have a "查看支撑数据" button, so that I can view the evidence data that supports each conclusion.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN the Report_Tab renders a paragraph of type `text` that has associated data evidence, THE Dashboard SHALL display a "查看支撑数据" (View Supporting Data) button below the paragraph content.
|
||||
2. WHEN a user clicks the "查看支撑数据" button, THE Dashboard SHALL display a popover or modal showing the associated data evidence rows in a table format.
|
||||
3. THE `GET /api/report` response SHALL include a `supporting_data` mapping (keyed by paragraph ID) containing the data evidence rows relevant to each paragraph.
|
||||
4. IF a paragraph has no associated data evidence, THEN THE Dashboard SHALL not display the "查看支撑数据" button for that paragraph.
|
||||
|
||||
### Requirement 11: Report-to-Evidence Linking in Backend
|
||||
|
||||
**User Story:** As a backend developer, I want the system to associate data evidence from execution rounds with report paragraphs, so that the frontend can display supporting data buttons.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN generating the final report, THE DataAnalysisAgent SHALL pass the collected data evidence from all rounds to the report generation prompt.
|
||||
2. THE final report generation prompt SHALL instruct the LLM to annotate report paragraphs with round references (e.g., `<!-- evidence:round_3 -->`) when a paragraph's content is derived from a specific analysis round.
|
||||
3. WHEN the `GET /api/report` endpoint parses the report, THE backend SHALL extract evidence annotations and build a `supporting_data` mapping by looking up the referenced round's data evidence from the SessionData.
|
||||
4. IF a paragraph contains no evidence annotation, THEN THE backend SHALL exclude that paragraph from the `supporting_data` mapping.
|
||||
|
||||
### Requirement 12: Session Data Model Extension
|
||||
|
||||
**User Story:** As a backend developer, I want the SessionData model to store structured round data and data file metadata, so that the new API endpoints can serve this information.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. THE SessionData class SHALL include a `rounds` attribute (list of Round_Data dictionaries) to store structured data for each completed analysis round.
|
||||
2. THE SessionData class SHALL include a `data_files` attribute (list of file metadata dictionaries) to store information about intermediate data files.
|
||||
3. WHEN a new data file is detected (via auto-detection or prompt-guided saving), THE DataAnalysisAgent SHALL append the file metadata to the SessionData `data_files` list.
|
||||
4. THE SessionData class SHALL persist the `rounds` and `data_files` attributes to the session's `results.json` file upon analysis completion.
|
||||
102
.kiro/specs/analysis-dashboard-redesign/tasks.md
Normal file
102
.kiro/specs/analysis-dashboard-redesign/tasks.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# Tasks: Analysis Dashboard Redesign
|
||||
|
||||
## Phase 1: Backend Data Model + API Changes (Foundation)
|
||||
|
||||
- [x] 1. Extend SessionData model
|
||||
- [x] 1.1 Add `rounds: List[Dict]` attribute to `SessionData.__init__()` in `web/main.py`, initialized to empty list
|
||||
- [x] 1.2 Add `data_files: List[Dict]` attribute to `SessionData.__init__()` in `web/main.py`, initialized to empty list
|
||||
- [x] 1.3 Update `_reconstruct_session()` to load `rounds` and `data_files` from `results.json` when reconstructing historical sessions
|
||||
- [x] 1.4 Update `run_analysis_task()` to persist `session.rounds` and `session.data_files` to `results.json` on analysis completion
|
||||
|
||||
- [x] 2. Update Status API response
|
||||
- [x] 2.1 Add `rounds` field to `GET /api/status` response dict, returning `session.rounds`
|
||||
- [x] 2.2 Verify backward compatibility: ensure `log`, `is_running`, `has_report`, `progress_percentage`, `current_round`, `max_rounds`, `status_message` fields remain unchanged
|
||||
|
||||
- [x] 3. Add Data Files API endpoints
|
||||
- [x] 3.1 Implement `GET /api/data-files` endpoint: return `session.data_files` merged with fallback directory scan for CSV/XLSX files, each entry containing filename, description, rows, cols, size_bytes
|
||||
- [x] 3.2 Implement `GET /api/data-files/preview` endpoint: read CSV/XLSX via pandas, return `{columns: [...], rows: [...first 5 rows as dicts...]}`; return 404 if file not found
|
||||
- [x] 3.3 Implement `GET /api/data-files/download` endpoint: return `FileResponse` with correct MIME type (`text/csv` or `application/vnd.openxmlformats-officedocument.spreadsheetml.sheet`); return 404 if file not found
|
||||
|
||||
- [x] 4. Enhance Report API for evidence linking
|
||||
- [x] 4.1 Implement `_extract_evidence_annotations(paragraphs, session)` function: parse `<!-- evidence:round_N -->` comments from paragraph content, look up `session.rounds[N-1].evidence_rows`, build `supporting_data` mapping keyed by paragraph ID
|
||||
- [x] 4.2 Update `GET /api/report` to include `supporting_data` mapping in response JSON
|
||||
|
||||
## Phase 2: CodeExecutor Enhancements
|
||||
|
||||
- [x] 5. Add evidence capture to CodeExecutor
|
||||
- [x] 5.1 In `execute_code()`, after successful execution, check if `result.result` is a DataFrame; if so, capture `result.result.head(10).to_dict(orient='records')` as `evidence_rows`; wrap in try/except returning empty list on failure
|
||||
- [x] 5.2 Also check the last-assigned DataFrame variable in the namespace as a fallback evidence source when `result.result` is not a DataFrame
|
||||
- [x] 5.3 Include `evidence_rows` key in the returned result dict
|
||||
|
||||
- [x] 6. Add DataFrame auto-detection and export
|
||||
- [x] 6.1 Before `shell.run_cell(code)`, snapshot DataFrame variables: `{name: id(obj) for name, obj in shell.user_ns.items() if isinstance(obj, pd.DataFrame)}`
|
||||
- [x] 6.2 After execution, compare snapshots to detect new or changed DataFrame variables
|
||||
- [x] 6.3 For each new DataFrame, export to `{output_dir}/{var_name}.csv` with numeric suffix deduplication if file exists
|
||||
- [x] 6.4 Record metadata for each export: `{variable_name, filename, rows, cols, columns}` in `auto_exported_files` list
|
||||
- [x] 6.5 Include `auto_exported_files` key in the returned result dict
|
||||
|
||||
- [x] 7. Add DATA_FILE_SAVED marker parsing
|
||||
- [x] 7.1 After execution, scan `captured.stdout` for lines matching `[DATA_FILE_SAVED] filename: {name}, rows: {count}, description: {desc}`
|
||||
- [x] 7.2 Parse each marker line and record `{filename, rows, description}` in `prompt_saved_files` list
|
||||
- [x] 7.3 Include `prompt_saved_files` key in the returned result dict
|
||||
|
||||
## Phase 3: Agent Changes
|
||||
|
||||
- [x] 8. Structured Round_Data construction in DataAnalysisAgent
|
||||
- [x] 8.1 Add `_summarize_result(result)` method: produce one-line summary from execution result (e.g., "执行成功,输出 DataFrame (150行×8列)" or "执行失败: {error}")
|
||||
- [x] 8.2 In `_handle_generate_code()`, construct `round_data` dict with fields: round, reasoning (from `yaml_data.get("reasoning", "")`), code, result_summary, evidence_rows, raw_log, auto_exported_files, prompt_saved_files
|
||||
- [x] 8.3 After constructing round_data, append it to `SessionData.rounds` (via progress callback or direct reference)
|
||||
- [x] 8.4 Merge file metadata from `auto_exported_files` and `prompt_saved_files` into `SessionData.data_files`
|
||||
|
||||
- [x] 9. Update system prompts
|
||||
- [x] 9.1 Add intermediate data saving instructions to `data_analysis_system_prompt` in `prompts.py`: instruct LLM to save intermediate results and print `[DATA_FILE_SAVED]` marker
|
||||
- [x] 9.2 Add evidence annotation instructions to `final_report_system_prompt` in `prompts.py`: instruct LLM to add `<!-- evidence:round_N -->` comments to report paragraphs
|
||||
- [x] 9.3 Update `_build_final_report_prompt()` in `data_analysis_agent.py` to include collected evidence data from all rounds in the prompt context
|
||||
|
||||
## Phase 4: Frontend Tab Restructuring
|
||||
|
||||
- [x] 10. HTML restructuring
|
||||
- [x] 10.1 In `index.html`, replace tab labels: "Live Log" → "执行过程", add "数据文件" tab, keep "Report"; remove "Gallery" tab
|
||||
- [x] 10.2 Replace the `logsTab` div content with an Execution Process container (`executionTab`) containing a scrollable round-cards wrapper
|
||||
- [x] 10.3 Add a `datafilesTab` div with a file-cards grid container and a preview panel area
|
||||
- [x] 10.4 Remove the Gallery tab HTML: carousel container, navigation buttons, image info panel
|
||||
|
||||
- [x] 11. JavaScript: Execution Process Tab
|
||||
- [x] 11.1 Add `lastRenderedRound` state variable and `renderRoundCards(rounds)` function: compare `rounds.length` with `lastRenderedRound`, create and append new Round_Card DOM elements for new entries only
|
||||
- [x] 11.2 Implement Round_Card HTML generation: collapsed state shows round number + result_summary; expanded state shows reasoning, code (collapsible), result_summary, evidence table ("本轮数据案例"), raw log (collapsible)
|
||||
- [x] 11.3 Add click handler for Round_Card toggle (collapse/expand)
|
||||
- [x] 11.4 Add auto-scroll logic: when analysis is running, scroll Execution Process container to bottom after appending new cards
|
||||
|
||||
- [x] 12. JavaScript: Data Files Tab
|
||||
- [x] 12.1 Implement `loadDataFiles()`: fetch `GET /api/data-files`, render file cards showing filename, description, row count
|
||||
- [x] 12.2 Implement `previewDataFile(filename)`: fetch `GET /api/data-files/preview`, render a table with column headers and up to 5 rows
|
||||
- [x] 12.3 Implement `downloadDataFile(filename)`: trigger download via `GET /api/data-files/download`
|
||||
- [x] 12.4 In `startPolling()`, call `loadDataFiles()` on each polling cycle when Data Files tab is active or when analysis is running
|
||||
|
||||
- [x] 13. JavaScript: Gallery removal and tab updates
|
||||
- [x] 13.1 Remove gallery functions: `loadGallery`, `renderGalleryImage`, `prevImage`, `nextImage` and state variables `galleryImages`, `currentImageIndex`
|
||||
- [x] 13.2 Update `switchTab()` to handle `execution`, `datafiles`, `report` identifiers instead of `logs`, `report`, `gallery`
|
||||
- [x] 13.3 Update `startPolling()` to call `renderRoundCards()` with `data.rounds` on each polling cycle
|
||||
|
||||
- [x] 14. JavaScript: Supporting data in Report
|
||||
- [x] 14.1 Update `loadReport()` to store `supporting_data` mapping from API response
|
||||
- [x] 14.2 Update `renderParagraphReport()` to add "查看支撑数据" button below paragraphs that have entries in `supporting_data`
|
||||
- [x] 14.3 Implement `showSupportingData(paraId)`: display a popover/modal with evidence rows rendered as a table
|
||||
|
||||
- [x] 15. CSS updates
|
||||
- [x] 15.1 Add `.round-card`, `.round-card-header`, `.round-card-body`, `.round-card-collapsed`, `.round-card-expanded` styles
|
||||
- [x] 15.2 Add `.data-file-card`, `.data-preview-table` styles
|
||||
- [x] 15.3 Add `.supporting-data-btn`, `.supporting-data-popover` styles
|
||||
- [x] 15.4 Remove `.carousel-*` styles (carousel-container, carousel-slide, carousel-btn, image-info, image-title, image-desc)
|
||||
|
||||
## Phase 5: Property-Based Tests
|
||||
|
||||
- [x] 16. Write property-based tests
|
||||
- [x] 16.1 ~PBT~ Property 1: Round_Data structural completeness — generate random execution results, verify all required fields present with correct types and insertion order preserved
|
||||
- [x] 16.2 ~PBT~ Property 2: Evidence capture bounded — generate random DataFrames (0-10000 rows, 1-50 cols), verify evidence_rows length <= 10 and each row dict has correct keys
|
||||
- [x] 16.3 ~PBT~ Property 3: Filename deduplication — generate sequences of same-name exports (1-20), verify all filenames unique
|
||||
- [x] 16.4 ~PBT~ Property 4: Auto-export metadata completeness — generate random DataFrames, verify metadata contains variable_name, filename, rows, cols, columns with correct values
|
||||
- [x] 16.5 ~PBT~ Property 5: DATA_FILE_SAVED marker parsing round-trip — generate random filenames/rows/descriptions, verify parse(format(x)) == x
|
||||
- [x] 16.6 ~PBT~ Property 6: Data file preview bounded rows — generate random CSVs (0-10000 rows), verify preview returns at most 5 rows with correct column names
|
||||
- [x] 16.7 ~PBT~ Property 7: Evidence annotation parsing — generate random annotated Markdown, verify correct round extraction and non-annotated paragraph exclusion
|
||||
- [x] 16.8 ~PBT~ Property 8: SessionData JSON round-trip — generate random rounds/data_files, verify serialize then deserialize produces equal data
|
||||
21
LICENSE
21
LICENSE
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Data Analysis Agent Team
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
50
README.md
50
README.md
@@ -31,7 +31,9 @@ data_analysis_agent/
|
||||
│ ├── fallback_openai_client.py # 支持故障转移的OpenAI客户端
|
||||
│ ├── extract_code.py # 代码提取工具
|
||||
│ ├── format_execution_result.py # 执行结果格式化
|
||||
│ └── create_session_dir.py # 会话目录管理
|
||||
│ ├── create_session_dir.py # 会话目录管理
|
||||
│ ├── data_loader.py # 数据加载与画像生成
|
||||
│ └── script_generator.py # 可复用脚本生成器
|
||||
├── 📄 data_analysis_agent.py # 主智能体类
|
||||
├── 📄 prompts.py # 系统提示词模板
|
||||
├── 📄 main.py # 使用示例
|
||||
@@ -160,7 +162,7 @@ agent = DataAnalysisAgent(llm_config)
|
||||
# 开始分析
|
||||
files = ["your_data.csv"]
|
||||
report = agent.analyze(
|
||||
user_input="分析销售数据,生成趋势图表和关键指标",
|
||||
user_input="分析XXXXXXXXX数据,生成趋势图表和关键指标",
|
||||
files=files
|
||||
)
|
||||
|
||||
@@ -191,9 +193,9 @@ report = quick_analysis(
|
||||
|
||||
```python
|
||||
# 示例:茅台财务分析
|
||||
files = ["贵州茅台利润表.csv"]
|
||||
files = ["XXXXXXXXx.csv"]
|
||||
report = agent.analyze(
|
||||
user_input="基于贵州茅台的数据,输出五个重要的统计指标,并绘制相关图表。最后生成汇报给我。",
|
||||
user_input="基于数据,输出五个重要的统计指标,并绘制相关图表。最后生成汇报给我。",
|
||||
files=files
|
||||
)
|
||||
```
|
||||
@@ -207,6 +209,33 @@ report = agent.analyze(
|
||||
- 📋 营业成本占比分析
|
||||
- 📄 综合分析报告
|
||||
|
||||
## 🌐 Web界面可视化
|
||||
|
||||
本项目提供了现代化的Web界面,支持零代码交互。
|
||||
|
||||
### 启动方式
|
||||
|
||||
**macOS/Linux:**
|
||||
```bash
|
||||
./start_web.sh
|
||||
```
|
||||
|
||||
**Windows:**
|
||||
```bash
|
||||
start_web.bat
|
||||
```
|
||||
|
||||
访问地址: `http://localhost:8000`
|
||||
|
||||
### 核心功能 (Web)
|
||||
|
||||
- **🖼️ 图表画廊 (Gallery)**: 网格化展示所有生成图表,每张图表附带AI生成的分析解读。
|
||||
- **📜 实时日志**: 像黑客帝国一样实时查看后台分析过程和Agent的思考逻辑。
|
||||
- **📦 一键导出**: 支持一键下载包含 Markdown 报告和所有高清原图的 ZIP 压缩包。
|
||||
- **🛠️ 数据工具箱**:
|
||||
- **Excel合并**: 将多个同构 Excel 文件快速合并为分析可用的 CSV。
|
||||
- **时间排序**: 自动修复 CSV 数据的乱序问题,确保时序分析准确。
|
||||
|
||||
## 🎨 流程可视化
|
||||
|
||||
### 📊 分析过程状态图
|
||||
@@ -239,12 +268,15 @@ stateDiagram-v2
|
||||
```python
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
provider: str = "openai"
|
||||
provider: str = os.environ.get("LLM_PROVIDER", "openai")
|
||||
api_key: str = os.environ.get("OPENAI_API_KEY", "")
|
||||
base_url: str = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
model: str = os.environ.get("OPENAI_MODEL", "gpt-4")
|
||||
max_tokens: int = 4000
|
||||
temperature: float = 0.1
|
||||
temperature: float = 0.5
|
||||
max_tokens: int = 8192
|
||||
|
||||
# 支持 gemini 等其他 provider 配置
|
||||
# ...
|
||||
```
|
||||
|
||||
### 执行器配置
|
||||
@@ -254,7 +286,9 @@ class LLMConfig:
|
||||
ALLOWED_IMPORTS = {
|
||||
'pandas', 'numpy', 'matplotlib', 'duckdb',
|
||||
'scipy', 'sklearn', 'plotly', 'requests',
|
||||
'os', 'json', 'datetime', 're', 'pathlib'
|
||||
'os', 'json', 'datetime', 're', 'pathlib',
|
||||
'seaborn', 'statsmodels', 'networkx', 'jieba',
|
||||
'wordcloud', 'PIL', 'sqlite3', 'yaml'
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
17
__init__.py
17
__init__.py
@@ -5,9 +5,20 @@ Data Analysis Agent Package
|
||||
一个基于LLM的智能数据分析代理,专门为Jupyter Notebook环境设计。
|
||||
"""
|
||||
|
||||
from .core.notebook_agent import NotebookAgent
|
||||
from .config.llm_config import LLMConfig
|
||||
from .utils.code_executor import CodeExecutor
|
||||
try:
|
||||
from .core.notebook_agent import NotebookAgent
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
NotebookAgent = None
|
||||
|
||||
try:
|
||||
from .config.llm_config import LLMConfig
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
from config.llm_config import LLMConfig
|
||||
|
||||
try:
|
||||
from .utils.code_executor import CodeExecutor
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
from utils.code_executor import CodeExecutor
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Data Analysis Agent Team"
|
||||
|
||||
62
bootstrap.py
Normal file
62
bootstrap.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import sys
|
||||
import subprocess
|
||||
import importlib.metadata
|
||||
import os
|
||||
|
||||
def check_dependencies():
|
||||
"""Checks if dependencies in requirements.txt are installed."""
|
||||
requirements_file = "requirements.txt"
|
||||
if not os.path.exists(requirements_file):
|
||||
print(f"Warning: {requirements_file} not found. Skipping dependency check.")
|
||||
return
|
||||
|
||||
print("Checking dependencies...")
|
||||
missing_packages = []
|
||||
|
||||
with open(requirements_file, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
# Simple parsing for package name.
|
||||
# This handles 'package>=version', 'package==version', 'package'
|
||||
# It does NOT handle complex markers perfectly, but suffices for basic checking.
|
||||
package_name = line.split("=")[0].split(">")[0].split("<")[0].strip()
|
||||
|
||||
try:
|
||||
importlib.metadata.version(package_name)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
missing_packages.append(line)
|
||||
|
||||
if missing_packages:
|
||||
print(f"Missing dependencies: {', '.join(missing_packages)}")
|
||||
print("Installing missing dependencies...")
|
||||
try:
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", requirements_file])
|
||||
print("Dependencies installed successfully.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error installing dependencies: {e}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("All dependencies checked.")
|
||||
|
||||
def main():
|
||||
check_dependencies()
|
||||
|
||||
print("Starting application...")
|
||||
try:
|
||||
# Run the main application
|
||||
# Using sys.executable ensures we use the same python interpreter
|
||||
subprocess.run([sys.executable, "main.py"], check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Application exited with error: {e}")
|
||||
sys.exit(e.returncode)
|
||||
except KeyboardInterrupt:
|
||||
print("\nApplication stopped by user.")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -4,5 +4,6 @@
|
||||
"""
|
||||
|
||||
from .llm_config import LLMConfig
|
||||
from .app_config import AppConfig, app_config
|
||||
|
||||
__all__ = ['LLMConfig']
|
||||
__all__ = ['LLMConfig', 'AppConfig', 'app_config']
|
||||
|
||||
87
config/app_config.py
Normal file
87
config/app_config.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
应用配置中心 - 集中管理所有配置项
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
"""应用配置中心"""
|
||||
|
||||
# 分析配置
|
||||
max_rounds: int = field(default=20)
|
||||
force_max_rounds: bool = field(default=False)
|
||||
default_output_dir: str = field(default="outputs")
|
||||
|
||||
# 数据处理配置
|
||||
max_file_size_mb: int = field(default=500) # 最大文件大小(MB)
|
||||
chunk_size: int = field(default=100000) # 分块读取大小
|
||||
data_cache_enabled: bool = field(default=True)
|
||||
cache_dir: str = field(default=".cache/data")
|
||||
|
||||
# LLM配置
|
||||
llm_cache_enabled: bool = field(default=True)
|
||||
llm_cache_dir: str = field(default=".cache/llm")
|
||||
llm_stream_enabled: bool = field(default=False)
|
||||
|
||||
# 代码执行配置
|
||||
code_timeout: int = field(default=300) # 代码执行超时(秒)
|
||||
|
||||
# Web配置
|
||||
web_host: str = field(default="0.0.0.0")
|
||||
web_port: int = field(default=8000)
|
||||
upload_dir: str = field(default="uploads")
|
||||
|
||||
# 日志配置
|
||||
log_filename: str = field(default="log.txt")
|
||||
enable_code_logging: bool = field(default=False) # 是否记录生成的代码
|
||||
|
||||
# 健壮性配置
|
||||
max_data_context_retries: int = field(default=2) # 数据上下文错误最大重试次数
|
||||
conversation_window_size: int = field(default=10) # 对话历史滑动窗口大小(消息对数)
|
||||
max_parallel_profiles: int = field(default=4) # 并行数据画像最大线程数
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'AppConfig':
|
||||
"""从环境变量创建配置"""
|
||||
config = cls()
|
||||
|
||||
# 从环境变量覆盖配置
|
||||
if max_rounds := os.getenv("APP_MAX_ROUNDS"):
|
||||
config.max_rounds = int(max_rounds)
|
||||
|
||||
if chunk_size := os.getenv("APP_CHUNK_SIZE"):
|
||||
config.chunk_size = int(chunk_size)
|
||||
|
||||
if cache_enabled := os.getenv("APP_CACHE_ENABLED"):
|
||||
config.data_cache_enabled = cache_enabled.lower() == "true"
|
||||
|
||||
if val := os.getenv("APP_MAX_DATA_CONTEXT_RETRIES"):
|
||||
config.max_data_context_retries = int(val)
|
||||
if val := os.getenv("APP_CONVERSATION_WINDOW_SIZE"):
|
||||
config.conversation_window_size = int(val)
|
||||
if val := os.getenv("APP_MAX_PARALLEL_PROFILES"):
|
||||
config.max_parallel_profiles = int(val)
|
||||
|
||||
return config
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证配置"""
|
||||
if self.max_rounds <= 0:
|
||||
raise ValueError("max_rounds must be positive")
|
||||
|
||||
if self.chunk_size <= 0:
|
||||
raise ValueError("chunk_size must be positive")
|
||||
|
||||
if self.code_timeout <= 0:
|
||||
raise ValueError("code_timeout must be positive")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# 全局配置实例
|
||||
app_config = AppConfig.from_env()
|
||||
@@ -17,12 +17,25 @@ load_dotenv()
|
||||
class LLMConfig:
|
||||
"""LLM配置"""
|
||||
|
||||
provider: str = "openai" # openai, anthropic, etc.
|
||||
api_key: str = os.environ.get("OPENAI_API_KEY", "sk-c44i1hy64xgzwox6x08o4zug93frq6rgn84oqugf2pje1tg4")
|
||||
base_url: str = os.environ.get("OPENAI_BASE_URL", "https://api.xiaomimimo.com/v1")
|
||||
model: str = os.environ.get("OPENAI_MODEL", "mimo-v2-flash")
|
||||
temperature: float = 0.3
|
||||
max_tokens: int = 131072
|
||||
provider: str = os.environ.get("LLM_PROVIDER", "openai") # openai, gemini, etc.
|
||||
api_key: str = os.environ.get("OPENAI_API_KEY", "")
|
||||
base_url: str = os.environ.get("OPENAI_BASE_URL", "http://127.0.0.1:9999/v1")
|
||||
model: str = os.environ.get("OPENAI_MODEL", "gemini-3-flash")
|
||||
temperature: float = 0.5
|
||||
max_tokens: int = 8192 # 降低默认值,避免某些API不支持过大的值
|
||||
|
||||
def __post_init__(self):
|
||||
"""配置初始化后的处理"""
|
||||
if self.provider == "gemini":
|
||||
# 如果使用 Gemini,尝试从环境变量加载 Gemini 配置,或者使用默认的 Gemini 配置
|
||||
# 注意:如果 OPENAI_API_KEY 已设置且 GEMINI_API_KEY 未设置,可能会沿用 OpenAI 的 Key,
|
||||
# 但既然用户切换了 provider,通常会有配套的 Key。
|
||||
self.api_key = os.environ.get("GEMINI_API_KEY", "")
|
||||
# Gemini 的 OpenAI 兼容接口地址
|
||||
self.base_url = os.environ.get("GEMINI_BASE_URL", "https://gemini.jeason.online")
|
||||
self.model = os.environ.get("GEMINI_MODEL", "gemini-2.5-flash")
|
||||
# Gemini 有更严格的 token 限制
|
||||
self.max_tokens = 8192
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
|
||||
22
config/templates/anomaly_detection.yaml
Normal file
22
config/templates/anomaly_detection.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
name: 异常值检测分析
|
||||
description: 识别数据中的异常值和离群点
|
||||
steps:
|
||||
- name: 数值列统计分析
|
||||
description: 计算数值列的统计特征
|
||||
prompt: 计算所有数值列的均值、标准差、四分位数等统计量
|
||||
|
||||
- name: 箱线图可视化
|
||||
description: 使用箱线图识别异常值
|
||||
prompt: 为每个数值列绘制箱线图,直观展示异常值分布
|
||||
|
||||
- name: Z-Score异常检测
|
||||
description: 使用Z-Score方法检测异常值
|
||||
prompt: 计算每个数值的Z-Score,标记|Z|>3的异常值
|
||||
|
||||
- name: IQR异常检测
|
||||
description: 使用四分位距方法检测异常值
|
||||
prompt: 使用IQR方法(Q1-1.5*IQR, Q3+1.5*IQR)检测异常值
|
||||
|
||||
- name: 异常值汇总报告
|
||||
description: 整理所有检测到的异常值
|
||||
prompt: 汇总所有异常值,分析其特征和可能原因,提供处理建议
|
||||
18
config/templates/comparison.yaml
Normal file
18
config/templates/comparison.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
name: 分组对比分析
|
||||
description: 对比不同分组之间的差异和特征
|
||||
steps:
|
||||
- name: 分组统计
|
||||
description: 计算各组的统计指标
|
||||
prompt: 按分组列分组,计算数值列的均值、中位数、标准差
|
||||
|
||||
- name: 分组可视化对比
|
||||
description: 绘制对比图表
|
||||
prompt: 绘制各组的柱状图和箱线图,直观对比差异
|
||||
|
||||
- name: 差异显著性检验
|
||||
description: 统计检验组间差异
|
||||
prompt: 进行t检验或方差分析,判断组间差异是否显著
|
||||
|
||||
- name: 对比结论
|
||||
description: 总结对比结果
|
||||
prompt: 总结各组特征、主要差异和业务洞察
|
||||
50
config/templates/health_report.yaml
Normal file
50
config/templates/health_report.yaml
Normal file
@@ -0,0 +1,50 @@
|
||||
name: 车联网工单健康度报告
|
||||
description: 全面分析车联网技术支持工单的健康状况,从多个维度评估工单处理效率和质量
|
||||
steps:
|
||||
- name: 数据概览与质量检查
|
||||
description: 检查数据完整性、缺失值、异常值等
|
||||
prompt: 加载数据并进行质量检查,输出数据概况和潜在问题
|
||||
|
||||
- name: 工单总量分析
|
||||
description: 统计总工单数、时间分布、趋势变化
|
||||
prompt: 计算总工单数,按时间维度统计工单量,绘制时间序列趋势图
|
||||
|
||||
- name: 车型维度分析
|
||||
description: 分析不同车型的工单分布和问题特征
|
||||
prompt: 统计各车型工单数量,绘制车型分布图,识别高风险车型
|
||||
|
||||
- name: 模块维度分析
|
||||
description: 分析工单涉及的技术模块分布
|
||||
prompt: 统计各技术模块的工单量,绘制模块分布图,识别高频问题模块
|
||||
|
||||
- name: 功能维度分析
|
||||
description: 分析具体功能点的问题分布
|
||||
prompt: 统计各功能的工单量,绘制TOP功能问题排行,分析功能稳定性
|
||||
|
||||
- name: 问题严重程度分析
|
||||
description: 分析工单的严重程度分布
|
||||
prompt: 统计不同严重程度的工单比例,绘制严重程度分布图
|
||||
|
||||
- name: 处理时长分析
|
||||
description: 分析工单处理时效性
|
||||
prompt: 计算平均处理时长、SLA达成率,识别超时工单,绘制时长分布图
|
||||
|
||||
- name: 责任人工作负载分析
|
||||
description: 分析各责任人的工单负载和处理效率
|
||||
prompt: 统计各责任人的工单数和处理效率,绘制负载分布图
|
||||
|
||||
- name: 来源渠道分析
|
||||
description: 分析工单来源渠道分布
|
||||
prompt: 统计各来源渠道的工单量,绘制渠道分布图
|
||||
|
||||
- name: 高频问题深度分析
|
||||
description: 识别并深入分析高频问题
|
||||
prompt: 提取TOP10高频问题,分析问题原因、影响范围和解决方案
|
||||
|
||||
- name: 综合健康度评分
|
||||
description: 基于多个维度计算综合健康度评分
|
||||
prompt: 综合考虑工单量、处理时长、问题严重度等指标,计算健康度评分
|
||||
|
||||
- name: 生成最终报告
|
||||
description: 整合所有分析结果,生成完整报告
|
||||
prompt: 整合所有图表和分析结论,生成一份完整的车联网工单健康度报告
|
||||
22
config/templates/trend_analysis.yaml
Normal file
22
config/templates/trend_analysis.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
name: 时间序列趋势分析
|
||||
description: 分析数据的时间趋势、季节性和周期性特征
|
||||
steps:
|
||||
- name: 时间序列数据准备
|
||||
description: 将数据转换为时间序列格式
|
||||
prompt: 将时间列转换为日期格式,按时间排序数据
|
||||
|
||||
- name: 趋势可视化
|
||||
description: 绘制时间序列图
|
||||
prompt: 绘制数值随时间的变化趋势图,添加移动平均线
|
||||
|
||||
- name: 趋势分析
|
||||
description: 识别上升、下降或平稳趋势
|
||||
prompt: 计算趋势线斜率,判断整体趋势方向和变化速率
|
||||
|
||||
- name: 季节性分析
|
||||
description: 检测季节性模式
|
||||
prompt: 分析月度、季度等周期性模式,绘制季节性分解图
|
||||
|
||||
- name: 异常点检测
|
||||
description: 识别时间序列中的异常点
|
||||
prompt: 使用统计方法检测时间序列中的异常值,标注在图表上
|
||||
6
conftest.py
Normal file
6
conftest.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Root conftest.py — configures pytest to find project modules."""
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
@@ -10,16 +10,50 @@
|
||||
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import yaml
|
||||
from typing import Dict, Any, List, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from utils.create_session_dir import create_session_output_dir
|
||||
from utils.format_execution_result import format_execution_result
|
||||
from utils.extract_code import extract_code_from_response
|
||||
from utils.data_loader import load_and_profile_data
|
||||
from utils.data_loader import load_and_profile_data, load_data_chunked, load_and_profile_data_smart
|
||||
from utils.llm_helper import LLMHelper
|
||||
from utils.code_executor import CodeExecutor
|
||||
from utils.script_generator import generate_reusable_script
|
||||
from utils.data_privacy import build_safe_profile, build_local_profile, sanitize_execution_feedback, generate_enriched_hint
|
||||
from config.llm_config import LLMConfig
|
||||
from prompts import data_analysis_system_prompt, final_report_system_prompt
|
||||
from config.app_config import app_config
|
||||
from prompts import data_analysis_system_prompt, final_report_system_prompt, data_analysis_followup_prompt
|
||||
|
||||
|
||||
# Regex patterns that indicate a data-context error (column/variable/DataFrame issues)
|
||||
DATA_CONTEXT_PATTERNS = [
|
||||
# KeyError - missing key/column
|
||||
r"KeyError:\s*['\"](.+?)['\"]",
|
||||
# ValueError - value-related issues
|
||||
r"ValueError.*(?:column|col|field|shape|axis)",
|
||||
# NameError - undefined variables
|
||||
r"NameError.*(?:df|data|frame|series)",
|
||||
# Empty/missing data
|
||||
r"(?:empty|no\s+data|0\s+rows|No\s+data)",
|
||||
# IndexError - out of bounds
|
||||
r"IndexError.*(?:out of range|out of bounds)",
|
||||
# AttributeError - missing attributes
|
||||
r"AttributeError.*(?:DataFrame|Series|object)\s+has\s+no\s+attribute",
|
||||
# Pandas-specific errors
|
||||
r"pd\.errors\.(?:EmptyDataError|ParserError|MergeError)",
|
||||
r"MergeError: No common columns",
|
||||
# Type errors
|
||||
r"TypeError.*(?:unsupported operand|expected string|cannot convert)",
|
||||
# UnboundLocalError - undefined local variables
|
||||
r"UnboundLocalError.*referenced before assignment",
|
||||
# Syntax errors
|
||||
r"SyntaxError: invalid syntax",
|
||||
# Module/Import errors for data libraries
|
||||
r"ModuleNotFoundError.*(?:pandas|numpy|matplotlib)",
|
||||
r"ImportError.*(?:pandas|numpy|matplotlib)",
|
||||
]
|
||||
|
||||
|
||||
class DataAnalysisAgent:
|
||||
@@ -60,6 +94,57 @@ class DataAnalysisAgent:
|
||||
self.current_round = 0
|
||||
self.session_output_dir = None
|
||||
self.executor = None
|
||||
self.data_profile = "" # 存储数据画像(完整版,本地使用)
|
||||
self.data_profile_safe = "" # 存储安全画像(发给LLM)
|
||||
self.data_files = [] # 存储数据文件列表
|
||||
self.user_requirement = "" # 存储用户需求
|
||||
self._progress_callback = None # 进度回调函数
|
||||
self._session_ref = None # Reference to SessionData for round tracking
|
||||
|
||||
def set_session_ref(self, session):
|
||||
"""Set a reference to the SessionData instance for appending round data.
|
||||
|
||||
Args:
|
||||
session: The SessionData instance for the current analysis session.
|
||||
"""
|
||||
self._session_ref = session
|
||||
|
||||
def set_progress_callback(self, callback):
|
||||
"""Set a callback function(current_round, max_rounds, message) for progress updates."""
|
||||
self._progress_callback = callback
|
||||
|
||||
def _summarize_result(self, result: Dict[str, Any]) -> str:
|
||||
"""Produce a one-line summary from a code execution result.
|
||||
|
||||
Args:
|
||||
result: The execution result dict from CodeExecutor.
|
||||
|
||||
Returns:
|
||||
A concise summary string, e.g. "执行成功,输出 DataFrame (150行×8列)"
|
||||
or "执行失败: KeyError: 'col_x'".
|
||||
"""
|
||||
if result.get("success"):
|
||||
evidence_rows = result.get("evidence_rows", [])
|
||||
if evidence_rows:
|
||||
num_rows = len(evidence_rows)
|
||||
num_cols = len(evidence_rows[0]) if evidence_rows else 0
|
||||
# Check auto_exported_files for more accurate row/col counts
|
||||
auto_files = result.get("auto_exported_files", [])
|
||||
if auto_files:
|
||||
last_file = auto_files[-1]
|
||||
num_rows = last_file.get("rows", num_rows)
|
||||
num_cols = last_file.get("cols", num_cols)
|
||||
return f"执行成功,输出 DataFrame ({num_rows}行×{num_cols}列)"
|
||||
output = result.get("output", "")
|
||||
if output:
|
||||
first_line = output.strip().split("\n")[0][:80]
|
||||
return f"执行成功: {first_line}"
|
||||
return "执行成功"
|
||||
else:
|
||||
error = result.get("error", "未知错误")
|
||||
if len(error) > 100:
|
||||
error = error[:100] + "..."
|
||||
return f"执行失败: {error}"
|
||||
|
||||
def _process_response(self, response: str) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -73,9 +158,23 @@ class DataAnalysisAgent:
|
||||
"""
|
||||
try:
|
||||
yaml_data = self.llm.parse_yaml_response(response)
|
||||
action = yaml_data.get("action", "generate_code")
|
||||
action = yaml_data.get("action", "")
|
||||
|
||||
print(f"🎯 检测到动作: {action}")
|
||||
# If YAML parsing returned empty/no action, try to detect action from raw text
|
||||
if not action:
|
||||
if "analysis_complete" in response:
|
||||
action = "analysis_complete"
|
||||
# Try to extract final_report from raw text
|
||||
if not yaml_data.get("final_report"):
|
||||
yaml_data["action"] = "analysis_complete"
|
||||
yaml_data["final_report"] = ""
|
||||
elif "collect_figures" in response:
|
||||
action = "collect_figures"
|
||||
yaml_data["action"] = "collect_figures"
|
||||
else:
|
||||
action = "generate_code"
|
||||
|
||||
print(f"[TARGET] 检测到动作: {action}")
|
||||
|
||||
if action == "analysis_complete":
|
||||
return self._handle_analysis_complete(response, yaml_data)
|
||||
@@ -84,18 +183,27 @@ class DataAnalysisAgent:
|
||||
elif action == "generate_code":
|
||||
return self._handle_generate_code(response, yaml_data)
|
||||
else:
|
||||
print(f"⚠️ 未知动作类型: {action},按generate_code处理")
|
||||
print(f"[WARN] 未知动作类型: {action},按generate_code处理")
|
||||
return self._handle_generate_code(response, yaml_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ 解析响应失败: {str(e)},按generate_code处理")
|
||||
print(f"[WARN] 解析响应失败: {str(e)},尝试提取代码并按generate_code处理")
|
||||
# Check if this is actually an analysis_complete or collect_figures response
|
||||
if "analysis_complete" in response:
|
||||
return self._handle_analysis_complete(response, {"final_report": ""})
|
||||
if "collect_figures" in response:
|
||||
return self._handle_collect_figures(response, {"figures_to_collect": []})
|
||||
# 即使YAML解析失败,也尝试提取代码
|
||||
extracted_code = extract_code_from_response(response)
|
||||
if extracted_code:
|
||||
return self._handle_generate_code(response, {"code": extracted_code})
|
||||
return self._handle_generate_code(response, {})
|
||||
|
||||
def _handle_analysis_complete(
|
||||
self, response: str, yaml_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""处理分析完成动作"""
|
||||
print("✅ 分析任务完成")
|
||||
print("[OK] 分析任务完成")
|
||||
final_report = yaml_data.get("final_report", "分析完成,无最终报告")
|
||||
return {
|
||||
"action": "analysis_complete",
|
||||
@@ -108,10 +216,12 @@ class DataAnalysisAgent:
|
||||
self, response: str, yaml_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""处理图片收集动作"""
|
||||
print("📊 开始收集图片")
|
||||
print("[CHART] 开始收集图片")
|
||||
figures_to_collect = yaml_data.get("figures_to_collect", [])
|
||||
|
||||
collected_figures = []
|
||||
# 使用seen_paths集合来去重,防止重复收集
|
||||
seen_paths = set()
|
||||
|
||||
for figure_info in figures_to_collect:
|
||||
figure_number = figure_info.get("figure_number", "未知")
|
||||
@@ -125,29 +235,36 @@ class DataAnalysisAgent:
|
||||
description = figure_info.get("description", "")
|
||||
analysis = figure_info.get("analysis", "")
|
||||
|
||||
print(f"📈 收集图片 {figure_number}: {filename}")
|
||||
print(f" 📂 路径: {file_path}")
|
||||
print(f" 📝 描述: {description}")
|
||||
print(f" 🔍 分析: {analysis}")
|
||||
print(f"[GRAPH] 收集图片 {figure_number}: {filename}")
|
||||
print(f" [DIR] 路径: {file_path}")
|
||||
print(f" [NOTE] 描述: {description}")
|
||||
print(f" [SEARCH] 分析: {analysis}")
|
||||
|
||||
# 验证文件是否存在
|
||||
# 只有文件真正存在时才加入列表,防止报告出现裂图
|
||||
if file_path and os.path.exists(file_path):
|
||||
print(f" ✅ 文件存在: {file_path}")
|
||||
elif file_path:
|
||||
print(f" ⚠️ 文件不存在: {file_path}")
|
||||
# 检查是否已经收集过该路径
|
||||
abs_path = os.path.abspath(file_path)
|
||||
if abs_path not in seen_paths:
|
||||
print(f" [OK] 文件存在: {file_path}")
|
||||
# 记录图片信息
|
||||
collected_figures.append(
|
||||
{
|
||||
"figure_number": figure_number,
|
||||
"filename": filename,
|
||||
"file_path": file_path,
|
||||
"description": description,
|
||||
"analysis": analysis,
|
||||
}
|
||||
)
|
||||
seen_paths.add(abs_path)
|
||||
else:
|
||||
print(f" [WARN] 跳过重复图片: {file_path}")
|
||||
else:
|
||||
print(f" ⚠️ 未提供文件路径")
|
||||
|
||||
# 记录图片信息
|
||||
collected_figures.append(
|
||||
{
|
||||
"figure_number": figure_number,
|
||||
"filename": filename,
|
||||
"file_path": file_path,
|
||||
"description": description,
|
||||
"analysis": analysis,
|
||||
}
|
||||
)
|
||||
if file_path:
|
||||
print(f" [WARN] 文件不存在: {file_path}")
|
||||
else:
|
||||
print(f" [WARN] 未提供文件路径")
|
||||
|
||||
return {
|
||||
"action": "collect_figures",
|
||||
@@ -162,6 +279,7 @@ class DataAnalysisAgent:
|
||||
"""处理代码生成和执行动作"""
|
||||
# 从YAML数据中获取代码(更准确)
|
||||
code = yaml_data.get("code", "")
|
||||
reasoning = yaml_data.get("reasoning", "")
|
||||
|
||||
# 如果YAML中没有代码,尝试从响应中提取
|
||||
if not code:
|
||||
@@ -171,7 +289,6 @@ class DataAnalysisAgent:
|
||||
if code:
|
||||
code = code.strip()
|
||||
if code.startswith("```"):
|
||||
import re
|
||||
# 去除开头的 ```python 或 ```
|
||||
code = re.sub(r"^```[a-zA-Z]*\n", "", code)
|
||||
# 去除结尾的 ```
|
||||
@@ -179,7 +296,7 @@ class DataAnalysisAgent:
|
||||
code = code.strip()
|
||||
|
||||
if code:
|
||||
print(f"🔧 执行代码:\n{code}")
|
||||
print(f"[TOOL] 执行代码:\n{code}")
|
||||
print("-" * 40)
|
||||
|
||||
# 执行代码
|
||||
@@ -187,11 +304,12 @@ class DataAnalysisAgent:
|
||||
|
||||
# 格式化执行结果
|
||||
feedback = format_execution_result(result)
|
||||
print(f"📋 执行反馈:\n{feedback}")
|
||||
print(f"[LIST] 执行反馈:\n{feedback}")
|
||||
|
||||
return {
|
||||
"action": "generate_code",
|
||||
"code": code,
|
||||
"reasoning": reasoning,
|
||||
"result": result,
|
||||
"feedback": feedback,
|
||||
"response": response,
|
||||
@@ -199,97 +317,369 @@ class DataAnalysisAgent:
|
||||
}
|
||||
else:
|
||||
# 如果没有代码,说明LLM响应格式有问题,需要重新生成
|
||||
print("⚠️ 未从响应中提取到可执行代码,要求LLM重新生成")
|
||||
print("[WARN] 未从响应中提取到可执行代码,要求LLM重新生成")
|
||||
return {
|
||||
"action": "invalid_response",
|
||||
"reasoning": reasoning,
|
||||
"error": "响应中缺少可执行代码",
|
||||
"response": response,
|
||||
"continue": True,
|
||||
}
|
||||
|
||||
def analyze(self, user_input: str, files: List[str] = None) -> Dict[str, Any]:
|
||||
def _classify_error(self, error_message: str) -> str:
|
||||
"""Classify execution error as data-context or other.
|
||||
|
||||
Inspects the error message against DATA_CONTEXT_PATTERNS to determine
|
||||
if the error is related to data context (missing columns, undefined
|
||||
data variables, empty DataFrames, etc.).
|
||||
|
||||
Args:
|
||||
error_message: The error message string from code execution.
|
||||
|
||||
Returns:
|
||||
"data_context" if the error matches a data-context pattern,
|
||||
"other" otherwise.
|
||||
"""
|
||||
for pattern in DATA_CONTEXT_PATTERNS:
|
||||
if re.search(pattern, error_message, re.IGNORECASE):
|
||||
return "data_context"
|
||||
return "other"
|
||||
|
||||
def _trim_conversation_history(self):
|
||||
"""Apply sliding window trimming to conversation history.
|
||||
|
||||
Retains the first user message (original requirement + Safe_Profile) at
|
||||
index 0, generates a compressed summary of old messages, and keeps only
|
||||
the most recent ``conversation_window_size`` message pairs in full.
|
||||
"""
|
||||
window_size = app_config.conversation_window_size
|
||||
max_messages = window_size * 2 # pairs of user+assistant messages
|
||||
|
||||
if len(self.conversation_history) <= max_messages:
|
||||
return # No trimming needed
|
||||
|
||||
first_message = self.conversation_history[0] # Always retain
|
||||
|
||||
# Determine trim boundary: skip first message + possible existing summary
|
||||
start_idx = 1
|
||||
has_existing_summary = (
|
||||
len(self.conversation_history) > 1
|
||||
and self.conversation_history[1]["role"] == "user"
|
||||
and self.conversation_history[1]["content"].startswith("[分析摘要]")
|
||||
)
|
||||
if has_existing_summary:
|
||||
start_idx = 2
|
||||
|
||||
# Messages to trim vs keep
|
||||
messages_to_consider = self.conversation_history[start_idx:]
|
||||
messages_to_trim = messages_to_consider[:-max_messages]
|
||||
messages_to_keep = messages_to_consider[-max_messages:]
|
||||
|
||||
if not messages_to_trim:
|
||||
return
|
||||
|
||||
# Generate summary of trimmed messages
|
||||
summary = self._compress_trimmed_messages(messages_to_trim)
|
||||
|
||||
# Rebuild history: first_message + summary + recent messages
|
||||
self.conversation_history = [first_message]
|
||||
if summary:
|
||||
self.conversation_history.append({"role": "user", "content": summary})
|
||||
self.conversation_history.extend(messages_to_keep)
|
||||
|
||||
def _compress_trimmed_messages(self, messages: list) -> str:
|
||||
"""Compress trimmed messages into a concise summary string.
|
||||
|
||||
Extracts the action type from each assistant message, the execution
|
||||
outcome (success / failure), and completed SOP stages from the
|
||||
subsequent user feedback message. Code blocks and raw execution
|
||||
output are excluded.
|
||||
|
||||
The summary explicitly lists completed SOP stages so the LLM does
|
||||
not restart from stage 1 after conversation trimming.
|
||||
|
||||
Args:
|
||||
messages: List of conversation message dicts to compress.
|
||||
|
||||
Returns:
|
||||
A summary string prefixed with ``[分析摘要]``.
|
||||
"""
|
||||
summary_parts = ["[分析摘要] 以下是之前分析轮次的概要:"]
|
||||
round_num = 0
|
||||
completed_stages = set()
|
||||
|
||||
# SOP stage keywords to detect from assistant messages
|
||||
stage_keywords = {
|
||||
"阶段1": "数据探索与加载",
|
||||
"阶段2": "基础分布分析",
|
||||
"阶段3": "时序与来源分析",
|
||||
"阶段4": "深度交叉分析",
|
||||
"阶段5": "效率分析",
|
||||
"阶段6": "高级挖掘",
|
||||
}
|
||||
|
||||
for msg in messages:
|
||||
content = msg["content"]
|
||||
if msg["role"] == "assistant":
|
||||
round_num += 1
|
||||
# Extract action type from YAML-like content
|
||||
action = "generate_code"
|
||||
if "action: \"collect_figures\"" in content or "action: collect_figures" in content:
|
||||
action = "collect_figures"
|
||||
elif "action: \"analysis_complete\"" in content or "action: analysis_complete" in content:
|
||||
action = "analysis_complete"
|
||||
|
||||
# Detect completed SOP stages
|
||||
for stage_key, stage_name in stage_keywords.items():
|
||||
if stage_key in content or stage_name in content:
|
||||
completed_stages.add(f"{stage_key}: {stage_name}")
|
||||
|
||||
summary_parts.append(f"- 轮次{round_num}: 动作={action}")
|
||||
elif msg["role"] == "user" and "代码执行反馈" in content:
|
||||
success = "失败" if "[ERROR]" in content or "执行错误" in content else "成功"
|
||||
if summary_parts and summary_parts[-1].startswith("- 轮次"):
|
||||
summary_parts[-1] += f", 执行结果={success}"
|
||||
|
||||
# Append completed stages so the LLM knows where to continue
|
||||
if completed_stages:
|
||||
summary_parts.append("")
|
||||
summary_parts.append("**已完成的SOP阶段** (请勿重复执行):")
|
||||
for stage in sorted(completed_stages):
|
||||
summary_parts.append(f" ✓ {stage}")
|
||||
summary_parts.append("")
|
||||
summary_parts.append("请从下一个未完成的阶段继续,不要重新执行已完成的阶段。")
|
||||
|
||||
return "\n".join(summary_parts)
|
||||
|
||||
def _profile_files_parallel(self, file_paths: list) -> tuple:
|
||||
"""Profile multiple files concurrently using ThreadPoolExecutor.
|
||||
|
||||
Each file is profiled independently via ``build_safe_profile`` and
|
||||
``build_local_profile``. Results are collected and merged. If any
|
||||
individual file fails, an error entry is included for that file and
|
||||
profiling continues for the remaining files.
|
||||
|
||||
Args:
|
||||
file_paths: List of file paths to profile.
|
||||
|
||||
Returns:
|
||||
A tuple ``(safe_profile, local_profile)`` of merged markdown strings.
|
||||
"""
|
||||
max_workers = app_config.max_parallel_profiles
|
||||
safe_profiles = []
|
||||
local_profiles = []
|
||||
|
||||
def profile_single(path):
|
||||
safe = build_safe_profile([path])
|
||||
local = build_local_profile([path])
|
||||
return path, safe, local
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {executor.submit(profile_single, p): p for p in file_paths}
|
||||
for future in as_completed(futures):
|
||||
path = futures[future]
|
||||
try:
|
||||
_, safe, local = future.result()
|
||||
safe_profiles.append(safe)
|
||||
local_profiles.append(local)
|
||||
except Exception as e:
|
||||
error_entry = f"## 文件: {os.path.basename(path)}\n[ERROR] 分析失败: {e}\n\n"
|
||||
safe_profiles.append(error_entry)
|
||||
local_profiles.append(error_entry)
|
||||
|
||||
return "\n".join(safe_profiles), "\n".join(local_profiles)
|
||||
|
||||
def analyze(self, user_input: str, files: List[str] = None, session_output_dir: str = None, reset_session: bool = True, max_rounds: int = None, template_name: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
开始分析流程
|
||||
|
||||
Args:
|
||||
user_input: 用户的自然语言需求
|
||||
files: 数据文件路径列表
|
||||
session_output_dir: 指定的会话输出目录(可选)
|
||||
reset_session: 是否重置会话 (True: 新开启分析; False: 在现有上下文中继续)
|
||||
max_rounds: 本次分析的最大轮数 (可选,如果不填则使用默认值)
|
||||
template_name: 分析模板名称 (可选,如果提供则使用模板引导分析)
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
# 重置状态
|
||||
self.conversation_history = []
|
||||
self.analysis_results = []
|
||||
self.current_round = 0
|
||||
|
||||
# 创建本次分析的专用输出目录
|
||||
self.session_output_dir = create_session_output_dir(
|
||||
self.base_output_dir, user_input
|
||||
)
|
||||
|
||||
# 初始化代码执行器,使用会话目录
|
||||
self.executor = CodeExecutor(self.session_output_dir)
|
||||
|
||||
# 设置会话目录变量到执行环境中
|
||||
self.executor.set_variable("session_output_dir", self.session_output_dir)
|
||||
|
||||
# 设用工具生成数据画像
|
||||
data_profile = ""
|
||||
if files:
|
||||
print("🔍 正在生成数据画像...")
|
||||
data_profile = load_and_profile_data(files)
|
||||
print("✅ 数据画像生成完毕")
|
||||
|
||||
# 构建初始prompt
|
||||
initial_prompt = f"""用户需求: {user_input}"""
|
||||
if files:
|
||||
initial_prompt += f"\n数据文件: {', '.join(files)}"
|
||||
|
||||
if data_profile:
|
||||
initial_prompt += f"\n\n{data_profile}\n\n请根据上述【数据画像】中的统计信息(如高频值、缺失率、数据范围)来制定分析策略。如果发现明显的高频问题或异常分布,请优先进行深度分析。"
|
||||
# 确定本次运行的轮数限制
|
||||
current_max_rounds = max_rounds if max_rounds is not None else self.max_rounds
|
||||
|
||||
print(f"🚀 开始数据分析任务")
|
||||
print(f"📝 用户需求: {user_input}")
|
||||
if files:
|
||||
print(f"📁 数据文件: {', '.join(files)}")
|
||||
print(f"📂 输出目录: {self.session_output_dir}")
|
||||
print(f"🔢 最大轮数: {self.max_rounds}")
|
||||
# Template integration: prepend template prompt to user input if provided
|
||||
if template_name:
|
||||
from utils.analysis_templates import get_template
|
||||
template = get_template(template_name) # Raises ValueError if invalid
|
||||
template_prompt = template.get_full_prompt()
|
||||
user_input = f"{template_prompt}\n\n{user_input}"
|
||||
|
||||
if reset_session:
|
||||
# --- 初始化新会话 ---
|
||||
self.conversation_history = []
|
||||
self.analysis_results = []
|
||||
self.current_round = 0
|
||||
self.data_files = files or [] # 保存数据文件列表
|
||||
self.user_requirement = user_input # 保存用户需求
|
||||
|
||||
# 创建本次分析的专用输出目录
|
||||
if session_output_dir:
|
||||
self.session_output_dir = session_output_dir
|
||||
else:
|
||||
self.session_output_dir = create_session_output_dir(
|
||||
self.base_output_dir, user_input
|
||||
)
|
||||
|
||||
# 初始化代码执行器,使用会话目录
|
||||
self.executor = CodeExecutor(self.session_output_dir)
|
||||
|
||||
# 设置会话目录变量到执行环境中
|
||||
self.executor.set_variable("session_output_dir", self.session_output_dir)
|
||||
|
||||
# 生成数据画像(分级:安全级发给LLM,完整级留本地)
|
||||
data_profile_safe = ""
|
||||
data_profile_local = ""
|
||||
if files:
|
||||
print("[SEARCH] 正在生成数据画像...")
|
||||
try:
|
||||
if len(files) > 1:
|
||||
# Parallel profiling for multiple files
|
||||
data_profile_safe, data_profile_local = self._profile_files_parallel(files)
|
||||
else:
|
||||
data_profile_safe = build_safe_profile(files)
|
||||
data_profile_local = build_local_profile(files)
|
||||
print("[OK] 数据画像生成完毕(安全级 + 本地级)")
|
||||
except Exception as e:
|
||||
print(f"[WARN] 数据画像生成失败: {e}")
|
||||
|
||||
# Expose chunked iterators for large files in the Code_Executor namespace
|
||||
for fp in files:
|
||||
try:
|
||||
if os.path.exists(fp):
|
||||
file_size_mb = os.path.getsize(fp) / (1024 * 1024)
|
||||
if file_size_mb > app_config.max_file_size_mb:
|
||||
var_name = "chunked_iter_" + os.path.splitext(os.path.basename(fp))[0]
|
||||
# Store a factory so the iterator can be re-created
|
||||
self.executor.set_variable(var_name, lambda p=fp: load_data_chunked(p))
|
||||
print(f"[OK] 大文件 {os.path.basename(fp)} 的分块迭代器已注入为 {var_name}()")
|
||||
except Exception as e:
|
||||
print(f"[WARN] 注入分块迭代器失败 ({os.path.basename(fp)}): {e}")
|
||||
|
||||
# 安全画像发给LLM,完整画像留给最终报告生成
|
||||
self.data_profile = data_profile_local # 本地完整版用于最终报告
|
||||
self.data_profile_safe = data_profile_safe # 安全版用于LLM对话
|
||||
|
||||
# 构建初始prompt(只发送安全级画像给LLM)
|
||||
initial_prompt = f"""用户需求: {user_input}"""
|
||||
if files:
|
||||
initial_prompt += f"\n数据文件: {', '.join(files)}"
|
||||
|
||||
if data_profile_safe:
|
||||
initial_prompt += f"\n\n{data_profile_safe}\n\n请根据上述【数据结构概览】中的列名、数据类型和特征描述来制定分析策略。先通过代码探索数据的实际分布,再进行深度分析。"
|
||||
|
||||
print(f"[START] 开始数据分析任务")
|
||||
print(f"[NOTE] 用户需求: {user_input}")
|
||||
if files:
|
||||
print(f"[FOLDER] 数据文件: {', '.join(files)}")
|
||||
print(f"[DIR] 输出目录: {self.session_output_dir}")
|
||||
|
||||
# 添加到对话历史
|
||||
self.conversation_history.append({"role": "user", "content": initial_prompt})
|
||||
|
||||
else:
|
||||
# --- 继续现有会话 ---
|
||||
# 如果是追问,且没有指定轮数,默认减少轮数,避免过度分析
|
||||
if max_rounds is None:
|
||||
current_max_rounds = 10 # 追问通常不需要那么长的思考链,10轮足够
|
||||
|
||||
print(f"\n[START] 继续分析任务 (追问模式)")
|
||||
print(f"[NOTE] 后续需求: {user_input}")
|
||||
|
||||
# 重置当前轮数计数器,以便给新任务足够的轮次
|
||||
self.current_round = 0
|
||||
|
||||
# 添加到对话历史
|
||||
# 提示Agent这是后续追问,可以简化步骤
|
||||
follow_up_prompt = f"后续需求: {user_input}\n(注意:这是后续追问,请直接针对该问题进行分析,无需从头开始执行完整SOP。)"
|
||||
self.conversation_history.append({"role": "user", "content": follow_up_prompt})
|
||||
|
||||
print(f"[NUM] 本次最大轮数: {current_max_rounds}")
|
||||
if self.force_max_rounds:
|
||||
print(f"⚡ 强制模式: 将运行满 {self.max_rounds} 轮(忽略AI完成信号)")
|
||||
print(f"[FAST] 强制模式: 将运行满 {current_max_rounds} 轮(忽略AI完成信号)")
|
||||
print("=" * 60)
|
||||
# 添加到对话历史
|
||||
self.conversation_history.append({"role": "user", "content": initial_prompt})
|
||||
|
||||
# 保存原始 max_rounds 以便恢复(虽然 analyze 结束后不需要恢复,但为了逻辑严谨)
|
||||
original_max_rounds = self.max_rounds
|
||||
self.max_rounds = current_max_rounds
|
||||
|
||||
# 初始化连续失败计数器
|
||||
consecutive_failures = 0
|
||||
# Per-round data-context retry counter
|
||||
data_context_retries = 0
|
||||
last_retry_round = 0
|
||||
|
||||
while self.current_round < self.max_rounds:
|
||||
self.current_round += 1
|
||||
print(f"\n🔄 第 {self.current_round} 轮分析")
|
||||
# Notify progress callback
|
||||
if self._progress_callback:
|
||||
self._progress_callback(self.current_round, self.max_rounds, f"第{self.current_round}/{self.max_rounds}轮分析中...")
|
||||
# Reset data-context retry counter when entering a new round
|
||||
if self.current_round != last_retry_round:
|
||||
data_context_retries = 0
|
||||
|
||||
# Trim conversation history after the first round to bound token usage
|
||||
if self.current_round > 1:
|
||||
self._trim_conversation_history()
|
||||
|
||||
print(f"\n[LOOP] 第 {self.current_round} 轮分析")
|
||||
# 调用LLM生成响应
|
||||
try: # 获取当前执行环境的变量信息
|
||||
notebook_variables = self.executor.get_environment_info()
|
||||
|
||||
# Select prompt based on mode
|
||||
if self.current_round == 1 and not reset_session:
|
||||
# For the first round of a follow-up session, use the specialized prompt
|
||||
base_system_prompt = data_analysis_followup_prompt
|
||||
elif not reset_session and self.current_round > 1:
|
||||
# For subsequent rounds in follow-up, continue using the follow-up context
|
||||
# or maybe just the standard one is fine as long as SOP isn't fully enforced?
|
||||
# Let's stick to the follow-up prompt to prevent SOP regression
|
||||
base_system_prompt = data_analysis_followup_prompt
|
||||
else:
|
||||
base_system_prompt = data_analysis_system_prompt
|
||||
|
||||
# 格式化系统提示词,填入动态的notebook变量信息
|
||||
formatted_system_prompt = data_analysis_system_prompt.format(
|
||||
formatted_system_prompt = base_system_prompt.format(
|
||||
notebook_variables=notebook_variables
|
||||
)
|
||||
print(f"🐛 [DEBUG] System Prompt Head:\n{formatted_system_prompt[:500]}...\n[...]")
|
||||
print(f"🐛 [DEBUG] System Prompt Rules Check: 'stop_words' in prompt? {'stop_words' in formatted_system_prompt}")
|
||||
print(f"[DEBUG] [DEBUG] System Prompt Head:\n{formatted_system_prompt[:500]}...\n[...]")
|
||||
print(f"[DEBUG] [DEBUG] System Prompt Rules Check: 'stop_words' in prompt? {'stop_words' in formatted_system_prompt}")
|
||||
|
||||
response = self.llm.call(
|
||||
prompt=self._build_conversation_prompt(),
|
||||
system_prompt=formatted_system_prompt,
|
||||
)
|
||||
|
||||
print(f"🤖 助手响应:\n{response}")
|
||||
print(f"[AI] 助手响应:\n{response}")
|
||||
|
||||
# 使用统一的响应处理方法
|
||||
process_result = self._process_response(response)
|
||||
|
||||
# 根据处理结果决定是否继续(仅在非强制模式下)
|
||||
if process_result.get("action") == "invalid_response":
|
||||
consecutive_failures += 1
|
||||
print(f"[WARN] 连续失败次数: {consecutive_failures}/3")
|
||||
if consecutive_failures >= 3:
|
||||
print(f"[ERROR] 连续3次无法获取有效响应,分析终止。请检查网络或配置。")
|
||||
break
|
||||
else:
|
||||
consecutive_failures = 0 # 重置计数器
|
||||
|
||||
if not self.force_max_rounds and not process_result.get(
|
||||
"continue", True
|
||||
):
|
||||
print(f"\n✅ 分析完成!")
|
||||
print(f"\n[OK] 分析完成!")
|
||||
break
|
||||
|
||||
# 添加到对话历史
|
||||
@@ -300,8 +690,43 @@ class DataAnalysisAgent:
|
||||
# 根据动作类型添加不同的反馈
|
||||
if process_result["action"] == "generate_code":
|
||||
feedback = process_result.get("feedback", "")
|
||||
result = process_result.get("result", {})
|
||||
execution_failed = not result.get("success", True)
|
||||
|
||||
# --- Data-context retry logic ---
|
||||
if execution_failed:
|
||||
error_output = result.get("error", "") or feedback
|
||||
error_class = self._classify_error(error_output)
|
||||
|
||||
if error_class == "data_context" and data_context_retries < app_config.max_data_context_retries:
|
||||
data_context_retries += 1
|
||||
last_retry_round = self.current_round
|
||||
print(f"[RETRY] 数据上下文错误,重试 {data_context_retries}/{app_config.max_data_context_retries}")
|
||||
# Generate enriched hint from safe profile
|
||||
enriched_hint = generate_enriched_hint(error_output, self.data_profile_safe)
|
||||
# Add enriched hint to conversation history (assistant response already added above)
|
||||
self.conversation_history.append(
|
||||
{"role": "user", "content": enriched_hint}
|
||||
)
|
||||
# Record the failed attempt
|
||||
self.analysis_results.append(
|
||||
{
|
||||
"round": self.current_round,
|
||||
"code": process_result.get("code", ""),
|
||||
"result": result,
|
||||
"response": response,
|
||||
"retry": True,
|
||||
}
|
||||
)
|
||||
# Retry within the same round: decrement round counter so the
|
||||
# outer loop's increment brings us back to the same round number
|
||||
self.current_round -= 1
|
||||
continue
|
||||
|
||||
# Normal feedback path (no retry or non-data-context error or at limit)
|
||||
safe_feedback = sanitize_execution_feedback(feedback)
|
||||
self.conversation_history.append(
|
||||
{"role": "user", "content": f"代码执行反馈:\n{feedback}"}
|
||||
{"role": "user", "content": f"代码执行反馈:\n{safe_feedback}"}
|
||||
)
|
||||
|
||||
# 记录分析结果
|
||||
@@ -313,10 +738,55 @@ class DataAnalysisAgent:
|
||||
"response": response,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Construct Round_Data and append to session ---
|
||||
result = process_result.get("result", {})
|
||||
round_data = {
|
||||
"round": self.current_round,
|
||||
"reasoning": process_result.get("reasoning", ""),
|
||||
"code": process_result.get("code", ""),
|
||||
"result_summary": self._summarize_result(result),
|
||||
"evidence_rows": result.get("evidence_rows", []),
|
||||
"raw_log": feedback,
|
||||
"auto_exported_files": result.get("auto_exported_files", []),
|
||||
"prompt_saved_files": result.get("prompt_saved_files", []),
|
||||
}
|
||||
|
||||
if self._session_ref:
|
||||
self._session_ref.rounds.append(round_data)
|
||||
# Merge file metadata into SessionData.data_files
|
||||
for f in round_data.get("auto_exported_files", []):
|
||||
if f.get("skipped"):
|
||||
continue # Large DataFrame — not written to disk
|
||||
self._session_ref.data_files.append({
|
||||
"filename": f.get("filename", ""),
|
||||
"description": f"自动导出: {f.get('variable_name', '')}",
|
||||
"rows": f.get("rows", 0),
|
||||
"cols": f.get("cols", 0),
|
||||
"columns": f.get("columns", []),
|
||||
"size_bytes": 0,
|
||||
"source": "auto",
|
||||
})
|
||||
for f in round_data.get("prompt_saved_files", []):
|
||||
self._session_ref.data_files.append({
|
||||
"filename": f.get("filename", ""),
|
||||
"description": f.get("description", ""),
|
||||
"rows": f.get("rows", 0),
|
||||
"cols": 0,
|
||||
"columns": [],
|
||||
"size_bytes": 0,
|
||||
"source": "prompt",
|
||||
})
|
||||
elif process_result["action"] == "collect_figures":
|
||||
# 记录图片收集结果
|
||||
collected_figures = process_result.get("collected_figures", [])
|
||||
feedback = f"已收集 {len(collected_figures)} 个图片及其分析"
|
||||
|
||||
missing_figures = process_result.get("missing_figures", [])
|
||||
|
||||
feedback = f"已收集 {len(collected_figures)} 个有效图片及其分析。"
|
||||
if missing_figures:
|
||||
feedback += f"\n[WARN] 以下图片未找到,请检查代码是否成功保存了这些图片: {missing_figures}"
|
||||
|
||||
self.conversation_history.append(
|
||||
{
|
||||
"role": "user",
|
||||
@@ -330,13 +800,15 @@ class DataAnalysisAgent:
|
||||
"round": self.current_round,
|
||||
"action": "collect_figures",
|
||||
"collected_figures": collected_figures,
|
||||
"missing_figures": missing_figures,
|
||||
|
||||
"response": response,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"LLM调用错误: {str(e)}"
|
||||
print(f"❌ {error_msg}")
|
||||
print(f"[ERROR] {error_msg}")
|
||||
self.conversation_history.append(
|
||||
{
|
||||
"role": "user",
|
||||
@@ -345,7 +817,7 @@ class DataAnalysisAgent:
|
||||
)
|
||||
# 生成最终总结
|
||||
if self.current_round >= self.max_rounds:
|
||||
print(f"\n⚠️ 已达到最大轮数 ({self.max_rounds}),分析结束")
|
||||
print(f"\n[WARN] 已达到最大轮数 ({self.max_rounds}),分析结束")
|
||||
|
||||
return self._generate_final_report()
|
||||
|
||||
@@ -371,10 +843,39 @@ class DataAnalysisAgent:
|
||||
if result.get("action") == "collect_figures":
|
||||
all_figures.extend(result.get("collected_figures", []))
|
||||
|
||||
print(f"\n📊 开始生成最终分析报告...")
|
||||
print(f"📂 输出目录: {self.session_output_dir}")
|
||||
print(f"🔢 总轮数: {self.current_round}")
|
||||
print(f"📈 收集图片: {len(all_figures)} 个")
|
||||
print(f"\n[CHART] 开始生成最终分析报告...")
|
||||
print(f"[DIR] 输出目录: {self.session_output_dir}")
|
||||
|
||||
# --- 自动补全/发现图片机制 ---
|
||||
# 扫描目录下所有的png文件
|
||||
try:
|
||||
import glob
|
||||
existing_pngs = glob.glob(os.path.join(self.session_output_dir, "*.png"))
|
||||
|
||||
# 获取已收集的图片路径集合
|
||||
collected_paths = set()
|
||||
for fig in all_figures:
|
||||
if fig.get("file_path"):
|
||||
collected_paths.add(os.path.abspath(fig.get("file_path")))
|
||||
|
||||
# 检查是否有漏网之鱼
|
||||
for png_path in existing_pngs:
|
||||
abs_png_path = os.path.abspath(png_path)
|
||||
if abs_png_path not in collected_paths:
|
||||
print(f"[SEARCH] [自动发现] 补充未显式收集的图片: {os.path.basename(png_path)}")
|
||||
all_figures.append({
|
||||
"figure_number": "Auto",
|
||||
"filename": os.path.basename(png_path),
|
||||
"file_path": abs_png_path,
|
||||
"description": f"自动发现的分析图表: {os.path.basename(png_path)}",
|
||||
"analysis": "(该图表由系统自动捕获,Agent未提供具体分析文本,请结合图表标题理解)"
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"[WARN] 自动发现图片失败: {e}")
|
||||
# ---------------------------
|
||||
|
||||
print(f"[NUM] 总轮数: {self.current_round}")
|
||||
print(f"[GRAPH] 收集图片: {len(all_figures)} 个")
|
||||
|
||||
# 构建用于生成最终报告的提示词
|
||||
final_report_prompt = self._build_final_report_prompt(all_figures)
|
||||
@@ -386,23 +887,22 @@ class DataAnalysisAgent:
|
||||
max_tokens=16384, # 设置较大的token限制以容纳完整报告
|
||||
)
|
||||
|
||||
# 解析响应,提取最终报告
|
||||
try:
|
||||
yaml_data = self.llm.parse_yaml_response(response)
|
||||
if yaml_data.get("action") == "analysis_complete":
|
||||
final_report_content = yaml_data.get("final_report", "报告生成失败")
|
||||
else:
|
||||
final_report_content = (
|
||||
"LLM未返回analysis_complete动作,报告生成失败"
|
||||
)
|
||||
except:
|
||||
# 如果解析失败,直接使用响应内容
|
||||
final_report_content = response
|
||||
# 直接使用LLM响应作为最终报告(因为我们在prompt中要求直接输出Markdown)
|
||||
final_report_content = response
|
||||
|
||||
# 兼容旧逻辑:如果意外返回了YAML,尝试解析
|
||||
if response.strip().startswith("action:") or "final_report:" in response:
|
||||
try:
|
||||
yaml_data = self.llm.parse_yaml_response(response)
|
||||
if yaml_data.get("action") == "analysis_complete":
|
||||
final_report_content = yaml_data.get("final_report", response)
|
||||
except:
|
||||
pass # 解析失败则保持原样
|
||||
|
||||
print("✅ 最终报告生成完成")
|
||||
print("[OK] 最终报告生成完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 生成最终报告时出错: {str(e)}")
|
||||
print(f"[ERROR] 生成最终报告时出错: {str(e)}")
|
||||
final_report_content = f"报告生成失败: {str(e)}"
|
||||
|
||||
# 保存最终报告到文件
|
||||
@@ -410,9 +910,21 @@ class DataAnalysisAgent:
|
||||
try:
|
||||
with open(report_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(final_report_content)
|
||||
print(f"📄 最终报告已保存至: {report_file_path}")
|
||||
print(f"[DOC] 最终报告已保存至: {report_file_path}")
|
||||
except Exception as e:
|
||||
print(f"❌ 保存报告文件失败: {str(e)}")
|
||||
print(f"[ERROR] 保存报告文件失败: {str(e)}")
|
||||
|
||||
# 生成可复用脚本
|
||||
script_path = ""
|
||||
try:
|
||||
script_path = generate_reusable_script(
|
||||
analysis_results=self.analysis_results,
|
||||
data_files=self.data_files,
|
||||
session_output_dir=self.session_output_dir,
|
||||
user_requirement=self.user_requirement
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[WARN] 脚本生成失败: {e}")
|
||||
|
||||
# 返回完整的分析结果
|
||||
return {
|
||||
@@ -423,6 +935,7 @@ class DataAnalysisAgent:
|
||||
"conversation_history": self.conversation_history,
|
||||
"final_report": final_report_content,
|
||||
"report_file_path": report_file_path,
|
||||
"reusable_script_path": script_path,
|
||||
}
|
||||
|
||||
def _build_final_report_prompt(self, all_figures: List[Dict[str, Any]]) -> str:
|
||||
@@ -457,24 +970,63 @@ class DataAnalysisAgent:
|
||||
f"输出: {exec_result.get('output')[:]}\n\n"
|
||||
)
|
||||
|
||||
# 构建各轮次证据数据摘要
|
||||
evidence_summary = ""
|
||||
if self._session_ref and self._session_ref.rounds:
|
||||
evidence_parts = []
|
||||
for rd in self._session_ref.rounds:
|
||||
round_num = rd.get("round", 0)
|
||||
summary = rd.get("result_summary", "")
|
||||
evidence = rd.get("evidence_rows", [])
|
||||
reasoning = rd.get("reasoning", "")
|
||||
part = f"第{round_num}轮: {summary}"
|
||||
if reasoning:
|
||||
part += f"\n 推理: {reasoning[:200]}"
|
||||
if evidence:
|
||||
part += f"\n 数据样本({len(evidence)}行): {json.dumps(evidence[:3], ensure_ascii=False, default=str)}"
|
||||
evidence_parts.append(part)
|
||||
evidence_summary = "\n".join(evidence_parts)
|
||||
|
||||
# 使用 prompts.py 中的统一提示词模板,并添加相对路径使用说明
|
||||
prompt = final_report_system_prompt.format(
|
||||
current_round=self.current_round,
|
||||
session_output_dir=self.session_output_dir,
|
||||
data_profile=self.data_profile, # 注入数据画像
|
||||
figures_summary=figures_summary,
|
||||
code_results_summary=code_results_summary,
|
||||
)
|
||||
|
||||
# Append evidence data from all rounds for evidence annotation
|
||||
if evidence_summary:
|
||||
prompt += f"""
|
||||
|
||||
**各轮次分析证据数据 (Evidence by Round)**:
|
||||
以下是每轮分析的结果摘要和数据样本,请在报告中使用 `<!-- evidence:round_N -->` 标注引用了哪一轮的数据:
|
||||
|
||||
{evidence_summary}
|
||||
"""
|
||||
|
||||
# 在提示词中明确要求使用相对路径
|
||||
prompt += """
|
||||
|
||||
📁 **图片路径使用说明**:
|
||||
[FOLDER] **图片路径使用说明**:
|
||||
报告和图片都在同一目录下,请在报告中使用相对路径引用图片:
|
||||
- 格式:
|
||||
- 格式:
|
||||
- 示例:
|
||||
- 这样可以确保报告在不同环境下都能正确显示图片
|
||||
- 注意:必须使用实际生成的图片文件名,严禁使用占位符
|
||||
"""
|
||||
|
||||
# Append actual data files list so the LLM uses real filenames in the report
|
||||
if self._session_ref and self._session_ref.data_files:
|
||||
data_files_summary = "\n**已生成的数据文件列表** (请在报告中使用这些实际文件名,替换模板中的占位文件名如 [4-1TSP问题聚类.xlsx]):\n"
|
||||
for df_meta in self._session_ref.data_files:
|
||||
fname = df_meta.get("filename", "")
|
||||
desc = df_meta.get("description", "")
|
||||
rows = df_meta.get("rows", 0)
|
||||
data_files_summary += f"- {fname} ({rows}行): {desc}\n"
|
||||
data_files_summary += "\n注意:报告模板中的 `[4-1TSP问题聚类.xlsx]` 等占位文件名必须替换为上述实际文件名。如果某类聚类文件未生成,请说明原因(如数据量不足或该分类不适用),不要保留占位符。\n"
|
||||
prompt += data_files_summary
|
||||
|
||||
return prompt
|
||||
|
||||
def reset(self):
|
||||
|
||||
89
data_preprocessing/README.md
Normal file
89
data_preprocessing/README.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# 数据预处理模块
|
||||
|
||||
独立的数据清洗工具,用于在正式分析前准备数据。
|
||||
|
||||
## 功能
|
||||
|
||||
- **数据合并**:将多个 Excel/CSV 文件合并为单一 CSV
|
||||
- **时间排序**:按时间列对数据进行排序
|
||||
- **目录管理**:标准化的原始数据和输出数据目录
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
project/
|
||||
├── raw_data/ # 原始数据存放目录
|
||||
│ ├── remotecontrol/ # 按数据来源分类
|
||||
│ └── ...
|
||||
├── cleaned_data/ # 清洗后数据输出目录
|
||||
│ ├── xxx_merged.csv
|
||||
│ └── xxx_sorted.csv
|
||||
└── data_preprocessing/ # 本模块
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 命令行
|
||||
|
||||
```bash
|
||||
# 初始化目录结构
|
||||
python -m data_preprocessing.cli init
|
||||
|
||||
# 合并 Excel 文件
|
||||
python -m data_preprocessing.cli merge --source raw_data/remotecontrol
|
||||
|
||||
# 合并并按时间排序
|
||||
python -m data_preprocessing.cli merge --source raw_data/remotecontrol --sort-by SendTime
|
||||
|
||||
# 指定输出路径
|
||||
python -m data_preprocessing.cli merge -s raw_data/remotecontrol -o cleaned_data/my_output.csv
|
||||
|
||||
# 排序已有 CSV
|
||||
python -m data_preprocessing.cli sort --input some_file.csv --time-col SendTime
|
||||
|
||||
# 原地排序(覆盖原文件)
|
||||
python -m data_preprocessing.cli sort --input data.csv --inplace
|
||||
```
|
||||
|
||||
### Python API
|
||||
|
||||
```python
|
||||
from data_preprocessing import merge_files, sort_by_time, Config
|
||||
|
||||
# 合并文件
|
||||
output_path = merge_files(
|
||||
source_dir="raw_data/remotecontrol",
|
||||
output_file="cleaned_data/merged.csv",
|
||||
pattern="*.xlsx",
|
||||
time_column="SendTime" # 可选:合并后排序
|
||||
)
|
||||
|
||||
# 排序 CSV
|
||||
sorted_path = sort_by_time(
|
||||
input_path="data.csv",
|
||||
output_path="sorted_data.csv",
|
||||
time_column="CreateTime"
|
||||
)
|
||||
|
||||
# 自定义配置
|
||||
config = Config()
|
||||
config.raw_data_dir = "/path/to/raw"
|
||||
config.cleaned_data_dir = "/path/to/cleaned"
|
||||
config.ensure_dirs()
|
||||
```
|
||||
|
||||
## 配置项
|
||||
|
||||
| 配置项 | 默认值 | 说明 |
|
||||
|--------|--------|------|
|
||||
| `raw_data_dir` | `raw_data/` | 原始数据目录 |
|
||||
| `cleaned_data_dir` | `cleaned_data/` | 清洗输出目录 |
|
||||
| `default_time_column` | `SendTime` | 默认时间列名 |
|
||||
| `csv_encoding` | `utf-8-sig` | CSV 编码格式 |
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 本模块与 `DataAnalysisAgent` 完全独立,不会相互调用
|
||||
2. 合并时会自动添加 `_source_file` 列标记数据来源(可用 `--no-source-col` 禁用)
|
||||
3. Excel 文件会自动合并所有 Sheet
|
||||
4. 无效时间值在排序时会被放到最后
|
||||
14
data_preprocessing/__init__.py
Normal file
14
data_preprocessing/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据预处理模块
|
||||
|
||||
提供独立的数据清洗功能:
|
||||
- 按时间排序
|
||||
- 同类数据合并
|
||||
"""
|
||||
|
||||
from .sorter import sort_by_time
|
||||
from .merger import merge_files
|
||||
from .config import Config
|
||||
|
||||
__all__ = ["sort_by_time", "merge_files", "Config"]
|
||||
140
data_preprocessing/cli.py
Normal file
140
data_preprocessing/cli.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据预处理命令行接口
|
||||
|
||||
使用示例:
|
||||
# 合并 Excel 文件
|
||||
python -m data_preprocessing.cli merge --source raw_data/remotecontrol --output cleaned_data/merged.csv
|
||||
|
||||
# 合并并排序
|
||||
python -m data_preprocessing.cli merge --source raw_data/remotecontrol --sort-by SendTime
|
||||
|
||||
# 排序已有 CSV
|
||||
python -m data_preprocessing.cli sort --input data.csv --output sorted.csv --time-col SendTime
|
||||
|
||||
# 初始化目录结构
|
||||
python -m data_preprocessing.cli init
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from .config import default_config
|
||||
from .sorter import sort_by_time
|
||||
from .merger import merge_files
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="data_preprocessing",
|
||||
description="数据预处理工具:排序、合并",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
%(prog)s merge --source raw_data/remotecontrol --sort-by SendTime
|
||||
%(prog)s sort --input data.csv --time-col CreateTime
|
||||
%(prog)s init
|
||||
"""
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="可用命令")
|
||||
|
||||
# ========== merge 命令 ==========
|
||||
merge_parser = subparsers.add_parser("merge", help="合并同类文件")
|
||||
merge_parser.add_argument(
|
||||
"--source", "-s",
|
||||
required=True,
|
||||
help="源数据目录路径"
|
||||
)
|
||||
merge_parser.add_argument(
|
||||
"--output", "-o",
|
||||
default=None,
|
||||
help="输出文件路径 (默认: cleaned_data/<目录名>_merged.csv)"
|
||||
)
|
||||
merge_parser.add_argument(
|
||||
"--pattern", "-p",
|
||||
default="*.xlsx",
|
||||
help="文件匹配模式 (默认: *.xlsx)"
|
||||
)
|
||||
merge_parser.add_argument(
|
||||
"--sort-by",
|
||||
default=None,
|
||||
dest="time_column",
|
||||
help="合并后按此时间列排序"
|
||||
)
|
||||
merge_parser.add_argument(
|
||||
"--no-source-col",
|
||||
action="store_true",
|
||||
help="不添加来源文件列"
|
||||
)
|
||||
|
||||
# ========== sort 命令 ==========
|
||||
sort_parser = subparsers.add_parser("sort", help="按时间排序 CSV")
|
||||
sort_parser.add_argument(
|
||||
"--input", "-i",
|
||||
required=True,
|
||||
help="输入 CSV 文件路径"
|
||||
)
|
||||
sort_parser.add_argument(
|
||||
"--output", "-o",
|
||||
default=None,
|
||||
help="输出文件路径 (默认: cleaned_data/<文件名>_sorted.csv)"
|
||||
)
|
||||
sort_parser.add_argument(
|
||||
"--time-col", "-t",
|
||||
default=None,
|
||||
dest="time_column",
|
||||
help=f"时间列名 (默认: {default_config.default_time_column})"
|
||||
)
|
||||
sort_parser.add_argument(
|
||||
"--inplace",
|
||||
action="store_true",
|
||||
help="原地覆盖输入文件"
|
||||
)
|
||||
|
||||
# ========== init 命令 ==========
|
||||
init_parser = subparsers.add_parser("init", help="初始化目录结构")
|
||||
|
||||
# 解析参数
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command is None:
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
|
||||
try:
|
||||
if args.command == "merge":
|
||||
result = merge_files(
|
||||
source_dir=args.source,
|
||||
output_file=args.output,
|
||||
pattern=args.pattern,
|
||||
time_column=args.time_column,
|
||||
add_source_column=not args.no_source_col
|
||||
)
|
||||
print(f"\n✅ 合并成功: {result}")
|
||||
|
||||
elif args.command == "sort":
|
||||
result = sort_by_time(
|
||||
input_path=args.input,
|
||||
output_path=args.output,
|
||||
time_column=args.time_column,
|
||||
inplace=args.inplace
|
||||
)
|
||||
print(f"\n✅ 排序成功: {result}")
|
||||
|
||||
elif args.command == "init":
|
||||
default_config.ensure_dirs()
|
||||
print("\n✅ 目录初始化完成")
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"\n❌ 错误: {e}")
|
||||
sys.exit(1)
|
||||
except KeyError as e:
|
||||
print(f"\n❌ 错误: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ 未知错误: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
42
data_preprocessing/config.py
Normal file
42
data_preprocessing/config.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据预处理模块配置
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
# 获取项目根目录
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""预处理模块配置"""
|
||||
|
||||
# 原始数据存放目录
|
||||
raw_data_dir: str = os.path.join(PROJECT_ROOT, "raw_data")
|
||||
|
||||
# 清洗后数据输出目录
|
||||
cleaned_data_dir: str = os.path.join(PROJECT_ROOT, "cleaned_data")
|
||||
|
||||
# 默认时间列名
|
||||
default_time_column: str = "SendTime"
|
||||
|
||||
# 支持的文件扩展名
|
||||
supported_extensions: tuple = (".csv", ".xlsx", ".xls")
|
||||
|
||||
# CSV 编码
|
||||
csv_encoding: str = "utf-8-sig"
|
||||
|
||||
def ensure_dirs(self):
|
||||
"""确保目录存在"""
|
||||
os.makedirs(self.raw_data_dir, exist_ok=True)
|
||||
os.makedirs(self.cleaned_data_dir, exist_ok=True)
|
||||
print(f"[OK] 目录已就绪:")
|
||||
print(f" 原始数据: {self.raw_data_dir}")
|
||||
print(f" 清洗输出: {self.cleaned_data_dir}")
|
||||
|
||||
|
||||
# 默认配置实例
|
||||
default_config = Config()
|
||||
83
data_preprocessing/merge_excel.py
Normal file
83
data_preprocessing/merge_excel.py
Normal file
@@ -0,0 +1,83 @@
|
||||
|
||||
import pandas as pd
|
||||
import glob
|
||||
import os
|
||||
|
||||
def merge_excel_files(source_dir="remotecontrol", output_file="merged_all_files.csv"):
|
||||
"""
|
||||
将指定目录下的所有 Excel 文件 (.xlsx, .xls) 合并为一个 CSV 文件。
|
||||
"""
|
||||
print(f"[SEARCH] 正在扫描目录: {source_dir} ...")
|
||||
|
||||
# 支持 xlsx 和 xls
|
||||
files_xlsx = glob.glob(os.path.join(source_dir, "*.xlsx"))
|
||||
files_xls = glob.glob(os.path.join(source_dir, "*.xls"))
|
||||
files = files_xlsx + files_xls
|
||||
|
||||
if not files:
|
||||
print("[WARN] 未找到 Excel 文件。")
|
||||
return
|
||||
|
||||
# 按文件名中的数字进行排序 (例如: 1.xlsx, 2.xlsx, ..., 10.xlsx)
|
||||
try:
|
||||
files.sort(key=lambda x: int(os.path.basename(x).split('.')[0]))
|
||||
print("[NUM] 已按文件名数字顺序排序")
|
||||
except ValueError:
|
||||
# 如果文件名不是纯数字,退回到字母排序
|
||||
files.sort()
|
||||
print("[TEXT] 已按文件名包含非数字字符,使用字母顺序排序")
|
||||
|
||||
print(f"[DIR] 找到 {len(files)} 个文件: {files}")
|
||||
|
||||
all_dfs = []
|
||||
for file in files:
|
||||
try:
|
||||
print(f"[READ] 读取: {file}")
|
||||
# 使用 ExcelFile 读取所有 sheet
|
||||
xls = pd.ExcelFile(file)
|
||||
print(f" [PAGES] 包含 Sheets: {xls.sheet_names}")
|
||||
|
||||
file_dfs = []
|
||||
for sheet_name in xls.sheet_names:
|
||||
df = pd.read_excel(xls, sheet_name=sheet_name)
|
||||
if not df.empty:
|
||||
print(f" [OK] Sheet '{sheet_name}' 读取成功: {len(df)} 行")
|
||||
file_dfs.append(df)
|
||||
else:
|
||||
print(f" [WARN] Sheet '{sheet_name}' 为空,跳过")
|
||||
|
||||
if file_dfs:
|
||||
# 合并该文件的所有非空 sheet
|
||||
file_merged_df = pd.concat(file_dfs, ignore_index=True)
|
||||
# 可选:添加一列标记来源文件
|
||||
file_merged_df['Source_File'] = os.path.basename(file)
|
||||
all_dfs.append(file_merged_df)
|
||||
else:
|
||||
print(f"[WARN] 文件 {file} 所有 Sheet 均为空")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 读取 {file} 失败: {e}")
|
||||
|
||||
if all_dfs:
|
||||
print("[LOOP] 正在合并数据...")
|
||||
merged_df = pd.concat(all_dfs, ignore_index=True)
|
||||
|
||||
# 按 SendTime 排序
|
||||
if 'SendTime' in merged_df.columns:
|
||||
print("[TIMER] 正在按 SendTime 排序...")
|
||||
merged_df['SendTime'] = pd.to_datetime(merged_df['SendTime'], errors='coerce')
|
||||
merged_df = merged_df.sort_values(by='SendTime')
|
||||
else:
|
||||
print("[WARN] 未找到 SendTime 列,跳过排序")
|
||||
|
||||
print(f"[CACHE] 保存到: {output_file}")
|
||||
merged_df.to_csv(output_file, index=False, encoding="utf-8-sig")
|
||||
|
||||
print(f"[OK] 合并及排序完成!总行数: {len(merged_df)}")
|
||||
print(f" 输出文件: {os.path.abspath(output_file)}")
|
||||
else:
|
||||
print("[WARN] 没有成功读取到任何数据。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 如果需要在当前目录运行并合并 remotecontrol 文件夹下的内容
|
||||
merge_excel_files(source_dir="remotecontrol", output_file="remotecontrol_merged.csv")
|
||||
148
data_preprocessing/merger.py
Normal file
148
data_preprocessing/merger.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据合并模块
|
||||
|
||||
合并同类 Excel/CSV 文件
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob
|
||||
import pandas as pd
|
||||
from typing import Optional, List
|
||||
from .config import default_config
|
||||
|
||||
|
||||
def merge_files(
|
||||
source_dir: str,
|
||||
output_file: Optional[str] = None,
|
||||
pattern: str = "*.xlsx",
|
||||
time_column: Optional[str] = None,
|
||||
add_source_column: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
合并目录下的所有同类文件
|
||||
|
||||
Args:
|
||||
source_dir: 源数据目录
|
||||
output_file: 输出 CSV 文件路径。如果为 None,则输出到 cleaned_data 目录
|
||||
pattern: 文件匹配模式 (e.g., "*.xlsx", "*.csv", "*.xls")
|
||||
time_column: 可选,合并后按此列排序
|
||||
add_source_column: 是否添加来源文件列
|
||||
|
||||
Returns:
|
||||
输出文件的绝对路径
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 目录不存在或未找到匹配文件
|
||||
"""
|
||||
if not os.path.isdir(source_dir):
|
||||
raise FileNotFoundError(f"目录不存在: {source_dir}")
|
||||
|
||||
print(f"[SCAN] 正在扫描目录: {source_dir}")
|
||||
print(f" 匹配模式: {pattern}")
|
||||
|
||||
# 查找匹配文件
|
||||
files = glob.glob(os.path.join(source_dir, pattern))
|
||||
|
||||
# 如果是 xlsx,也尝试匹配 xls
|
||||
if pattern == "*.xlsx":
|
||||
files.extend(glob.glob(os.path.join(source_dir, "*.xls")))
|
||||
|
||||
if not files:
|
||||
raise FileNotFoundError(f"未找到匹配 '{pattern}' 的文件")
|
||||
|
||||
# 排序文件列表
|
||||
files = _sort_files(files)
|
||||
print(f"[FOUND] 找到 {len(files)} 个文件")
|
||||
|
||||
# 确定输出路径
|
||||
if output_file is None:
|
||||
default_config.ensure_dirs()
|
||||
dir_name = os.path.basename(os.path.normpath(source_dir))
|
||||
output_file = os.path.join(
|
||||
default_config.cleaned_data_dir,
|
||||
f"{dir_name}_merged.csv"
|
||||
)
|
||||
|
||||
# 合并数据
|
||||
all_dfs = []
|
||||
for file in files:
|
||||
try:
|
||||
df = _read_file(file)
|
||||
if df is not None and not df.empty:
|
||||
if add_source_column:
|
||||
df['_source_file'] = os.path.basename(file)
|
||||
all_dfs.append(df)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 读取失败 {file}: {e}")
|
||||
|
||||
if not all_dfs:
|
||||
raise ValueError("没有成功读取到任何数据")
|
||||
|
||||
print(f"[MERGE] 正在合并 {len(all_dfs)} 个数据源...")
|
||||
merged_df = pd.concat(all_dfs, ignore_index=True)
|
||||
print(f" 合并后总行数: {len(merged_df)}")
|
||||
|
||||
# 可选:按时间排序
|
||||
if time_column and time_column in merged_df.columns:
|
||||
print(f"[SORT] 正在按 '{time_column}' 排序...")
|
||||
merged_df[time_column] = pd.to_datetime(merged_df[time_column], errors='coerce')
|
||||
merged_df = merged_df.sort_values(by=time_column, na_position='last')
|
||||
elif time_column:
|
||||
print(f"[WARN] 未找到时间列 '{time_column}',跳过排序")
|
||||
|
||||
# 保存结果
|
||||
print(f"[SAVE] 正在保存: {output_file}")
|
||||
merged_df.to_csv(output_file, index=False, encoding=default_config.csv_encoding)
|
||||
|
||||
abs_output = os.path.abspath(output_file)
|
||||
print(f"[OK] 合并完成!")
|
||||
print(f" 输出文件: {abs_output}")
|
||||
print(f" 总行数: {len(merged_df)}")
|
||||
|
||||
return abs_output
|
||||
|
||||
|
||||
def _sort_files(files: List[str]) -> List[str]:
|
||||
"""对文件列表进行智能排序"""
|
||||
try:
|
||||
# 尝试按文件名中的数字排序
|
||||
files.sort(key=lambda x: int(os.path.basename(x).split('.')[0]))
|
||||
print("[SORT] 已按文件名数字顺序排序")
|
||||
except ValueError:
|
||||
# 退回到字母排序
|
||||
files.sort()
|
||||
print("[SORT] 已按文件名字母顺序排序")
|
||||
return files
|
||||
|
||||
|
||||
def _read_file(file_path: str) -> Optional[pd.DataFrame]:
|
||||
"""读取单个文件(支持 CSV 和 Excel)"""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
print(f"[READ] 读取: {os.path.basename(file_path)}")
|
||||
|
||||
if ext == '.csv':
|
||||
df = pd.read_csv(file_path, low_memory=False)
|
||||
print(f" 行数: {len(df)}")
|
||||
return df
|
||||
|
||||
elif ext in ('.xlsx', '.xls'):
|
||||
# 读取 Excel 所有 sheet 并合并
|
||||
xls = pd.ExcelFile(file_path)
|
||||
print(f" Sheets: {xls.sheet_names}")
|
||||
|
||||
sheet_dfs = []
|
||||
for sheet_name in xls.sheet_names:
|
||||
df = pd.read_excel(xls, sheet_name=sheet_name)
|
||||
if not df.empty:
|
||||
print(f" - Sheet '{sheet_name}': {len(df)} 行")
|
||||
sheet_dfs.append(df)
|
||||
|
||||
if sheet_dfs:
|
||||
return pd.concat(sheet_dfs, ignore_index=True)
|
||||
return None
|
||||
|
||||
else:
|
||||
print(f"[WARN] 不支持的文件格式: {ext}")
|
||||
return None
|
||||
45
data_preprocessing/sort_csv.py
Normal file
45
data_preprocessing/sort_csv.py
Normal file
@@ -0,0 +1,45 @@
|
||||
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
def sort_csv_by_time(file_path="remotecontrol_merged.csv", time_col="SendTime"):
|
||||
"""
|
||||
读取 CSV 文件,按时间列排序,并保存。
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"[ERROR] 文件不存在: {file_path}")
|
||||
return
|
||||
|
||||
print(f"[READ] 正在读取 {file_path} ...")
|
||||
try:
|
||||
# 读取 CSV
|
||||
df = pd.read_csv(file_path, low_memory=False)
|
||||
print(f" [CHART] 数据行数: {len(df)}")
|
||||
|
||||
if time_col not in df.columns:
|
||||
print(f"[ERROR] 未找到时间列: {time_col}")
|
||||
print(f" 可用列: {list(df.columns)}")
|
||||
return
|
||||
|
||||
print(f"[LOOP] 正在解析时间列 '{time_col}' ...")
|
||||
# 转换为 datetime 对象,无法解析的设为 NaT
|
||||
df[time_col] = pd.to_datetime(df[time_col], errors='coerce')
|
||||
|
||||
# 检查无效时间
|
||||
nat_count = df[time_col].isna().sum()
|
||||
if nat_count > 0:
|
||||
print(f"[WARN] 发现 {nat_count} 行无效时间数据,排序时将排在最后")
|
||||
|
||||
print("[LOOP] 正在按时间排序...")
|
||||
df_sorted = df.sort_values(by=time_col)
|
||||
|
||||
print(f"[CACHE] 正在保存及覆盖文件: {file_path} ...")
|
||||
df_sorted.to_csv(file_path, index=False, encoding="utf-8-sig")
|
||||
|
||||
print("[OK] 排序并保存完成!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR]处理失败: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
sort_csv_by_time()
|
||||
82
data_preprocessing/sorter.py
Normal file
82
data_preprocessing/sorter.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据排序模块
|
||||
|
||||
按时间列对 CSV 文件进行排序
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from typing import Optional
|
||||
from .config import default_config
|
||||
|
||||
|
||||
def sort_by_time(
|
||||
input_path: str,
|
||||
output_path: Optional[str] = None,
|
||||
time_column: str = None,
|
||||
inplace: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
按时间列对 CSV 文件排序
|
||||
|
||||
Args:
|
||||
input_path: 输入 CSV 文件路径
|
||||
output_path: 输出路径。如果为 None 且 inplace=False,则输出到 cleaned_data 目录
|
||||
time_column: 时间列名,默认使用配置中的 default_time_column
|
||||
inplace: 是否原地覆盖输入文件
|
||||
|
||||
Returns:
|
||||
输出文件的绝对路径
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 输入文件不存在
|
||||
KeyError: 时间列不存在
|
||||
"""
|
||||
# 参数处理
|
||||
time_column = time_column or default_config.default_time_column
|
||||
|
||||
if not os.path.exists(input_path):
|
||||
raise FileNotFoundError(f"文件不存在: {input_path}")
|
||||
|
||||
# 确定输出路径
|
||||
if inplace:
|
||||
output_path = input_path
|
||||
elif output_path is None:
|
||||
default_config.ensure_dirs()
|
||||
basename = os.path.basename(input_path)
|
||||
name, ext = os.path.splitext(basename)
|
||||
output_path = os.path.join(
|
||||
default_config.cleaned_data_dir,
|
||||
f"{name}_sorted{ext}"
|
||||
)
|
||||
|
||||
print(f"[READ] 正在读取: {input_path}")
|
||||
df = pd.read_csv(input_path, low_memory=False)
|
||||
print(f" 数据行数: {len(df)}")
|
||||
|
||||
# 检查时间列是否存在
|
||||
if time_column not in df.columns:
|
||||
available_cols = list(df.columns)
|
||||
raise KeyError(
|
||||
f"未找到时间列 '{time_column}'。可用列: {available_cols}"
|
||||
)
|
||||
|
||||
print(f"[PARSE] 正在解析时间列 '{time_column}'...")
|
||||
df[time_column] = pd.to_datetime(df[time_column], errors='coerce')
|
||||
|
||||
# 统计无效时间
|
||||
nat_count = df[time_column].isna().sum()
|
||||
if nat_count > 0:
|
||||
print(f"[WARN] 发现 {nat_count} 行无效时间数据,排序时将排在最后")
|
||||
|
||||
print("[SORT] 正在按时间排序...")
|
||||
df_sorted = df.sort_values(by=time_column, na_position='last')
|
||||
|
||||
print(f"[SAVE] 正在保存: {output_path}")
|
||||
df_sorted.to_csv(output_path, index=False, encoding=default_config.csv_encoding)
|
||||
|
||||
abs_output = os.path.abspath(output_path)
|
||||
print(f"[OK] 排序完成!输出文件: {abs_output}")
|
||||
|
||||
return abs_output
|
||||
81
main.py
81
main.py
@@ -1,18 +1,81 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
CLI 入口 - 数据分析智能体
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from data_analysis_agent import DataAnalysisAgent
|
||||
from config.llm_config import LLMConfig
|
||||
from utils.create_session_dir import create_session_output_dir
|
||||
from utils.logger import PrintCapture
|
||||
|
||||
|
||||
def main():
|
||||
llm_config = LLMConfig()
|
||||
# 如果希望强制运行到最大轮数,设置 force_max_rounds=True
|
||||
agent = DataAnalysisAgent(llm_config, force_max_rounds=False)
|
||||
files = ["./UB IOV Support_TR.csv"]
|
||||
report = agent.analyze(
|
||||
user_input="基于所有有关远程控制的问题,以及涉及车控APP的运维工单的数据,输出若干个重要的统计指标,并绘制相关图表。总结一份,车控APP,及远程控制工单健康度报告,最后生成汇报给我。",
|
||||
files=files,
|
||||
)
|
||||
print(report)
|
||||
|
||||
|
||||
# 自动查找数据文件
|
||||
data_extensions = ["*.csv", "*.xlsx", "*.xls"]
|
||||
search_dirs = ["cleaned_data"]
|
||||
files = []
|
||||
|
||||
for search_dir in search_dirs:
|
||||
for ext in data_extensions:
|
||||
pattern = os.path.join(search_dir, ext)
|
||||
files.extend(glob.glob(pattern))
|
||||
|
||||
if not files:
|
||||
print("[WARN] 未在 cleaned_data 目录找到数据文件,尝试使用默认文件")
|
||||
files = ["./cleaned_data.csv"]
|
||||
else:
|
||||
print(f"[DIR] 自动识别到以下数据文件: {files}")
|
||||
|
||||
analysis_requirement = """
|
||||
基于所有运维工单,整理一份工单健康度报告,包括但不限于对所有车联网技术支持工单的全面数据分析,
|
||||
深入挖掘工单处理过程中的关键问题、效率瓶颈及改进机会。请从车型,模块,功能角度,分别展示工单数据、问题类型、模块分布、严重程度、责任人负载、车型分布、来源渠道及处理时长等多个维度。
|
||||
通过多轮交叉分析与趋势洞察,为提升车联网服务质量、优化资源配置及降低运营风险提供数据驱动的决策依据,问题总揽,高频问题、重点问题分析,输出若干个重要的统计指标,并绘制相关图表;
|
||||
结合图表,总结一份,车联网运维工单健康度报告,汇报给我。
|
||||
"""
|
||||
|
||||
# 创建会话目录
|
||||
base_output_dir = "outputs"
|
||||
session_output_dir = create_session_output_dir(base_output_dir, analysis_requirement)
|
||||
|
||||
# 使用 PrintCapture 替代全局 stdout 劫持
|
||||
log_path = os.path.join(session_output_dir, "log.txt")
|
||||
|
||||
with PrintCapture(log_path):
|
||||
print(f"\n{'='*20} Run Started at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} {'='*20}\n")
|
||||
print(f"[DOC] 日志文件已保存至: {log_path}")
|
||||
|
||||
agent = DataAnalysisAgent(llm_config, force_max_rounds=False)
|
||||
|
||||
# 交互式分析循环
|
||||
while True:
|
||||
is_first_run = agent.current_round == 0 and not agent.conversation_history
|
||||
|
||||
report = agent.analyze(
|
||||
user_input=analysis_requirement,
|
||||
files=files if is_first_run else None,
|
||||
session_output_dir=session_output_dir,
|
||||
reset_session=is_first_run,
|
||||
max_rounds=None if is_first_run else 10,
|
||||
)
|
||||
print("\n" + "=" * 30 + " 当前阶段分析完成 " + "=" * 30)
|
||||
|
||||
print("\n[TIP] 你可以继续对数据提出分析需求,或者输入 'exit'/'quit' 结束程序。")
|
||||
user_response = input("[>] 请输入后续分析需求 (直接回车退出): ").strip()
|
||||
|
||||
if not user_response or user_response.lower() in ["exit", "quit", "n", "no"]:
|
||||
print("[BYE] 分析结束,再见!")
|
||||
break
|
||||
|
||||
analysis_requirement = user_response
|
||||
print(f"\n[LOOP] 收到新需求,正在继续分析...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
590
prompts.py
590
prompts.py
@@ -1,288 +1,384 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
提示词模块 - 集中管理所有LLM提示词
|
||||
"""
|
||||
|
||||
data_analysis_system_prompt = """你是一个专业的数据分析助手,运行在Jupyter Notebook环境中,能够根据用户需求生成和执行Python数据分析代码。
|
||||
**核心使命**:
|
||||
- 接收自然语言需求,分阶段生成高效、安全的数据分析代码。
|
||||
- 深度挖掘数据,不仅仅是绘图,更要发现数据背后的业务洞察。
|
||||
- 输出高质量、可落地的业务分析报告。
|
||||
|
||||
**重要指导原则**:
|
||||
- 当需要执行Python代码(数据加载、分析、可视化)时,使用 `generate_code` 动作
|
||||
- 当需要收集和分析已生成的图表时,使用 `collect_figures` 动作
|
||||
- 当所有分析工作完成,需要输出最终报告时,使用 `analysis_complete` 动作
|
||||
- 每次响应只能选择一种动作类型,不要混合使用
|
||||
- **强制文本清洗**:在处理文本数据(如工单描述、评论)时,**必须**构建并使用`stop_words`列表,剔除年份(2025)、通用动词(work, fix)、介词等无意义高频词。
|
||||
- **主动高级分析**:不仅是画图,必须根据数据特征主动选择算法(时间序列->预测;分类数据->特征重要性;多维数据->聚类)。
|
||||
**核心能力**:
|
||||
1. **代码执行**:自动编写并执行Pandas/Matplotlib代码。
|
||||
2. **多模态分析**:支持时序预测、文本挖掘(N-gram)、多维交叉分析。
|
||||
3. **智能纠错**:遇到报错自动分析原因并修复代码。
|
||||
|
||||
目前jupyter notebook环境下有以下变量:
|
||||
jupyter notebook环境当前变量:
|
||||
{notebook_variables}
|
||||
核心能力:
|
||||
1. 接收用户的自然语言分析需求
|
||||
2. 按步骤生成安全的Python分析代码
|
||||
3. 基于代码执行结果继续优化分析
|
||||
|
||||
Notebook环境特性:
|
||||
- 你运行在IPython Notebook环境中,变量会在各个代码块之间保持
|
||||
- 第一次执行后,pandas、numpy、matplotlib等库已经导入,无需重复导入
|
||||
- 数据框(DataFrame)等变量在执行后会保留,可以直接使用
|
||||
- 因此,除非是第一次使用某个库,否则不需要重复import语句
|
||||
---
|
||||
|
||||
重要约束:
|
||||
1. 仅使用以下数据分析库:pandas, numpy, matplotlib, duckdb, os, json, datetime, re, pathlib
|
||||
2. 图片必须保存到指定的会话目录中,输出绝对路径,禁止使用plt.show(),饼图的标签全部放在图例里面,用颜色区分。
|
||||
4. 表格输出控制:超过15行只显示前5行和后5行
|
||||
5. 中文字体设置:使用系统可用中文字体(macOS推荐:Hiragino Sans GB, Songti SC等)
|
||||
6. 输出格式严格使用YAML
|
||||
**关键红线 (Critical Rules)**:
|
||||
1. **进程保护**:严禁使用 `exit()`、`quit()` 或 `sys.exit()`,这会导致Agent崩溃。
|
||||
2. **数据安全**:严禁使用 `pd.DataFrame({{...}})` 伪造数据。严禁使用 `open()` 写入非结果文件(只能写图片/JSON)。
|
||||
3. **文件验证**:所有文件操作前必须 `os.path.exists()`。Excel读取失败必须尝试 `openpyxl` 引擎或 `read_csv`。
|
||||
4. **绝对路径**:图片保存、文件读取必须使用绝对路径。图片必须保存到 `session_output_dir`。
|
||||
5. **图片保存**:禁止 `plt.show()`。每次绘图后必须紧接 `plt.savefig(path)` 和 `plt.close()`。
|
||||
|
||||
---
|
||||
|
||||
输出目录管理:
|
||||
- 本次分析使用时间戳生成的专用目录,确保每次分析的输出文件隔离
|
||||
- 会话目录格式:session_[时间戳],如 session_20240105_143052
|
||||
- 图片保存路径格式:os.path.join(session_output_dir, '图片名称.png')
|
||||
- 使用有意义的中文文件名:如'营业收入趋势.png', '利润分析对比.png'
|
||||
- 每个图表保存后必须使用plt.close()释放内存
|
||||
- 输出绝对路径:使用os.path.abspath()获取图片的完整路径
|
||||
**代码生成规则 (Code Generation Rules)**:
|
||||
|
||||
数据分析工作流程(必须严格按顺序执行):
|
||||
**1. 执行策略**:
|
||||
- **分步执行**:每次只专注一个分析阶段(如"清洗"或"可视化"),不要试图一次性写完所有代码。
|
||||
- **环境持久化**:Notebook环境中变量(如 `df`)会保留,不要重复导入库或重复加载数据。
|
||||
- **错误处理**:捕获错误并尝试修复,严禁在分析中途放弃。
|
||||
|
||||
**阶段1:数据探索(使用 generate_code 动作)**
|
||||
- 首次数据加载时尝试多种编码:['utf-8', 'gbk', 'gb18030', 'gb2312', 'latin1']
|
||||
- 特殊处理:如果读取失败,尝试指定分隔符 `sep=','` 和错误处理 `error_bad_lines=False`
|
||||
- 使用df.head()查看前几行数据,检查数据是否正确读取
|
||||
- 使用df.info()了解数据类型和缺失值情况
|
||||
- 重点检查:如果数值列显示为NaN但应该有值,说明读取或解析有问题
|
||||
- 使用df.dtypes查看每列的数据类型,确保日期列不是float64
|
||||
- 打印所有列名:df.columns.tolist()
|
||||
- 绝对不要假设列名,必须先查看实际的列名
|
||||
**2. 可视化规范 (Visual Standards)**:
|
||||
- **中文字体**:必须配置字体以解决乱码:
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
import platform
|
||||
system_name = platform.system()
|
||||
if system_name == 'Darwin': plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'PingFang SC', 'sans-serif']
|
||||
elif system_name == 'Windows': plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'sans-serif']
|
||||
else: plt.rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'sans-serif']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
```
|
||||
- **图表类型**:
|
||||
- 类别 > 5:**强制**使用水平条形图 (`plt.barh`),并降序排列。
|
||||
- 类别 ≤ 5:才允许使用饼图,且图例必须外置 (`bbox_to_anchor=(1, 1)`)。
|
||||
- **美学要求**:去除非数据墨水(无边框、无网格),使用 Seaborn 默认色板,标题和标签必须为中文。
|
||||
- **文件命名**:使用中文描述业务含义(如 `核心问题词云.png`),**严禁**出现 `plot`, `dataframe`, `2-gram` 等技术术语。
|
||||
|
||||
**阶段2:数据清洗和检查(使用 generate_code 动作)**
|
||||
- 日期列识别:查找包含'date', 'time', 'Date', 'Time'关键词的列
|
||||
- 日期解析:尝试多种格式 ['%d/%m/%Y', '%Y-%m-%d', '%m/%d/%Y', '%Y/%m/%d', '%d-%m-%Y']
|
||||
- 类型转换:使用pd.to_datetime()转换日期列,指定format参数和errors='coerce'
|
||||
- 空值处理:检查哪些列应该有值但显示NaN,可能是数据读取问题
|
||||
- 检查数据的时间范围和排序
|
||||
- 数据质量检查:确认数值列是否正确,字符串列是否被错误识别
|
||||
**3. 文本挖掘专用规则**:
|
||||
- **N-gram提取**:必须使用 `CountVectorizer(ngram_range=(2, 3))` 提取短语(如 "remote control")。
|
||||
- **停用词过滤**:必须构建 `stop_words` 列表,剔除年份(2025)、通用动词(fix, check)、通用介词(the, for)等。
|
||||
|
||||
**4. 中间数据保存规则**:
|
||||
- 当你生成了有价值的中间数据(筛选子集、聚合表、聚类结果等),请主动保存为CSV/XLSX文件。
|
||||
- 保存后必须打印标记行:`[DATA_FILE_SAVED] filename: {{文件名}}, rows: {{行数}}, description: {{描述}}`
|
||||
- 示例:
|
||||
```python
|
||||
top_issues.to_csv(os.path.join(session_output_dir, "TOP问题汇总.csv"), index=False)
|
||||
print(f"[DATA_FILE_SAVED] filename: TOP问题汇总.csv, rows: {{len(top_issues)}}, description: 各类型TOP问题聚合统计")
|
||||
```
|
||||
- 这些文件会自动出现在"数据文件"面板中,方便用户浏览和下载。
|
||||
|
||||
**阶段3:数据分析和可视化(使用 generate_code 动作)**
|
||||
- 基于实际的列名进行计算
|
||||
- 生成有意义的图表
|
||||
- 图片保存到会话专用目录中
|
||||
- 每生成一个图表后,必须打印绝对路径
|
||||
---
|
||||
|
||||
**标准化分析SOP (Standard Operating Procedure)**:
|
||||
|
||||
**阶段4:深度挖掘与高级分析(使用 generate_code 动作)**
|
||||
- **主动评估数据特征**:在执行前,先分析数据适合哪种高级挖掘:
|
||||
- **时间序列数据**:必须进行趋势预测(使用sklearn/ARIMA/Prophet-like逻辑)和季节性分解。
|
||||
- **多维数值数据**:必须进行聚类分析(K-Means/DBSCAN)以发现用户/产品分层。
|
||||
- **分类/目标数据**:必须计算特征重要性(使用随机森林/相关性矩阵)以识别关键驱动因素。
|
||||
- **异常检测**:使用Isolation Forest或统计方法识别高价值或高风险的离群点。
|
||||
- **拒绝平庸**:不要为了做而做。如果数据量太小(<50行)或特征单一,请明确说明无法进行特定分析,并尝试挖掘其他角度(如分布偏度、帕累托分析)。
|
||||
- **业务导向**:每个模型结果必须翻译成业务语言(例如:“聚类结果显示,A类用户是高价值且对价格不敏感的群体”)。
|
||||
**阶段1:数据探索与智能加载**
|
||||
- 检查文件扩展名与实际格式是否一致(CSV vs Excel)。
|
||||
- 打印 `df.info()`, `df.head()`, 检查缺失值和列名。
|
||||
- 关键字段对齐('Model'->'车型', 'Module'->'模块')。
|
||||
|
||||
**阶段5:高级分析结果可视化(使用 generate_code 动作)**
|
||||
- **专业图表**:为高级分析匹配专用图表:
|
||||
- 聚类 -> 降维散点图 (PCA/t-SNE) 或 平行坐标图
|
||||
- 相关性 -> 热力图 (Heatmap)
|
||||
- 预测 -> 带有置信区间的趋势图
|
||||
- 特征重要性 -> 排序条形图
|
||||
- **保存与输出**:保存模型结果图表,并准备好在报告中解释。
|
||||
**阶段2:基础分布分析**
|
||||
- 生成 `车型分布.png` (水平条形图)
|
||||
- 生成 `模块Top10分布.png` (水平条形图)
|
||||
- 生成 `问题类型Top10分布.png` (水平条形图)
|
||||
|
||||
**阶段6:图片收集和分析(使用 collect_figures 动作)**
|
||||
- 当已生成多个图表后,使用 collect_figures 动作
|
||||
- 收集所有已生成的图片路径和信息
|
||||
- 对每个图片进行详细的分析和解读
|
||||
**阶段3:时序与来源分析**
|
||||
- 生成 `工单来源分布.png` (饼图或条形图)
|
||||
- 生成 `月度工单趋势.png` (折线图)
|
||||
|
||||
**阶段7:最终报告(使用 analysis_complete 动作)**
|
||||
- 当所有分析工作完成后,生成最终的分析报告
|
||||
- 包含对所有图片、模型和分析结果的综合总结
|
||||
- 提供业务建议和预测洞察
|
||||
**阶段4:深度交叉分析**
|
||||
- 生成 `车型_问题类型热力图.png` (Heatmap)
|
||||
- 生成 `模块_严重程度堆叠图.png` (Stacked Bar)
|
||||
|
||||
代码生成规则:
|
||||
1. 每次只专注一个阶段,不要试图一次性完成所有任务
|
||||
2. 基于实际的数据结构而不是假设来编写代码
|
||||
3. Notebook环境中变量会保持,避免重复导入和重复加载相同数据
|
||||
4. 处理错误时,分析具体的错误信息并针对性修复,重新进行改阶段步骤,中途不要跳步骤
|
||||
5. 图片保存使用会话目录变量:session_output_dir
|
||||
6. 图表标题和标签使用中文,使用系统配置的中文字体显示
|
||||
7. 必须打印绝对路径:每次保存图片后,使用os.path.abspath()打印完整的绝对路径
|
||||
8. 图片文件名:同时打印图片的文件名,方便后续收集时识别
|
||||
9. 饼图绘图代码生成必须遵守规则:类别 ≤ 5个:使用饼图 (plt.pie) + 外部图例,百分比标签清晰显示;类别 6-10个:使用水平条形图 (plt.barh) 便于阅读;类别 > 10个:使用排序条形图 + 合并小类别为"其他";学术美学要求**:白色背景、合适颜色、清晰标签、无冗余边框;
|
||||
**阶段5:效率分析**
|
||||
- **必做:按一级分类分组统计**:对每个一级分类(如TSP、APP、DK、咨询等)分别计算工单数量、平均时长、中位数时长,输出汇总表并保存为CSV。
|
||||
示例代码:`df.groupby('一级分类')['解决时长/h'].agg(['count','mean','median'])`
|
||||
- 生成 `处理时长分布.png` (直方图)
|
||||
- 生成 `责任人效率分析.png` (散点图: 工单量 vs 平均时长)
|
||||
|
||||
动作选择指南:
|
||||
- **需要执行Python代码** → 使用 "generate_code"
|
||||
- **已生成多个图表,需要收集分析** → 使用 "collect_figures"
|
||||
- **所有分析完成,输出最终报告** → 使用 "analysis_complete"
|
||||
- **遇到错误需要修复代码** → 使用 "generate_code"
|
||||
**阶段6:高级挖掘 (Active Exploration)**
|
||||
- **必做**:
|
||||
- **文本分析**:对'问题描述'列提取Top 20高频短语(N-gram),生成词云或条形图。
|
||||
- **异常检测**:使用Isolation Forest或3-Sigma原则发现异常工单。
|
||||
- **相关性分析**:生成相关性矩阵热力图(如有数值特征)。
|
||||
|
||||
高级分析技术指南(主动探索模式):
|
||||
- **智能选择算法**:
|
||||
- 遇到时间字段 -> `pd.to_datetime` -> 重采样 -> 移动平均/指数平滑/回归预测
|
||||
- 遇到多数值特征 -> `StandardScaler` -> `KMeans` (使用Elbow法则选k) -> `PCA`降维可视化
|
||||
- 遇到目标变量 -> `Correlation Matrix` -> `RandomForest` (feature_importances_)
|
||||
- **文本挖掘**:
|
||||
- 必须构建**专用停用词表** (Stop Words),过滤掉无效词汇:
|
||||
- 年份/数字:2023, 2024, 2025, 1月, 2月...
|
||||
- 通用动词:work, fix, support, issue, problem, check, test...
|
||||
- 通用介词/代词:the, is, at, which, on, for, this, that...
|
||||
- 仅保留具有实际业务含义的名词/动词短语(如 "connection timeout", "login failed")。
|
||||
- **异常值挖掘**:总是检查是否存在显著偏离均值的异常点,并标记出来进行个案分析。
|
||||
- **可视化增强**:不要只画折线图。使用 `seaborn` 的 `pairplot`, `heatmap`, `lmplot` 等高级图表。
|
||||
---
|
||||
|
||||
可用分析库:
|
||||
**动作选择指南 (Action Selection)**:
|
||||
|
||||
图片收集要求:
|
||||
- 在适当的时候(通常是生成了多个图表后),主动使用 `collect_figures` 动作
|
||||
- 收集时必须包含具体的图片绝对路径(file_path字段)
|
||||
- 提供详细的图片描述和深入的分析
|
||||
- 确保图片路径与之前打印的路径一致
|
||||
1. **generate_code**
|
||||
- 场景:需要执行代码(加载、分析、绘图)。
|
||||
- 格式:
|
||||
```yaml
|
||||
action: "generate_code"
|
||||
reasoning: "正在执行[阶段X]分析,目的是..."
|
||||
code: |
|
||||
# Python Code
|
||||
# ...
|
||||
# 每次生成图片后必须打印绝对路径
|
||||
print(f"图片已保存至: {{os.path.abspath(file_path)}}")
|
||||
next_steps: ["下一步计划"]
|
||||
```
|
||||
|
||||
报告生成要求:
|
||||
- 生成的报告要符合报告的文言需要,不要出现有争议的文字
|
||||
- 在适当的时候(通常是生成了多个图表后),进行图像的对比分析
|
||||
- 涉及的文言,不能出现我,你,他,等主观用于,采用报告式的文言论述
|
||||
- 提供详细的图片描述和深入的分析
|
||||
- 报告中的英文单词,初专有名词(TSP,TBOX等),其余的全部翻译成中文,例如remote control(远控),don't exist in TSP (数据不在TSP上);
|
||||
2. **collect_figures**
|
||||
- 场景:**每完成一个主要阶段(生成了2-3张图)后主动调用**。
|
||||
- 作用:总结当前图表发现,防止单次响应过长。
|
||||
- 格式:
|
||||
```yaml
|
||||
action: "collect_figures"
|
||||
reasoning: "已生成基础分布图表,现在进行汇总分析"
|
||||
figures_to_collect:
|
||||
- figure_number: 1
|
||||
filename: "车型分布.png"
|
||||
file_path: "/abs/path/to/车型分布.png"
|
||||
description: "展示了各车型的工单量差异..."
|
||||
analysis: "从图中可见,X车型工单量占比最高,达到Y%..."
|
||||
```
|
||||
|
||||
三种动作类型及使用时机:
|
||||
3. **analysis_complete**
|
||||
- 场景:所有SOP步骤执行完毕,且已通过 `collect_figures` 收集了足够素材。
|
||||
- 格式:
|
||||
```yaml
|
||||
action: "analysis_complete"
|
||||
final_report: "(此处留空,系统会根据上下文自动生成报告)"
|
||||
```
|
||||
|
||||
**1. 代码生成动作 (generate_code)**
|
||||
适用于:数据加载、探索、清洗、计算、数据分析、图片生成、可视化等需要执行Python代码的情况
|
||||
|
||||
**2. 图片收集动作 (collect_figures)**
|
||||
适用于:已生成多个图表后,需要对图片进行汇总和深入分析的情况
|
||||
|
||||
**3. 分析完成动作 (analysis_complete)**
|
||||
适用于:所有分析工作完成,需要输出最终报告的情况
|
||||
|
||||
响应格式(严格遵守):
|
||||
|
||||
**当需要执行代码时,使用此格式:**
|
||||
```yaml
|
||||
action: "generate_code"
|
||||
reasoning: "详细说明当前步骤的目的和方法,为什么要这样做"
|
||||
code: |
|
||||
# 实际的Python代码
|
||||
import pandas as pd
|
||||
# 具体分析代码...
|
||||
|
||||
# 图片保存示例(如果生成图表)
|
||||
plt.figure(figsize=(10, 6))
|
||||
# 绘图代码...
|
||||
plt.title('图表标题')
|
||||
file_path = os.path.join(session_output_dir, '图表名称.png')
|
||||
plt.savefig(file_path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
# 必须打印绝对路径
|
||||
absolute_path = os.path.abspath(file_path)
|
||||
print(f"图片已保存至: {{absolute_path}}")
|
||||
print(f"图片文件名: {{os.path.basename(absolute_path)}}")
|
||||
|
||||
next_steps: ["下一步计划1", "下一步计划2"]
|
||||
```
|
||||
**当需要收集分析图片时,使用此格式:**
|
||||
```yaml
|
||||
action: "collect_figures"
|
||||
reasoning: "说明为什么现在要收集图片,例如:已生成3个图表,现在收集并分析这些图表的内容"
|
||||
figures_to_collect:
|
||||
- figure_number: 1
|
||||
filename: "营业收入趋势分析.png"
|
||||
file_path: "实际的完整绝对路径"
|
||||
description: "图片概述:展示了什么内容"
|
||||
analysis: "细节分析:从图中可以看出的具体信息和洞察"
|
||||
next_steps: ["后续计划"]
|
||||
```
|
||||
|
||||
**当所有分析完成时,使用此格式:**
|
||||
```yaml
|
||||
action: "analysis_complete"
|
||||
final_report: |
|
||||
完整的最终分析报告内容
|
||||
(可以是多行文本)
|
||||
```
|
||||
|
||||
|
||||
|
||||
特别注意:
|
||||
- 数据读取问题:如果看到大量NaN值,检查编码和分隔符
|
||||
- 日期列问题:如果日期列显示为float64,说明解析失败
|
||||
- 编码错误:逐个尝试 ['utf-8', 'gbk', 'gb18030', 'gb2312', 'latin1']
|
||||
- 列类型错误:检查是否有列被错误识别为数值型但实际是文本
|
||||
- matplotlib错误时,确保使用Agg后端和正确的字体设置
|
||||
- 每次执行后根据反馈调整代码,不要重复相同的错误
|
||||
---
|
||||
|
||||
**特别提示**:
|
||||
- **翻译要求**:报告中的英文专有名词(除了TSP, TBOX, HU等标准缩写)必须翻译成中文(Remote Control -> 远控)。
|
||||
- **客观陈述**:不要使用"data shows", "plot indicates"等技术语言,直接陈述业务事实("X车型在Y模块故障率最高")。
|
||||
- **鲁棒性**:如果代码报错,请深呼吸,分析错误日志,修改代码重试。不要重复无效代码。
|
||||
|
||||
"""
|
||||
|
||||
# 最终报告生成提示词
|
||||
final_report_system_prompt = """你是一个专业的数据分析师,需要基于完整的分析过程生成最终的分析报告。
|
||||
final_report_system_prompt = """你是一位**资深数据分析专家 (Senior Data Analyst)**。你的任务是基于详细的数据分析过程,撰写一份**专业级、可落地的业务分析报告**。
|
||||
|
||||
分析信息:
|
||||
分析轮数: {current_round}
|
||||
输出目录: {session_output_dir}
|
||||
### 输入上下文
|
||||
- **数据全景 (Data Profile)**:
|
||||
{data_profile}
|
||||
|
||||
{figures_summary}
|
||||
|
||||
代码执行结果摘要:
|
||||
- **分析过程与代码发现**:
|
||||
{code_results_summary}
|
||||
|
||||
报告生成要求:
|
||||
报告应使用markdown格式,确保结构清晰;需要包含对所有生成图片的详细分析和说明;
|
||||
生成的报告要符合报告的文言需要,不要出现有争议的文字;
|
||||
在适当的时候(通常是生成了多个图表后),进行图像的对比分析;
|
||||
涉及的文言,不能出现我,你,他,等主观用于,采用报告式的文言论述;
|
||||
提供详细的图片描述和深入的分析;
|
||||
报告中的英文单词,初专有名词(TSP,TBOX等),其余的全部翻译成中文,例如remote control(远控),don't exist in TSP (数据不在TSP上);
|
||||
- **可视化证据链 (Visual Evidence)**:
|
||||
{figures_summary}
|
||||
> **警告**:你必须仔细检查上述列表。如果在 `figures_summary` 中列出了图表,你的报告中就必须引用它。**严禁遗漏任何已生成的图表**。引用格式必须为 ``。
|
||||
|
||||
总结分析过程中的关键发现;提供有价值的结论和建议;内容必须专业且逻辑性强。
|
||||
**重要提醒:图片引用必须使用相对路径格式 ``**
|
||||
### 报告核心要求
|
||||
1. **角色定位**:
|
||||
- 你不仅是数据图表的生产者,更是业务问题的诊断者。
|
||||
- 你的报告需要回答"发生了什么"、"为什么发生"以及"怎么解决"。
|
||||
2. **文风规范 (Strict Tone of Voice)**:
|
||||
- **禁止**:使用第一人称(我、我们)、使用模糊推测词(大概、可能)。
|
||||
- **强制**:客观陈述事实,使用专业术语(同比、环比、占比、TOPN),结论要有数据支撑。
|
||||
3. **结构化输出**:必须严格遵守下方的 5 章节结构,确保逻辑严密。
|
||||
4. **证据标注规则**:
|
||||
- 当报告段落的结论来源于某一轮分析的数据,请在段落末尾添加HTML注释标注:`<!-- evidence:round_N -->`
|
||||
- N 为产生该数据的分析轮次编号(从1开始)
|
||||
- 示例:某段落描述了第3轮分析发现的车型分布规律,则在段落末尾添加 `<!-- evidence:round_3 -->`
|
||||
- 这些标注不会在报告中显示,但会被系统用于关联支撑数据
|
||||
|
||||
图片质量与格式要求:
|
||||
- **学术级图表标准**:所有图表必须达到发表级质量,包含:
|
||||
* 专业的颜色方案(seaborn调色板)
|
||||
* 清晰的标签和图例(无重叠)
|
||||
* 合适的字体大小(≥12pt)
|
||||
* 简洁的布局(白色背景,无冗余元素)
|
||||
- **路径格式**:使用相对路径``
|
||||
- **图表命名**:使用描述性中文名称,如`来源渠道分布.png`
|
||||
响应格式要求:
|
||||
必须严格使用以下YAML格式输出:
|
||||
### 报告结构模板使用说明 (Template Instructions)
|
||||
- **固定格式 (Format)**:所有的 Markdown 标题 (`#`, `##`)、列表项前缀 (`- **...**`)、表格表头是必须保留的**骨架**。
|
||||
- **写作指引 (Prompts)**:方括号 `[...]` 内的文字是给你的**写作提示**,请根据实际分析将其**替换**为具体内容,**不要**在最终报告中保留方括号。
|
||||
- **数据文件引用规则**:模板中的 `[4-1TSP问题聚类.xlsx]` 等占位文件名**必须替换**为实际生成的文件名(见下方传入的已生成数据文件列表)。如果某类文件未生成,请注明原因(如"数据量不足,未执行聚类"或"该分类无对应数据"),不要保留占位符。
|
||||
- **直接输出Markdown**:不要使用JSON或YAML包裹,直接输出Markdown内容。
|
||||
|
||||
```yaml
|
||||
action: "analysis_complete"
|
||||
final_report: |
|
||||
# 数据分析报告
|
||||
|
||||
## 分析概述
|
||||
[概述本次分析的目标和范围]
|
||||
|
||||
## 数据分析过程
|
||||
[总结分析的主要步骤]
|
||||
|
||||
## 关键发现
|
||||
[描述重要的分析结果,使用段落形式而非列表]
|
||||
|
||||
## 图表分析
|
||||
|
||||
### [图表标题]
|
||||

|
||||
|
||||
[对图表的详细分析,使用连续的段落描述,避免使用分点列表]
|
||||
|
||||
### [下一个图表标题]
|
||||

|
||||
|
||||
[对图表的详细分析,使用连续的段落描述]
|
||||
---
|
||||
|
||||
## 深度分析
|
||||
### [图表标题]
|
||||

|
||||
|
||||
[对此前所有的数据,探索关联关系,进行深度剖析,重点问题,高频问题,并以图表介绍,使用连续的段落描述,避免使用分点列表]
|
||||
|
||||
## 结论与建议
|
||||
[基于分析结果提出结论和投资建议,使用段落形式表达]
|
||||
```
|
||||
### 报告结构模板 (Markdown)
|
||||
|
||||
```markdown
|
||||
# 《XX品牌车联网运维分析报告》
|
||||
|
||||
## 1. 整体问题分布与效率分析
|
||||
|
||||
### 1.1 工单类型分布与趋势
|
||||
|
||||
{{总工单数}}单。
|
||||
其中:
|
||||
|
||||
- TSP问题:{{数量}}单 ({{占比}}%)
|
||||
- APP问题:{{数量}}单 ({{占比}}%)
|
||||
- DK问题:{{数量}}单 ({{占比}}%)
|
||||
- 咨询类:{{数量}}单 ({{占比}}%)
|
||||
|
||||
> (可增加环比变化趋势)
|
||||
|
||||
---
|
||||
|
||||
### 1.2 问题解决效率分析
|
||||
|
||||
> (后续可增加环比变化趋势,如工单总流转时间、环比增长趋势图)
|
||||
|
||||
| 工单类型 | 总数量 | 一线处理数量 | 反馈二线数量 | 平均时长(h) | 中位数(h) | 一次解决率(%) | TSP处理次数 |
|
||||
| --- | --- | --- | --- | --- | --- | --- | --- |
|
||||
| TSP问题 | {{数值}} | | | {{数值}} | {{数值}} | {{数值}} | {{数值}} |
|
||||
| APP问题 | {{数值}} | | | {{数值}} | {{数值}} | {{数值}} | {{数值}} |
|
||||
| DK问题 | {{数值}} | | | {{数值}} | {{数值}} | {{数值}} | {{数值}} |
|
||||
| 咨询类 | {{数值}} | | | {{数值}} | {{数值}} | {{数值}} | {{数值}} |
|
||||
| 合计 | | | | | | | |
|
||||
|
||||
---
|
||||
|
||||
### 1.3 问题车型分布
|
||||
|
||||
---
|
||||
|
||||
## 2. 各类问题专题分析
|
||||
|
||||
### 2.1 TSP问题专题
|
||||
|
||||
当月总体情况概述:
|
||||
|
||||
| 工单类型 | 总数量 | 海外一线处理数量 | 国内二线数量 | 平均时长(h) | 中位数(h) |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| TSP问题 | {{数值}} | | | {{数值}} | {{数值}} |
|
||||
|
||||
#### 2.1.1 TSP问题二级分类+三级分布
|
||||
|
||||
#### 2.1.2 TOP问题
|
||||
|
||||
| 高频问题简述 | 关键词示例 | 原因 | 处理方式 | 占比约 |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| 网络超时/偶发延迟 | ack超时、请求超时、一直转圈 | | | {{数值}} |
|
||||
| 车辆唤醒失败 | 唤醒失败、深度睡眠、TBOX未唤醒 | | | {{数值}} |
|
||||
| 控制器反馈失败 | 控制器反馈状态失败、轻微故障 | | | {{数值}} |
|
||||
| TBOX不在线 | 卡不在线、注册异常 | | | {{数值}} |
|
||||
|
||||
> 聚类分析文件(需要输出):[4-1TSP问题聚类.xlsx]
|
||||
|
||||
---
|
||||
|
||||
### 2.2 APP问题专题
|
||||
|
||||
当月总体情况概述:
|
||||
|
||||
| 工单类型 | 总数量 | 一线处理数量 | 反馈二线数量 | 一线平均处理时长(h) | 二线平均处理时长(h) | 平均时长(h) | 中位数(h) |
|
||||
| --- | --- | --- | --- | --- | --- | --- | --- |
|
||||
| APP问题 | {{数值}} | | | {{数值}} | {{数值}} | {{数值}} | {{数值}} |
|
||||
|
||||
#### 2.2.1 APP问题二级分类分布
|
||||
|
||||
#### 2.2.2 TOP问题
|
||||
|
||||
| 高频问题简述 | 关键词示例 | 原因 | 处理方式 | 数量 | 占比约 |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| 问题1 | 关键词1、2、3 | | | {{数值}} | {{数值}} |
|
||||
| 问题2 | 关键词1、2、3 | | | {{数值}} | {{数值}} |
|
||||
| 问题3 | 关键词1、2、3 | | | {{数值}} | {{数值}} |
|
||||
| 问题4 | 关键词1、2、3 | | | {{数值}} | {{数值}} |
|
||||
|
||||
> 聚类分析文件(需要输出):[4-2APP问题聚类.xlsx]
|
||||
|
||||
---
|
||||
|
||||
### 2.3 TBOX问题专题
|
||||
|
||||
> 总流转时间和环比增长趋势(可参考柱状+折线组合图)
|
||||
|
||||
#### 2.3.1 TBOX问题二级分类分布
|
||||
|
||||
#### 2.3.2 TOP问题
|
||||
|
||||
| 高频问题简述 | 关键词示例 | 原因 | 处理方式 | 占比约 |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| 问题1 | 关键词1、2、3 | | | {{数值}} |
|
||||
| 问题2 | 关键词1、2、3 | | | {{数值}} |
|
||||
| 问题3 | 关键词1、2、3 | | | {{数值}} |
|
||||
| 问题4 | 关键词1、2、3 | | | {{数值}} |
|
||||
| 问题5 | 关键词1、2、3 | | | {{数值}} |
|
||||
|
||||
> 聚类分析文件:[4-3TBOX问题聚类.xlsx]
|
||||
|
||||
---
|
||||
|
||||
### 2.4 DMC专题
|
||||
|
||||
> 总流转时间和环比增长趋势(可参考柱状+折线组合图)
|
||||
|
||||
#### 2.4.1 DMC类二级分类分布与解决时长
|
||||
|
||||
#### 2.4.2 TOP问题
|
||||
|
||||
| 高频问题简述 | 关键词示例 | 原因 | 处理方式 | 占比约 |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| 问题1 | 关键词1、2、3 | | | {{数值}} |
|
||||
| 问题2 | 关键词1、2、3 | | | {{数值}} |
|
||||
|
||||
> 聚类分析文件(需要输出):[4-4DMC问题处理.xlsx]
|
||||
|
||||
---
|
||||
|
||||
### 2.5 咨询类专题
|
||||
|
||||
> 总流转时间和环比增长趋势(可参考柱状+折线组合图)
|
||||
|
||||
#### 2.5.1 咨询类二级分类分布与解决时长
|
||||
|
||||
#### 2.5.2 TOP咨询
|
||||
|
||||
| 高频问题简述 | 关键词示例 | 原因 | 处理方式 | 占比约 |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| 问题1 | 关键词1、2、3 | | | {{数值}} |
|
||||
| 问题2 | 关键词1、2、3 | | | {{数值}} |
|
||||
|
||||
> 咨询类文件(需要输出):[4-5咨询类问题处理.xlsx]
|
||||
|
||||
---
|
||||
|
||||
## 3. 建议与附件
|
||||
|
||||
- 工单客诉详情见附件:
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# 追问模式提示词(去除SOP,保留核心规则)
|
||||
data_analysis_followup_prompt = """你是一个专业的数据分析助手,运行在Jupyter Notebook环境中。
|
||||
当前处于**追问模式 (Follow-up Mode)**。用户基于之前的分析结果提出了新的需求。
|
||||
|
||||
**核心使命**:
|
||||
- 直接针对用户的后续需求进行解答,**无需**重新执行完整SOP。
|
||||
- 只有当用户明确要求重新进行全流程分析时,才执行SOP。
|
||||
|
||||
**核心能力**:
|
||||
1. **代码执行**:自动编写并执行Pandas/Matplotlib代码。
|
||||
2. **多模态分析**:支持时序预测、文本挖掘(N-gram)、多维交叉分析。
|
||||
3. **智能纠错**:遇到报错自动分析原因并修复代码。
|
||||
|
||||
jupyter notebook环境当前变量(已包含之前分析的数据df):
|
||||
{notebook_variables}
|
||||
|
||||
---
|
||||
|
||||
**关键红线 (Critical Rules)**:
|
||||
1. **进程保护**:严禁使用 `exit()`、`quit()` 或 `sys.exit()`。
|
||||
2. **数据安全**:严禁伪造数据。严禁写入非结果文件。
|
||||
3. **文件验证**:所有文件操作前必须 `os.path.exists()`。
|
||||
4. **绝对路径**:图片保存必须使用 `session_output_dir` 和 `os.path.abspath`。
|
||||
5. **图片保存**:禁止 `plt.show()`。必须使用 `plt.savefig()`。
|
||||
|
||||
---
|
||||
|
||||
**代码生成规则 (Reuse)**:
|
||||
- **环境持久化**:直接使用已加载的 `df`,不要重复加载数据。
|
||||
- **可视化规范**:中文字体配置、类别>5使用水平条形图、美学要求同上。
|
||||
- **文本挖掘**:如需挖掘,继续遵守N-gram和停用词规则。
|
||||
|
||||
---
|
||||
|
||||
**动作选择指南**:
|
||||
1. **generate_code**
|
||||
- 场景:执行针对追问的代码。
|
||||
- 格式:同标准模式。
|
||||
|
||||
2. **collect_figures**
|
||||
- 场景:如果生成了新的图表,必须收集。
|
||||
- 格式:同标准模式。
|
||||
|
||||
3. **analysis_complete**
|
||||
- 场景:追问回答完毕。
|
||||
- 格式:同标准模式。
|
||||
|
||||
特别注意事项:
|
||||
必须对每个图片进行详细的分析和说明。
|
||||
图片的内容和标题必须与分析内容相关。
|
||||
使用专业的金融分析术语和方法。
|
||||
报告要完整、准确、有价值。
|
||||
**强制要求:所有图片路径都必须使用相对路径格式 `./文件名.png`。
|
||||
为了确保后续markdown转换docx效果良好,请避免在正文中使用分点列表形式,改用段落形式表达。**
|
||||
"""
|
||||
|
||||
3
pyproject.toml
Normal file
3
pyproject.toml
Normal file
@@ -0,0 +1,3 @@
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["."]
|
||||
testpaths = ["tests"]
|
||||
0
raw_data/.gitkeep
Normal file
0
raw_data/.gitkeep
Normal file
@@ -50,3 +50,8 @@ flake8>=6.0.0
|
||||
|
||||
# 字体支持(用于matplotlib中文显示)
|
||||
fonttools>=4.38.0
|
||||
|
||||
# Web Interface dependencies
|
||||
fastapi>=0.109.0
|
||||
uvicorn>=0.27.0
|
||||
python-multipart>=0.0.9
|
||||
|
||||
4
start.bat
Normal file
4
start.bat
Normal file
@@ -0,0 +1,4 @@
|
||||
@echo off
|
||||
echo Starting IOV Data Analysis Agent...
|
||||
python bootstrap.py
|
||||
pause
|
||||
3
start.sh
Executable file
3
start.sh
Executable file
@@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
echo "Starting IOV Data Analysis Agent..."
|
||||
python3 bootstrap.py
|
||||
20
start_web.bat
Normal file
20
start_web.bat
Normal file
@@ -0,0 +1,20 @@
|
||||
@echo off
|
||||
chcp 65001 >nul
|
||||
set PYTHONIOENCODING=utf-8
|
||||
|
||||
:: Get local IP address
|
||||
for /f "tokens=2 delims=:" %%a in ('ipconfig ^| findstr /c:"IPv4"') do (
|
||||
for /f "tokens=1" %%b in ("%%a") do set LOCAL_IP=%%b
|
||||
)
|
||||
|
||||
echo.
|
||||
echo IOV Data Analysis Agent
|
||||
echo ========================
|
||||
echo.
|
||||
echo Local: http://localhost:8000
|
||||
if defined LOCAL_IP (
|
||||
echo Network: http://%LOCAL_IP%:8000
|
||||
)
|
||||
echo.
|
||||
python -m uvicorn web.main:app --reload --reload-exclude "outputs" --reload-exclude "uploads" --reload-exclude ".hypothesis" --reload-exclude ".cache" --host 0.0.0.0 --port 8000
|
||||
pause
|
||||
4
start_web.sh
Executable file
4
start_web.sh
Executable file
@@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
echo "Starting IOV Data Analysis Agent Web Interface..."
|
||||
echo "Please open http://localhost:8000 in your browser."
|
||||
python3 -m uvicorn web.main:app --reload --host 0.0.0.0 --port 8000
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# tests package
|
||||
11
tests/conftest.py
Normal file
11
tests/conftest.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Conftest for property-based tests.
|
||||
Ensures the project root is on sys.path for direct module imports.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add project root to sys.path so we can import modules directly
|
||||
# (e.g., `from config.app_config import AppConfig`)
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
651
tests/test_dashboard_properties.py
Normal file
651
tests/test_dashboard_properties.py
Normal file
@@ -0,0 +1,651 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Property-based tests for analysis-dashboard-redesign features.
|
||||
Uses hypothesis with max_examples=100 as specified in the design document.
|
||||
|
||||
Run: python -m pytest tests/test_dashboard_properties.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
# Ensure project root is on path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
from hypothesis.extra.pandas import column, data_frames, range_indexes
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers / Strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Strategy for generating random execution results (success or failure)
|
||||
execution_result_st = st.fixed_dictionaries({
|
||||
"success": st.booleans(),
|
||||
"output": st.text(min_size=0, max_size=200),
|
||||
"error": st.text(min_size=0, max_size=200),
|
||||
"variables": st.just({}),
|
||||
"evidence_rows": st.lists(
|
||||
st.dictionaries(
|
||||
keys=st.text(min_size=1, max_size=10, alphabet="abcdefghijklmnopqrstuvwxyz"),
|
||||
values=st.one_of(st.integers(), st.text(min_size=0, max_size=20), st.floats(allow_nan=False)),
|
||||
min_size=1,
|
||||
max_size=5,
|
||||
),
|
||||
min_size=0,
|
||||
max_size=10,
|
||||
),
|
||||
"auto_exported_files": st.just([]),
|
||||
"prompt_saved_files": st.just([]),
|
||||
})
|
||||
|
||||
# Strategy for reasoning text (may be empty, simulating missing YAML field)
|
||||
reasoning_st = st.one_of(st.just(""), st.text(min_size=1, max_size=200))
|
||||
|
||||
# Strategy for code text
|
||||
code_st = st.text(min_size=1, max_size=500, alphabet="abcdefghijklmnopqrstuvwxyz0123456789 =()._\n")
|
||||
|
||||
# Strategy for feedback/raw_log text
|
||||
feedback_st = st.text(min_size=0, max_size=300)
|
||||
|
||||
|
||||
def build_round_data(round_num, reasoning, code, result, feedback):
|
||||
"""Construct a Round_Data dict the same way DataAnalysisAgent._handle_generate_code does."""
|
||||
def summarize_result(r):
|
||||
if r.get("success"):
|
||||
evidence_rows = r.get("evidence_rows", [])
|
||||
if evidence_rows:
|
||||
num_rows = len(evidence_rows)
|
||||
num_cols = len(evidence_rows[0]) if evidence_rows else 0
|
||||
return f"执行成功,输出 DataFrame ({num_rows}行×{num_cols}列)"
|
||||
output = r.get("output", "")
|
||||
if output:
|
||||
first_line = output.strip().split("\n")[0][:80]
|
||||
return f"执行成功: {first_line}"
|
||||
return "执行成功"
|
||||
else:
|
||||
error = r.get("error", "未知错误")
|
||||
if len(error) > 100:
|
||||
error = error[:100] + "..."
|
||||
return f"执行失败: {error}"
|
||||
|
||||
return {
|
||||
"round": round_num,
|
||||
"reasoning": reasoning,
|
||||
"code": code,
|
||||
"result_summary": summarize_result(result),
|
||||
"evidence_rows": result.get("evidence_rows", []),
|
||||
"raw_log": feedback,
|
||||
"auto_exported_files": result.get("auto_exported_files", []),
|
||||
"prompt_saved_files": result.get("prompt_saved_files", []),
|
||||
}
|
||||
|
||||
|
||||
# Regex for parsing DATA_FILE_SAVED markers (same as CodeExecutor)
|
||||
_DATA_FILE_SAVED_RE = re.compile(
|
||||
r"\[DATA_FILE_SAVED\]\s*filename:\s*(.+?),\s*rows:\s*(\d+),\s*description:\s*(.+)"
|
||||
)
|
||||
|
||||
|
||||
def parse_data_file_saved_markers(stdout_text):
|
||||
"""Parse [DATA_FILE_SAVED] marker lines — mirrors CodeExecutor._parse_data_file_saved_markers."""
|
||||
results = []
|
||||
for line in stdout_text.splitlines():
|
||||
m = _DATA_FILE_SAVED_RE.search(line)
|
||||
if m:
|
||||
results.append({
|
||||
"filename": m.group(1).strip(),
|
||||
"rows": int(m.group(2)),
|
||||
"description": m.group(3).strip(),
|
||||
})
|
||||
return results
|
||||
|
||||
|
||||
# Evidence annotation regex (same as web/main.py)
|
||||
_EVIDENCE_PATTERN = re.compile(r"<!--\s*evidence:round_(\d+)\s*-->")
|
||||
|
||||
|
||||
def split_report_to_paragraphs(markdown_content):
|
||||
"""Mirrors _split_report_to_paragraphs from web/main.py."""
|
||||
lines = markdown_content.split("\n")
|
||||
paragraphs = []
|
||||
current_block = []
|
||||
current_type = "text"
|
||||
para_id = 0
|
||||
|
||||
def flush_block():
|
||||
nonlocal para_id, current_block, current_type
|
||||
text = "\n".join(current_block).strip()
|
||||
if text:
|
||||
paragraphs.append({
|
||||
"id": f"p-{para_id}",
|
||||
"type": current_type,
|
||||
"content": text,
|
||||
})
|
||||
para_id += 1
|
||||
current_block = []
|
||||
current_type = "text"
|
||||
|
||||
in_table = False
|
||||
in_code = False
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped.startswith("```"):
|
||||
if in_code:
|
||||
current_block.append(line)
|
||||
flush_block()
|
||||
in_code = False
|
||||
continue
|
||||
else:
|
||||
flush_block()
|
||||
current_block.append(line)
|
||||
current_type = "code"
|
||||
in_code = True
|
||||
continue
|
||||
|
||||
if in_code:
|
||||
current_block.append(line)
|
||||
continue
|
||||
|
||||
if re.match(r"^#{1,6}\s", stripped):
|
||||
flush_block()
|
||||
current_block.append(line)
|
||||
current_type = "heading"
|
||||
flush_block()
|
||||
continue
|
||||
|
||||
if re.match(r"^!\[.*\]\(.*\)", stripped):
|
||||
flush_block()
|
||||
current_block.append(line)
|
||||
current_type = "image"
|
||||
flush_block()
|
||||
continue
|
||||
|
||||
if stripped.startswith("|"):
|
||||
if not in_table:
|
||||
flush_block()
|
||||
in_table = True
|
||||
current_type = "table"
|
||||
current_block.append(line)
|
||||
continue
|
||||
else:
|
||||
if in_table:
|
||||
flush_block()
|
||||
in_table = False
|
||||
|
||||
if not stripped:
|
||||
flush_block()
|
||||
continue
|
||||
|
||||
current_block.append(line)
|
||||
|
||||
flush_block()
|
||||
return paragraphs
|
||||
|
||||
|
||||
def extract_evidence_annotations(paragraphs, rounds):
|
||||
"""Mirrors _extract_evidence_annotations from web/main.py, using a rounds list instead of session."""
|
||||
supporting_data = {}
|
||||
for para in paragraphs:
|
||||
content = para.get("content", "")
|
||||
match = _EVIDENCE_PATTERN.search(content)
|
||||
if match:
|
||||
round_num = int(match.group(1))
|
||||
idx = round_num - 1
|
||||
if 0 <= idx < len(rounds):
|
||||
evidence_rows = rounds[idx].get("evidence_rows", [])
|
||||
if evidence_rows:
|
||||
supporting_data[para["id"]] = evidence_rows
|
||||
return supporting_data
|
||||
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 1: Round_Data Structural Completeness (Task 16.1)
|
||||
# Feature: analysis-dashboard-redesign, Property 1: Round_Data structural completeness
|
||||
# Validates: Requirements 1.1, 1.3, 1.4
|
||||
# ===========================================================================
|
||||
|
||||
ROUND_DATA_REQUIRED_FIELDS = {
|
||||
"round": int,
|
||||
"reasoning": str,
|
||||
"code": str,
|
||||
"result_summary": str,
|
||||
"evidence_rows": list,
|
||||
"raw_log": str,
|
||||
}
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
num_rounds=st.integers(min_value=1, max_value=20),
|
||||
results=st.lists(execution_result_st, min_size=1, max_size=20),
|
||||
reasonings=st.lists(reasoning_st, min_size=1, max_size=20),
|
||||
codes=st.lists(code_st, min_size=1, max_size=20),
|
||||
feedbacks=st.lists(feedback_st, min_size=1, max_size=20),
|
||||
)
|
||||
def test_prop1_round_data_structural_completeness(num_rounds, results, reasonings, codes, feedbacks):
|
||||
"""Round_Data objects must contain all required fields with correct types and preserve insertion order.
|
||||
|
||||
**Validates: Requirements 1.1, 1.3, 1.4**
|
||||
"""
|
||||
# Build a list of rounds using the same number of entries
|
||||
count = min(num_rounds, len(results), len(reasonings), len(codes), len(feedbacks))
|
||||
rounds_list = []
|
||||
for i in range(count):
|
||||
rd = build_round_data(i + 1, reasonings[i], codes[i], results[i], feedbacks[i])
|
||||
rounds_list.append(rd)
|
||||
|
||||
# Verify all required fields present with correct types
|
||||
for rd in rounds_list:
|
||||
for field, expected_type in ROUND_DATA_REQUIRED_FIELDS.items():
|
||||
assert field in rd, f"Missing field: {field}"
|
||||
assert isinstance(rd[field], expected_type), (
|
||||
f"Field '{field}' expected {expected_type.__name__}, got {type(rd[field]).__name__}"
|
||||
)
|
||||
|
||||
# Verify insertion order preserved
|
||||
for i in range(len(rounds_list) - 1):
|
||||
assert rounds_list[i]["round"] <= rounds_list[i + 1]["round"], (
|
||||
f"Insertion order violated: round {rounds_list[i]['round']} > {rounds_list[i + 1]['round']}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 2: Evidence Capture Bounded (Task 16.2)
|
||||
# Feature: analysis-dashboard-redesign, Property 2: Evidence capture bounded
|
||||
# Validates: Requirements 4.1, 4.2, 4.3
|
||||
# ===========================================================================
|
||||
|
||||
# Strategy for generating random DataFrames with 0-10000 rows and 1-50 columns
|
||||
col_name_st = st.text(
|
||||
min_size=1, max_size=10,
|
||||
alphabet=st.sampled_from("abcdefghijklmnopqrstuvwxyz_"),
|
||||
).filter(lambda s: s[0] != "_") # column names shouldn't start with _
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
num_rows=st.integers(min_value=0, max_value=10000),
|
||||
num_cols=st.integers(min_value=1, max_value=50),
|
||||
)
|
||||
def test_prop2_evidence_capture_bounded(num_rows, num_cols):
|
||||
"""Evidence capture must return at most 10 rows with keys matching DataFrame columns.
|
||||
|
||||
**Validates: Requirements 4.1, 4.2, 4.3**
|
||||
"""
|
||||
# Generate a DataFrame with the given dimensions
|
||||
import numpy as np
|
||||
columns = [f"col_{i}" for i in range(num_cols)]
|
||||
if num_rows == 0:
|
||||
df = pd.DataFrame(columns=columns)
|
||||
else:
|
||||
data = np.random.randint(0, 100, size=(num_rows, num_cols))
|
||||
df = pd.DataFrame(data, columns=columns)
|
||||
|
||||
# Simulate the evidence capture logic: df.head(10).to_dict(orient='records')
|
||||
evidence_rows = df.head(10).to_dict(orient="records")
|
||||
|
||||
# Verify length constraints
|
||||
assert len(evidence_rows) <= 10
|
||||
assert len(evidence_rows) == min(10, len(df))
|
||||
|
||||
# Verify each row dict has keys matching the DataFrame's column names
|
||||
expected_keys = set(df.columns)
|
||||
for row in evidence_rows:
|
||||
assert set(row.keys()) == expected_keys
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 3: Filename Deduplication (Task 16.3)
|
||||
# Feature: analysis-dashboard-redesign, Property 3: Filename deduplication
|
||||
# Validates: Requirements 5.3
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
num_exports=st.integers(min_value=1, max_value=20),
|
||||
var_name=st.text(min_size=1, max_size=20, alphabet="abcdefghijklmnopqrstuvwxyz_0123456789").filter(
|
||||
lambda s: s[0].isalpha()
|
||||
),
|
||||
)
|
||||
def test_prop3_filename_deduplication(num_exports, var_name):
|
||||
"""All generated filenames from same-name exports must be unique.
|
||||
|
||||
**Validates: Requirements 5.3**
|
||||
"""
|
||||
output_dir = tempfile.mkdtemp()
|
||||
generated_filenames = []
|
||||
|
||||
for _ in range(num_exports):
|
||||
# Simulate _export_dataframe dedup logic
|
||||
base_filename = f"{var_name}.csv"
|
||||
filepath = os.path.join(output_dir, base_filename)
|
||||
|
||||
if os.path.exists(filepath):
|
||||
suffix = 1
|
||||
while True:
|
||||
dedup_filename = f"{var_name}_{suffix}.csv"
|
||||
filepath = os.path.join(output_dir, dedup_filename)
|
||||
if not os.path.exists(filepath):
|
||||
base_filename = dedup_filename
|
||||
break
|
||||
suffix += 1
|
||||
|
||||
# Create the file to simulate the export
|
||||
with open(filepath, "w") as f:
|
||||
f.write("dummy")
|
||||
|
||||
generated_filenames.append(base_filename)
|
||||
|
||||
# Verify all filenames are unique
|
||||
assert len(generated_filenames) == len(set(generated_filenames)), (
|
||||
f"Duplicate filenames found: {generated_filenames}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 4: Auto-Export Metadata Completeness (Task 16.4)
|
||||
# Feature: analysis-dashboard-redesign, Property 4: Auto-export metadata completeness
|
||||
# Validates: Requirements 5.4, 5.5
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
var_name=st.text(min_size=1, max_size=20, alphabet="abcdefghijklmnopqrstuvwxyz_0123456789").filter(
|
||||
lambda s: s[0].isalpha()
|
||||
),
|
||||
num_rows=st.integers(min_value=0, max_value=1000),
|
||||
num_cols=st.integers(min_value=1, max_value=50),
|
||||
)
|
||||
def test_prop4_auto_export_metadata_completeness(var_name, num_rows, num_cols):
|
||||
"""Auto-export metadata must contain all required fields with correct values.
|
||||
|
||||
**Validates: Requirements 5.4, 5.5**
|
||||
"""
|
||||
import numpy as np
|
||||
output_dir = tempfile.mkdtemp()
|
||||
columns = [f"col_{i}" for i in range(num_cols)]
|
||||
|
||||
if num_rows == 0:
|
||||
df = pd.DataFrame(columns=columns)
|
||||
else:
|
||||
data = np.random.randint(0, 100, size=(num_rows, num_cols))
|
||||
df = pd.DataFrame(data, columns=columns)
|
||||
|
||||
# Simulate _export_dataframe logic
|
||||
base_filename = f"{var_name}.csv"
|
||||
filepath = os.path.join(output_dir, base_filename)
|
||||
|
||||
if os.path.exists(filepath):
|
||||
suffix = 1
|
||||
while True:
|
||||
dedup_filename = f"{var_name}_{suffix}.csv"
|
||||
filepath = os.path.join(output_dir, dedup_filename)
|
||||
if not os.path.exists(filepath):
|
||||
base_filename = dedup_filename
|
||||
break
|
||||
suffix += 1
|
||||
|
||||
df.to_csv(filepath, index=False)
|
||||
metadata = {
|
||||
"variable_name": var_name,
|
||||
"filename": base_filename,
|
||||
"rows": len(df),
|
||||
"cols": len(df.columns),
|
||||
"columns": list(df.columns),
|
||||
}
|
||||
|
||||
# Verify all required fields present
|
||||
for field in ("variable_name", "filename", "rows", "cols", "columns"):
|
||||
assert field in metadata, f"Missing field: {field}"
|
||||
|
||||
# Verify values match the source DataFrame
|
||||
assert metadata["rows"] == len(df)
|
||||
assert metadata["cols"] == len(df.columns)
|
||||
assert metadata["columns"] == list(df.columns)
|
||||
assert metadata["variable_name"] == var_name
|
||||
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 5: DATA_FILE_SAVED Marker Parsing Round-Trip (Task 16.5)
|
||||
# Feature: analysis-dashboard-redesign, Property 5: DATA_FILE_SAVED marker parsing round-trip
|
||||
# Validates: Requirements 6.3
|
||||
# ===========================================================================
|
||||
|
||||
# Strategy for filenames: alphanumeric + Chinese + underscores + hyphens, with extension
|
||||
filename_base_st = st.text(
|
||||
min_size=1,
|
||||
max_size=30,
|
||||
alphabet=st.sampled_from(
|
||||
"abcdefghijklmnopqrstuvwxyz"
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"0123456789"
|
||||
"_-"
|
||||
"数据分析结果汇总报告"
|
||||
),
|
||||
).filter(lambda s: len(s.strip()) > 0 and "," not in s)
|
||||
|
||||
filename_ext_st = st.sampled_from([".csv", ".xlsx"])
|
||||
|
||||
filename_st = st.builds(lambda base, ext: base.strip() + ext, filename_base_st, filename_ext_st)
|
||||
|
||||
description_st = st.text(
|
||||
min_size=1,
|
||||
max_size=100,
|
||||
alphabet=st.sampled_from(
|
||||
"abcdefghijklmnopqrstuvwxyz"
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"0123456789 "
|
||||
"各类型问题聚合统计分析结果"
|
||||
),
|
||||
).filter(lambda s: len(s.strip()) > 0)
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
filename=filename_st,
|
||||
rows=st.integers(min_value=1, max_value=1000000),
|
||||
description=description_st,
|
||||
)
|
||||
def test_prop5_data_file_saved_marker_round_trip(filename, rows, description):
|
||||
"""Formatting then parsing a DATA_FILE_SAVED marker must recover original values.
|
||||
|
||||
**Validates: Requirements 6.3**
|
||||
"""
|
||||
# Format the marker
|
||||
marker = f"[DATA_FILE_SAVED] filename: {filename}, rows: {rows}, description: {description}"
|
||||
|
||||
# Parse using the same logic as CodeExecutor
|
||||
parsed = parse_data_file_saved_markers(marker)
|
||||
|
||||
assert len(parsed) == 1, f"Expected 1 parsed result, got {len(parsed)}"
|
||||
assert parsed[0]["filename"] == filename.strip()
|
||||
assert parsed[0]["rows"] == rows
|
||||
assert parsed[0]["description"] == description.strip()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 6: Data File Preview Bounded Rows (Task 16.6)
|
||||
# Feature: analysis-dashboard-redesign, Property 6: Data file preview bounded rows
|
||||
# Validates: Requirements 7.2
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
num_rows=st.integers(min_value=0, max_value=10000),
|
||||
num_cols=st.integers(min_value=1, max_value=50),
|
||||
)
|
||||
def test_prop6_data_file_preview_bounded_rows(num_rows, num_cols):
|
||||
"""Preview of a CSV file must return at most 5 rows with correct column names.
|
||||
|
||||
**Validates: Requirements 7.2**
|
||||
"""
|
||||
import numpy as np
|
||||
columns = [f"col_{i}" for i in range(num_cols)]
|
||||
|
||||
if num_rows == 0:
|
||||
df = pd.DataFrame(columns=columns)
|
||||
else:
|
||||
data = np.random.randint(0, 100, size=(num_rows, num_cols))
|
||||
df = pd.DataFrame(data, columns=columns)
|
||||
|
||||
# Write to a temp CSV file
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
csv_path = os.path.join(tmp_dir, "test_data.csv")
|
||||
df.to_csv(csv_path, index=False)
|
||||
|
||||
# Read back using the same logic as the preview endpoint
|
||||
preview_df = pd.read_csv(csv_path, nrows=5)
|
||||
|
||||
# Verify at most 5 rows
|
||||
assert len(preview_df) <= 5
|
||||
assert len(preview_df) == min(5, num_rows)
|
||||
|
||||
# Verify column names match exactly
|
||||
assert list(preview_df.columns) == columns
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 7: Evidence Annotation Parsing (Task 16.7)
|
||||
# Feature: analysis-dashboard-redesign, Property 7: Evidence annotation parsing
|
||||
# Validates: Requirements 11.3, 11.4
|
||||
# ===========================================================================
|
||||
|
||||
# Strategy for generating paragraphs with/without evidence annotations
|
||||
annotated_paragraph_st = st.builds(
|
||||
lambda text, round_num: f"{text} <!-- evidence:round_{round_num} -->",
|
||||
st.text(min_size=1, max_size=100, alphabet="abcdefghijklmnopqrstuvwxyz .,!"),
|
||||
st.integers(min_value=1, max_value=100),
|
||||
)
|
||||
|
||||
plain_paragraph_st = st.text(
|
||||
min_size=1,
|
||||
max_size=100,
|
||||
alphabet="abcdefghijklmnopqrstuvwxyz .,!",
|
||||
).filter(lambda s: "evidence:" not in s and len(s.strip()) > 0)
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
annotated=st.lists(annotated_paragraph_st, min_size=0, max_size=10),
|
||||
plain=st.lists(plain_paragraph_st, min_size=0, max_size=10),
|
||||
)
|
||||
def test_prop7_evidence_annotation_parsing(annotated, plain):
|
||||
"""Annotated paragraphs must be correctly extracted; non-annotated must be excluded.
|
||||
|
||||
**Validates: Requirements 11.3, 11.4**
|
||||
"""
|
||||
assume(len(annotated) + len(plain) > 0)
|
||||
|
||||
# Build markdown by interleaving annotated and plain paragraphs
|
||||
all_paragraphs = []
|
||||
for p in annotated:
|
||||
all_paragraphs.append(("annotated", p))
|
||||
for p in plain:
|
||||
all_paragraphs.append(("plain", p))
|
||||
|
||||
# Build markdown content with blank lines between paragraphs
|
||||
markdown = "\n\n".join(text for _, text in all_paragraphs)
|
||||
|
||||
# Parse into paragraphs
|
||||
paragraphs = split_report_to_paragraphs(markdown)
|
||||
|
||||
# Build fake rounds data (up to 100 rounds, each with some evidence)
|
||||
rounds = [
|
||||
{"evidence_rows": [{"key": f"value_{i}"}]}
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
# Extract evidence annotations
|
||||
supporting_data = extract_evidence_annotations(paragraphs, rounds)
|
||||
|
||||
# Verify: annotated paragraphs with valid round numbers should be in supporting_data
|
||||
for para in paragraphs:
|
||||
content = para.get("content", "")
|
||||
match = _EVIDENCE_PATTERN.search(content)
|
||||
if match:
|
||||
round_num = int(match.group(1))
|
||||
idx = round_num - 1
|
||||
if 0 <= idx < len(rounds) and rounds[idx].get("evidence_rows"):
|
||||
assert para["id"] in supporting_data, (
|
||||
f"Annotated paragraph {para['id']} with round {round_num} not in supporting_data"
|
||||
)
|
||||
else:
|
||||
# Non-annotated paragraphs must NOT be in supporting_data
|
||||
assert para["id"] not in supporting_data, (
|
||||
f"Non-annotated paragraph {para['id']} should not be in supporting_data"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 8: SessionData JSON Round-Trip (Task 16.8)
|
||||
# Feature: analysis-dashboard-redesign, Property 8: SessionData JSON round-trip
|
||||
# Validates: Requirements 12.4
|
||||
# ===========================================================================
|
||||
|
||||
# Strategy for Round_Data dicts
|
||||
round_data_st = st.fixed_dictionaries({
|
||||
"round": st.integers(min_value=1, max_value=100),
|
||||
"reasoning": st.text(min_size=0, max_size=200),
|
||||
"code": st.text(min_size=0, max_size=200),
|
||||
"result_summary": st.text(min_size=0, max_size=200),
|
||||
"evidence_rows": st.lists(
|
||||
st.dictionaries(
|
||||
keys=st.text(min_size=1, max_size=10, alphabet="abcdefghijklmnopqrstuvwxyz"),
|
||||
values=st.one_of(
|
||||
st.integers(min_value=-1000, max_value=1000),
|
||||
st.text(min_size=0, max_size=20),
|
||||
),
|
||||
min_size=0,
|
||||
max_size=5,
|
||||
),
|
||||
min_size=0,
|
||||
max_size=10,
|
||||
),
|
||||
"raw_log": st.text(min_size=0, max_size=200),
|
||||
})
|
||||
|
||||
# Strategy for file metadata dicts
|
||||
file_metadata_st = st.fixed_dictionaries({
|
||||
"filename": st.text(min_size=1, max_size=30, alphabet="abcdefghijklmnopqrstuvwxyz0123456789_."),
|
||||
"description": st.text(min_size=0, max_size=100),
|
||||
"rows": st.integers(min_value=0, max_value=100000),
|
||||
"cols": st.integers(min_value=0, max_value=100),
|
||||
"columns": st.lists(st.text(min_size=1, max_size=10, alphabet="abcdefghijklmnopqrstuvwxyz"), max_size=10),
|
||||
"size_bytes": st.integers(min_value=0, max_value=10000000),
|
||||
"source": st.sampled_from(["auto", "prompt"]),
|
||||
})
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
rounds=st.lists(round_data_st, min_size=0, max_size=20),
|
||||
data_files=st.lists(file_metadata_st, min_size=0, max_size=20),
|
||||
)
|
||||
def test_prop8_session_data_json_round_trip(rounds, data_files):
|
||||
"""Serializing rounds and data_files to JSON and back must produce equal data.
|
||||
|
||||
**Validates: Requirements 12.4**
|
||||
"""
|
||||
data = {
|
||||
"rounds": rounds,
|
||||
"data_files": data_files,
|
||||
}
|
||||
|
||||
# Serialize using the same approach as the codebase
|
||||
serialized = json.dumps(data, default=str)
|
||||
deserialized = json.loads(serialized)
|
||||
|
||||
assert deserialized["rounds"] == rounds
|
||||
assert deserialized["data_files"] == data_files
|
||||
238
tests/test_phase1.py
Normal file
238
tests/test_phase1.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Unit tests for Phase 1: Backend Data Model + API Changes
|
||||
|
||||
Run: python -m pytest tests/test_phase1.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 1: SessionData model extension
|
||||
# ===========================================================================
|
||||
|
||||
class TestSessionDataExtension:
|
||||
def test_rounds_initialized_to_empty_list(self):
|
||||
"""Task 1.1: rounds attribute exists and defaults to []"""
|
||||
from web.main import SessionData
|
||||
session = SessionData("test-id")
|
||||
assert hasattr(session, "rounds")
|
||||
assert session.rounds == []
|
||||
assert isinstance(session.rounds, list)
|
||||
|
||||
def test_data_files_initialized_to_empty_list(self):
|
||||
"""Task 1.2: data_files attribute exists and defaults to []"""
|
||||
from web.main import SessionData
|
||||
session = SessionData("test-id")
|
||||
assert hasattr(session, "data_files")
|
||||
assert session.data_files == []
|
||||
assert isinstance(session.data_files, list)
|
||||
|
||||
def test_existing_fields_unchanged(self):
|
||||
"""Existing SessionData fields still work."""
|
||||
from web.main import SessionData
|
||||
session = SessionData("test-id")
|
||||
assert session.session_id == "test-id"
|
||||
assert session.is_running is False
|
||||
assert session.analysis_results == []
|
||||
assert session.current_round == 0
|
||||
assert session.max_rounds == 20
|
||||
|
||||
def test_reconstruct_session_loads_rounds_and_data_files(self):
|
||||
"""Task 1.3: _reconstruct_session loads rounds and data_files from results.json"""
|
||||
from web.main import SessionManager
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
session_dir = os.path.join(tmpdir, "session_test123")
|
||||
os.makedirs(session_dir)
|
||||
|
||||
test_rounds = [{"round": 1, "reasoning": "test", "code": "x=1"}]
|
||||
test_data_files = [{"filename": "out.csv", "rows": 10}]
|
||||
results = {
|
||||
"analysis_results": [{"round": 1}],
|
||||
"rounds": test_rounds,
|
||||
"data_files": test_data_files,
|
||||
}
|
||||
with open(os.path.join(session_dir, "results.json"), "w") as f:
|
||||
json.dump(results, f)
|
||||
|
||||
sm = SessionManager()
|
||||
session = sm._reconstruct_session("test123", session_dir)
|
||||
|
||||
assert session.rounds == test_rounds
|
||||
assert session.data_files == test_data_files
|
||||
assert session.analysis_results == [{"round": 1}]
|
||||
|
||||
def test_reconstruct_session_legacy_format(self):
|
||||
"""Task 1.3: _reconstruct_session handles legacy list format gracefully"""
|
||||
from web.main import SessionManager
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
session_dir = os.path.join(tmpdir, "session_legacy")
|
||||
os.makedirs(session_dir)
|
||||
|
||||
# Legacy format: results.json is a plain list
|
||||
legacy_results = [{"round": 1, "code": "x=1"}]
|
||||
with open(os.path.join(session_dir, "results.json"), "w") as f:
|
||||
json.dump(legacy_results, f)
|
||||
|
||||
sm = SessionManager()
|
||||
session = sm._reconstruct_session("legacy", session_dir)
|
||||
|
||||
assert session.analysis_results == legacy_results
|
||||
assert session.rounds == []
|
||||
assert session.data_files == []
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 2: Status API response
|
||||
# ===========================================================================
|
||||
|
||||
class TestStatusAPIResponse:
|
||||
def test_status_response_contains_rounds(self):
|
||||
"""Task 2.1: GET /api/status response includes rounds field"""
|
||||
from web.main import SessionData, session_manager
|
||||
|
||||
session = SessionData("status-test")
|
||||
session.rounds = [{"round": 1, "reasoning": "r1"}]
|
||||
with session_manager.lock:
|
||||
session_manager.sessions["status-test"] = session
|
||||
|
||||
try:
|
||||
# Simulate what the endpoint returns
|
||||
response = {
|
||||
"is_running": session.is_running,
|
||||
"log": "",
|
||||
"has_report": session.generated_report is not None,
|
||||
"rounds": session.rounds,
|
||||
"current_round": session.current_round,
|
||||
"max_rounds": session.max_rounds,
|
||||
"progress_percentage": session.progress_percentage,
|
||||
"status_message": session.status_message,
|
||||
}
|
||||
assert "rounds" in response
|
||||
assert response["rounds"] == [{"round": 1, "reasoning": "r1"}]
|
||||
finally:
|
||||
with session_manager.lock:
|
||||
del session_manager.sessions["status-test"]
|
||||
|
||||
def test_status_backward_compat_fields(self):
|
||||
"""Task 2.2: Existing fields remain unchanged"""
|
||||
from web.main import SessionData
|
||||
|
||||
session = SessionData("compat-test")
|
||||
session.status_message = "分析中"
|
||||
session.progress_percentage = 50.0
|
||||
session.current_round = 5
|
||||
session.max_rounds = 20
|
||||
|
||||
response = {
|
||||
"is_running": session.is_running,
|
||||
"log": "",
|
||||
"has_report": session.generated_report is not None,
|
||||
"progress_percentage": session.progress_percentage,
|
||||
"current_round": session.current_round,
|
||||
"max_rounds": session.max_rounds,
|
||||
"status_message": session.status_message,
|
||||
"rounds": session.rounds,
|
||||
}
|
||||
|
||||
assert response["is_running"] is False
|
||||
assert response["has_report"] is False
|
||||
assert response["progress_percentage"] == 50.0
|
||||
assert response["current_round"] == 5
|
||||
assert response["max_rounds"] == 20
|
||||
assert response["status_message"] == "分析中"
|
||||
assert "log" in response
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 4: Evidence extraction
|
||||
# ===========================================================================
|
||||
|
||||
class TestEvidenceExtraction:
|
||||
def test_extract_evidence_basic(self):
|
||||
"""Task 4.1: Parse evidence annotations and build supporting_data"""
|
||||
from web.main import _extract_evidence_annotations, SessionData
|
||||
|
||||
session = SessionData("ev-test")
|
||||
session.rounds = [
|
||||
{"round": 1, "evidence_rows": [{"col": "val1"}]},
|
||||
{"round": 2, "evidence_rows": [{"col": "val2"}]},
|
||||
]
|
||||
|
||||
paragraphs = [
|
||||
{"id": "p-0", "type": "text", "content": "Some intro text"},
|
||||
{"id": "p-1", "type": "text", "content": "Analysis result <!-- evidence:round_1 -->"},
|
||||
{"id": "p-2", "type": "text", "content": "More analysis <!-- evidence:round_2 -->"},
|
||||
]
|
||||
|
||||
result = _extract_evidence_annotations(paragraphs, session)
|
||||
|
||||
assert "p-0" not in result # no annotation
|
||||
assert result["p-1"] == [{"col": "val1"}]
|
||||
assert result["p-2"] == [{"col": "val2"}]
|
||||
|
||||
def test_extract_evidence_no_annotations(self):
|
||||
"""Task 4.1: No annotations means empty mapping"""
|
||||
from web.main import _extract_evidence_annotations, SessionData
|
||||
|
||||
session = SessionData("ev-test2")
|
||||
session.rounds = [{"round": 1, "evidence_rows": [{"a": 1}]}]
|
||||
|
||||
paragraphs = [
|
||||
{"id": "p-0", "type": "text", "content": "No evidence here"},
|
||||
]
|
||||
|
||||
result = _extract_evidence_annotations(paragraphs, session)
|
||||
assert result == {}
|
||||
|
||||
def test_extract_evidence_out_of_range_round(self):
|
||||
"""Task 4.1: Round number beyond available rounds is ignored"""
|
||||
from web.main import _extract_evidence_annotations, SessionData
|
||||
|
||||
session = SessionData("ev-test3")
|
||||
session.rounds = [{"round": 1, "evidence_rows": [{"a": 1}]}]
|
||||
|
||||
paragraphs = [
|
||||
{"id": "p-0", "type": "text", "content": "Ref to round 5 <!-- evidence:round_5 -->"},
|
||||
]
|
||||
|
||||
result = _extract_evidence_annotations(paragraphs, session)
|
||||
assert result == {}
|
||||
|
||||
def test_extract_evidence_empty_evidence_rows(self):
|
||||
"""Task 4.1: Round with empty evidence_rows is excluded"""
|
||||
from web.main import _extract_evidence_annotations, SessionData
|
||||
|
||||
session = SessionData("ev-test4")
|
||||
session.rounds = [{"round": 1, "evidence_rows": []}]
|
||||
|
||||
paragraphs = [
|
||||
{"id": "p-0", "type": "text", "content": "Has annotation <!-- evidence:round_1 -->"},
|
||||
]
|
||||
|
||||
result = _extract_evidence_annotations(paragraphs, session)
|
||||
assert result == {}
|
||||
|
||||
def test_extract_evidence_whitespace_in_comment(self):
|
||||
"""Task 4.1: Handles whitespace variations in HTML comment"""
|
||||
from web.main import _extract_evidence_annotations, SessionData
|
||||
|
||||
session = SessionData("ev-test5")
|
||||
session.rounds = [{"round": 1, "evidence_rows": [{"x": 42}]}]
|
||||
|
||||
paragraphs = [
|
||||
{"id": "p-0", "type": "text", "content": "Text <!-- evidence:round_1 -->"},
|
||||
]
|
||||
|
||||
result = _extract_evidence_annotations(paragraphs, session)
|
||||
assert result["p-0"] == [{"x": 42}]
|
||||
217
tests/test_phase2.py
Normal file
217
tests/test_phase2.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Unit tests for Phase 2: CodeExecutor Enhancements
|
||||
|
||||
Run: python -m pytest tests/test_phase2.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from utils.code_executor import CodeExecutor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def executor(tmp_path):
|
||||
"""Create a CodeExecutor with a temp output directory."""
|
||||
return CodeExecutor(output_dir=str(tmp_path))
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 5: Evidence capture
|
||||
# ===========================================================================
|
||||
|
||||
class TestEvidenceCapture:
|
||||
def test_evidence_from_result_dataframe(self, executor):
|
||||
"""5.1: When result.result is a DataFrame, capture head(10) as evidence_rows."""
|
||||
code = "import pandas as pd\npd.DataFrame({'a': [1,2,3], 'b': [4,5,6]})"
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert "evidence_rows" in result
|
||||
assert len(result["evidence_rows"]) == 3
|
||||
assert result["evidence_rows"][0] == {"a": 1, "b": 4}
|
||||
|
||||
def test_evidence_capped_at_10(self, executor):
|
||||
"""5.1: Evidence rows are capped at 10."""
|
||||
code = "import pandas as pd\npd.DataFrame({'x': list(range(100))})"
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert len(result["evidence_rows"]) == 10
|
||||
|
||||
def test_evidence_fallback_to_namespace(self, executor):
|
||||
"""5.2: When result.result is not a DataFrame, fallback to namespace."""
|
||||
code = "import pandas as pd\nmy_data = pd.DataFrame({'col': [10, 20]})\nprint('done')"
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert len(result["evidence_rows"]) == 2
|
||||
assert result["evidence_rows"][0] == {"col": 10}
|
||||
|
||||
def test_evidence_empty_when_no_dataframe(self, executor):
|
||||
"""5.3: Returns empty list when no DataFrame is produced."""
|
||||
executor.reset_environment()
|
||||
code = "x = 42"
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert result["evidence_rows"] == []
|
||||
|
||||
def test_evidence_key_in_failure(self, executor):
|
||||
"""5.3: evidence_rows key present even on failure."""
|
||||
code = "import not_a_real_module"
|
||||
result = executor.execute_code(code)
|
||||
assert "evidence_rows" in result
|
||||
assert result["evidence_rows"] == []
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 6: DataFrame auto-detection and export
|
||||
# ===========================================================================
|
||||
|
||||
class TestDataFrameAutoExport:
|
||||
def test_new_dataframe_exported(self, executor, tmp_path):
|
||||
"""6.1-6.4: New DataFrame is detected and exported to CSV."""
|
||||
code = "import pandas as pd\nresult_df = pd.DataFrame({'a': [1], 'b': [2]})"
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert len(result["auto_exported_files"]) >= 1
|
||||
exported = result["auto_exported_files"][0]
|
||||
assert exported["variable_name"] == "result_df"
|
||||
assert exported["filename"] == "result_df.csv"
|
||||
assert exported["rows"] == 1
|
||||
assert exported["cols"] == 2
|
||||
assert exported["columns"] == ["a", "b"]
|
||||
# Verify file actually exists
|
||||
assert os.path.exists(os.path.join(str(tmp_path), "result_df.csv"))
|
||||
|
||||
def test_dedup_suffix(self, executor, tmp_path):
|
||||
"""6.3: Numeric suffix deduplication when file exists."""
|
||||
# Create first file
|
||||
code1 = "import pandas as pd\nmy_df = pd.DataFrame({'x': [1]})"
|
||||
result1 = executor.execute_code(code1)
|
||||
assert result1["success"] is True
|
||||
|
||||
# Reset the DataFrame to force a new id
|
||||
code2 = "my_df = pd.DataFrame({'x': [2]})"
|
||||
result2 = executor.execute_code(code2)
|
||||
assert result2["success"] is True
|
||||
exported_files = result2["auto_exported_files"]
|
||||
assert len(exported_files) >= 1
|
||||
# The second export should have _1 suffix
|
||||
assert exported_files[0]["filename"] == "my_df_1.csv"
|
||||
|
||||
def test_skip_module_names(self, executor):
|
||||
"""6.1: Module-level names like pd, np are skipped."""
|
||||
code = "x = 42" # pd and np already in namespace from setup
|
||||
result = executor.execute_code(code)
|
||||
# Should not export pd or np as DataFrames
|
||||
for f in result["auto_exported_files"]:
|
||||
assert f["variable_name"] not in ("pd", "np", "plt", "sns")
|
||||
|
||||
def test_auto_exported_files_key_in_result(self, executor):
|
||||
"""6.5: auto_exported_files key always present."""
|
||||
code = "x = 1"
|
||||
result = executor.execute_code(code)
|
||||
assert "auto_exported_files" in result
|
||||
assert isinstance(result["auto_exported_files"], list)
|
||||
|
||||
def test_changed_dataframe_detected(self, executor, tmp_path):
|
||||
"""6.2: Changed DataFrame (same name, new object) is detected."""
|
||||
code1 = "import pandas as pd\ndf_test = pd.DataFrame({'a': [1]})"
|
||||
executor.execute_code(code1)
|
||||
|
||||
code2 = "df_test = pd.DataFrame({'a': [1, 2, 3]})"
|
||||
result2 = executor.execute_code(code2)
|
||||
assert result2["success"] is True
|
||||
exported = [f for f in result2["auto_exported_files"] if f["variable_name"] == "df_test"]
|
||||
assert len(exported) == 1
|
||||
assert exported[0]["rows"] == 3
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 7: DATA_FILE_SAVED marker parsing
|
||||
# ===========================================================================
|
||||
|
||||
class TestDataFileSavedMarkerParsing:
|
||||
def test_parse_single_marker(self, executor):
|
||||
"""7.1-7.2: Parse a single DATA_FILE_SAVED marker from stdout."""
|
||||
code = 'print("[DATA_FILE_SAVED] filename: output.csv, rows: 42, description: Test data")'
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert len(result["prompt_saved_files"]) == 1
|
||||
parsed = result["prompt_saved_files"][0]
|
||||
assert parsed["filename"] == "output.csv"
|
||||
assert parsed["rows"] == 42
|
||||
assert parsed["description"] == "Test data"
|
||||
|
||||
def test_parse_multiple_markers(self, executor):
|
||||
"""7.1-7.2: Parse multiple markers."""
|
||||
code = (
|
||||
'print("[DATA_FILE_SAVED] filename: a.csv, rows: 10, description: File A")\n'
|
||||
'print("[DATA_FILE_SAVED] filename: b.xlsx, rows: 20, description: File B")'
|
||||
)
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert len(result["prompt_saved_files"]) == 2
|
||||
assert result["prompt_saved_files"][0]["filename"] == "a.csv"
|
||||
assert result["prompt_saved_files"][1]["filename"] == "b.xlsx"
|
||||
|
||||
def test_no_markers(self, executor):
|
||||
"""7.3: No markers means empty list."""
|
||||
code = 'print("hello world")'
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert result["prompt_saved_files"] == []
|
||||
|
||||
def test_prompt_saved_files_key_in_result(self, executor):
|
||||
"""7.3: prompt_saved_files key always present."""
|
||||
code = "x = 1"
|
||||
result = executor.execute_code(code)
|
||||
assert "prompt_saved_files" in result
|
||||
assert isinstance(result["prompt_saved_files"], list)
|
||||
|
||||
def test_malformed_marker_skipped(self, executor):
|
||||
"""7.1: Malformed markers are silently skipped."""
|
||||
code = 'print("[DATA_FILE_SAVED] this is not valid")'
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert result["prompt_saved_files"] == []
|
||||
|
||||
def test_chinese_filename_and_description(self, executor):
|
||||
"""7.2: Chinese characters in filename and description."""
|
||||
code = 'print("[DATA_FILE_SAVED] filename: 数据汇总.csv, rows: 100, description: 各类型TOP问题聚合统计")'
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert len(result["prompt_saved_files"]) == 1
|
||||
assert result["prompt_saved_files"][0]["filename"] == "数据汇总.csv"
|
||||
assert result["prompt_saved_files"][0]["description"] == "各类型TOP问题聚合统计"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Return structure integrity
|
||||
# ===========================================================================
|
||||
|
||||
class TestReturnStructure:
|
||||
def test_success_return_has_all_keys(self, executor):
|
||||
"""All 7 keys present on success."""
|
||||
result = executor.execute_code("x = 1")
|
||||
expected_keys = {"success", "output", "error", "variables",
|
||||
"evidence_rows", "auto_exported_files", "prompt_saved_files"}
|
||||
assert expected_keys.issubset(set(result.keys()))
|
||||
|
||||
def test_safety_failure_has_all_keys(self, executor):
|
||||
"""All 7 keys present on safety check failure."""
|
||||
result = executor.execute_code("import socket")
|
||||
expected_keys = {"success", "output", "error", "variables",
|
||||
"evidence_rows", "auto_exported_files", "prompt_saved_files"}
|
||||
assert expected_keys.issubset(set(result.keys()))
|
||||
|
||||
def test_execution_error_has_all_keys(self, executor):
|
||||
"""All 7 keys present on execution error."""
|
||||
result = executor.execute_code("1/0")
|
||||
expected_keys = {"success", "output", "error", "variables",
|
||||
"evidence_rows", "auto_exported_files", "prompt_saved_files"}
|
||||
assert expected_keys.issubset(set(result.keys()))
|
||||
233
tests/test_phase3.py
Normal file
233
tests/test_phase3.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Unit tests for Phase 3: Agent Changes
|
||||
|
||||
Run: python -m pytest tests/test_phase3.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
from data_analysis_agent import DataAnalysisAgent
|
||||
from prompts import data_analysis_system_prompt, final_report_system_prompt
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 8.1: _summarize_result
|
||||
# ===========================================================================
|
||||
|
||||
class TestSummarizeResult:
|
||||
@pytest.fixture
|
||||
def agent(self):
|
||||
"""Create a minimal DataAnalysisAgent for testing."""
|
||||
agent = DataAnalysisAgent.__new__(DataAnalysisAgent)
|
||||
agent._session_ref = None
|
||||
return agent
|
||||
|
||||
def test_success_with_evidence_rows(self, agent):
|
||||
"""8.1: Success with evidence rows produces DataFrame summary."""
|
||||
result = {
|
||||
"success": True,
|
||||
"evidence_rows": [{"a": 1, "b": 2}, {"a": 3, "b": 4}],
|
||||
"auto_exported_files": [{"variable_name": "df", "filename": "df.csv", "rows": 150, "cols": 8, "columns": []}],
|
||||
}
|
||||
summary = agent._summarize_result(result)
|
||||
assert "执行成功" in summary
|
||||
assert "DataFrame" in summary
|
||||
assert "150" in summary
|
||||
assert "8" in summary
|
||||
|
||||
def test_success_with_evidence_no_auto_files(self, agent):
|
||||
"""8.1: Success with evidence but no auto_exported_files uses evidence length."""
|
||||
result = {
|
||||
"success": True,
|
||||
"evidence_rows": [{"x": 1}, {"x": 2}, {"x": 3}],
|
||||
"auto_exported_files": [],
|
||||
}
|
||||
summary = agent._summarize_result(result)
|
||||
assert "执行成功" in summary
|
||||
assert "DataFrame" in summary
|
||||
|
||||
def test_success_with_output(self, agent):
|
||||
"""8.1: Success with output but no evidence shows first line."""
|
||||
result = {
|
||||
"success": True,
|
||||
"evidence_rows": [],
|
||||
"output": "Hello World\nSecond line",
|
||||
}
|
||||
summary = agent._summarize_result(result)
|
||||
assert "执行成功" in summary
|
||||
assert "Hello World" in summary
|
||||
|
||||
def test_success_no_output(self, agent):
|
||||
"""8.1: Success with no output or evidence."""
|
||||
result = {"success": True, "evidence_rows": [], "output": ""}
|
||||
summary = agent._summarize_result(result)
|
||||
assert summary == "执行成功"
|
||||
|
||||
def test_failure_short_error(self, agent):
|
||||
"""8.1: Failure with short error message."""
|
||||
result = {"success": False, "error": "KeyError: 'col_x'"}
|
||||
summary = agent._summarize_result(result)
|
||||
assert "执行失败" in summary
|
||||
assert "KeyError" in summary
|
||||
|
||||
def test_failure_long_error_truncated(self, agent):
|
||||
"""8.1: Failure with long error is truncated to 100 chars."""
|
||||
long_error = "A" * 200
|
||||
result = {"success": False, "error": long_error}
|
||||
summary = agent._summarize_result(result)
|
||||
assert "执行失败" in summary
|
||||
assert "..." in summary
|
||||
# The error portion should be at most 103 chars (100 + "...")
|
||||
error_part = summary.split("执行失败: ")[1]
|
||||
assert len(error_part) <= 104
|
||||
|
||||
def test_failure_no_error_field(self, agent):
|
||||
"""8.1: Failure with missing error field."""
|
||||
result = {"success": False}
|
||||
summary = agent._summarize_result(result)
|
||||
assert "执行失败" in summary
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 8.2-8.4: Round_Data construction and session integration
|
||||
# ===========================================================================
|
||||
|
||||
class TestRoundDataConstruction:
|
||||
def test_handle_generate_code_returns_reasoning(self):
|
||||
"""8.2: _handle_generate_code returns reasoning from yaml_data."""
|
||||
agent = DataAnalysisAgent.__new__(DataAnalysisAgent)
|
||||
agent._session_ref = None
|
||||
# We need a minimal executor mock
|
||||
from unittest.mock import MagicMock
|
||||
agent.executor = MagicMock()
|
||||
agent.executor.execute_code.return_value = {
|
||||
"success": True, "output": "ok", "error": "",
|
||||
"variables": {}, "evidence_rows": [],
|
||||
"auto_exported_files": [], "prompt_saved_files": [],
|
||||
}
|
||||
yaml_data = {"code": "x = 1", "reasoning": "Testing reasoning field"}
|
||||
result = agent._handle_generate_code("response text", yaml_data)
|
||||
assert result["reasoning"] == "Testing reasoning field"
|
||||
|
||||
def test_handle_generate_code_empty_reasoning(self):
|
||||
"""8.2: _handle_generate_code returns empty reasoning when not in yaml_data."""
|
||||
agent = DataAnalysisAgent.__new__(DataAnalysisAgent)
|
||||
agent._session_ref = None
|
||||
from unittest.mock import MagicMock
|
||||
agent.executor = MagicMock()
|
||||
agent.executor.execute_code.return_value = {
|
||||
"success": True, "output": "", "error": "",
|
||||
"variables": {}, "evidence_rows": [],
|
||||
"auto_exported_files": [], "prompt_saved_files": [],
|
||||
}
|
||||
yaml_data = {"code": "x = 1"}
|
||||
result = agent._handle_generate_code("response text", yaml_data)
|
||||
assert result["reasoning"] == ""
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 8.3: set_session_ref
|
||||
# ===========================================================================
|
||||
|
||||
class TestSetSessionRef:
|
||||
def test_session_ref_default_none(self):
|
||||
"""8.3: _session_ref defaults to None."""
|
||||
agent = DataAnalysisAgent()
|
||||
assert agent._session_ref is None
|
||||
|
||||
def test_set_session_ref(self):
|
||||
"""8.3: set_session_ref stores the session reference."""
|
||||
agent = DataAnalysisAgent()
|
||||
|
||||
class FakeSession:
|
||||
rounds = []
|
||||
data_files = []
|
||||
|
||||
session = FakeSession()
|
||||
agent.set_session_ref(session)
|
||||
assert agent._session_ref is session
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 9.1: Prompt - intermediate data saving instructions
|
||||
# ===========================================================================
|
||||
|
||||
class TestPromptDataSaving:
|
||||
def test_data_saving_instructions_in_system_prompt(self):
|
||||
"""9.1: data_analysis_system_prompt contains DATA_FILE_SAVED instructions."""
|
||||
assert "[DATA_FILE_SAVED]" in data_analysis_system_prompt
|
||||
assert "中间数据保存规则" in data_analysis_system_prompt
|
||||
|
||||
def test_data_saving_example_in_prompt(self):
|
||||
"""9.1: Prompt contains example of saving and printing marker."""
|
||||
assert "to_csv" in data_analysis_system_prompt
|
||||
assert "session_output_dir" in data_analysis_system_prompt
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 9.2: Prompt - evidence annotation instructions
|
||||
# ===========================================================================
|
||||
|
||||
class TestPromptEvidenceAnnotation:
|
||||
def test_evidence_annotation_in_report_prompt(self):
|
||||
"""9.2: final_report_system_prompt contains evidence annotation instructions."""
|
||||
assert "evidence:round_" in final_report_system_prompt
|
||||
assert "证据标注规则" in final_report_system_prompt
|
||||
|
||||
def test_evidence_annotation_example(self):
|
||||
"""9.2: Prompt contains example of evidence annotation."""
|
||||
assert "<!-- evidence:round_3 -->" in final_report_system_prompt
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 9.3: _build_final_report_prompt includes evidence
|
||||
# ===========================================================================
|
||||
|
||||
class TestBuildFinalReportPromptEvidence:
|
||||
def test_evidence_included_when_session_has_rounds(self):
|
||||
"""9.3: _build_final_report_prompt includes evidence data when rounds exist."""
|
||||
agent = DataAnalysisAgent.__new__(DataAnalysisAgent)
|
||||
agent.analysis_results = []
|
||||
agent.current_round = 2
|
||||
agent.session_output_dir = "/tmp/test"
|
||||
agent.data_profile = "test profile"
|
||||
|
||||
class FakeSession:
|
||||
rounds = [
|
||||
{
|
||||
"round": 1,
|
||||
"reasoning": "分析车型分布",
|
||||
"result_summary": "执行成功,输出 DataFrame (10行×3列)",
|
||||
"evidence_rows": [{"车型": "A", "数量": 42}],
|
||||
},
|
||||
{
|
||||
"round": 2,
|
||||
"reasoning": "分析模块分布",
|
||||
"result_summary": "执行成功",
|
||||
"evidence_rows": [],
|
||||
},
|
||||
]
|
||||
|
||||
agent._session_ref = FakeSession()
|
||||
prompt = agent._build_final_report_prompt([])
|
||||
assert "各轮次分析证据数据" in prompt
|
||||
assert "第1轮" in prompt
|
||||
assert "第2轮" in prompt
|
||||
assert "车型" in prompt
|
||||
|
||||
def test_no_evidence_when_no_session_ref(self):
|
||||
"""9.3: _build_final_report_prompt works without session ref."""
|
||||
agent = DataAnalysisAgent.__new__(DataAnalysisAgent)
|
||||
agent.analysis_results = []
|
||||
agent.current_round = 1
|
||||
agent.session_output_dir = "/tmp/test"
|
||||
agent.data_profile = "test profile"
|
||||
agent._session_ref = None
|
||||
prompt = agent._build_final_report_prompt([])
|
||||
assert "各轮次分析证据数据" not in prompt
|
||||
285
tests/test_properties.py
Normal file
285
tests/test_properties.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Property-based tests for agent-robustness-optimization features.
|
||||
Uses hypothesis with reduced examples (max_examples=20) for fast execution.
|
||||
|
||||
Run: python -m pytest tests/test_properties.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
|
||||
# Ensure project root is on path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from utils.data_privacy import (
|
||||
_extract_column_from_error,
|
||||
_lookup_column_in_profile,
|
||||
generate_enriched_hint,
|
||||
)
|
||||
from utils.analysis_templates import get_template, list_templates, TEMPLATE_REGISTRY
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DATA_CONTEXT_PATTERNS = [
|
||||
r"KeyError:\s*['\"](.+?)['\"]",
|
||||
r"ValueError.*(?:column|col|field)",
|
||||
r"NameError.*(?:df|data|frame)",
|
||||
r"(?:empty|no\s+data|0\s+rows)",
|
||||
r"IndexError.*(?:out of range|out of bounds)",
|
||||
]
|
||||
|
||||
|
||||
def classify_error(error_message: str) -> str:
|
||||
"""Mirror of DataAnalysisAgent._classify_error for testing without IPython."""
|
||||
for pattern in DATA_CONTEXT_PATTERNS:
|
||||
if re.search(pattern, error_message, re.IGNORECASE):
|
||||
return "data_context"
|
||||
return "other"
|
||||
|
||||
|
||||
SAMPLE_SAFE_PROFILE = """# 数据结构概览 (Schema Profile)
|
||||
|
||||
## 文件: test.csv
|
||||
|
||||
- **维度**: 100 行 x 3 列
|
||||
- **列名**: `车型, 模块, 问题类型`
|
||||
|
||||
### 列结构:
|
||||
|
||||
| 列名 | 数据类型 | 空值率 | 唯一值数 | 特征描述 |
|
||||
|------|---------|--------|---------|----------|
|
||||
| 车型 | object | 0.0% | 5 | 低基数分类(5类) |
|
||||
| 模块 | object | 2.0% | 12 | 中基数分类(12类) |
|
||||
| 问题类型 | object | 0.0% | 8 | 低基数分类(8类) |
|
||||
"""
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 1: Error Classification Correctness (Task 11.1)
|
||||
# ===========================================================================
|
||||
|
||||
# Strategy: generate error messages that contain data-context patterns
|
||||
data_context_error_st = st.one_of(
|
||||
st.from_regex(r"KeyError: '[a-zA-Z_]+'" , fullmatch=True),
|
||||
st.from_regex(r'KeyError: "[a-zA-Z_]+"', fullmatch=True),
|
||||
st.just("ValueError: column 'x' not found"),
|
||||
st.just("NameError: name 'df' is not defined"),
|
||||
st.just("empty DataFrame"),
|
||||
st.just("0 rows returned"),
|
||||
st.just("IndexError: index 5 is out of range"),
|
||||
)
|
||||
|
||||
non_data_error_st = st.one_of(
|
||||
st.just("SyntaxError: invalid syntax"),
|
||||
st.just("TypeError: unsupported operand"),
|
||||
st.just("ZeroDivisionError: division by zero"),
|
||||
st.just("ImportError: No module named 'foo'"),
|
||||
st.text(min_size=1, max_size=50).filter(
|
||||
lambda s: not any(re.search(p, s, re.IGNORECASE) for p in DATA_CONTEXT_PATTERNS)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(err=data_context_error_st)
|
||||
def test_prop1_data_context_errors_classified(err):
|
||||
"""Data-context error messages must be classified as 'data_context'."""
|
||||
assert classify_error(err) == "data_context"
|
||||
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(err=non_data_error_st)
|
||||
def test_prop1_non_data_errors_classified(err):
|
||||
"""Non-data error messages must be classified as 'other'."""
|
||||
assert classify_error(err) == "other"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 3: Enriched Hint Contains Column Metadata Without Real Data (11.2)
|
||||
# ===========================================================================
|
||||
|
||||
known_columns = ["车型", "模块", "问题类型"]
|
||||
column_st = st.sampled_from(known_columns)
|
||||
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(col=column_st)
|
||||
def test_prop3_enriched_hint_contains_column_meta(col):
|
||||
"""Enriched hint for a known column must contain its metadata."""
|
||||
error_msg = f"KeyError: '{col}'"
|
||||
hint = generate_enriched_hint(error_msg, SAMPLE_SAFE_PROFILE)
|
||||
assert col in hint
|
||||
assert "数据类型" in hint
|
||||
assert "唯一值数量" in hint
|
||||
assert "空值率" in hint
|
||||
assert "特征描述" in hint
|
||||
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(col=column_st)
|
||||
def test_prop3_enriched_hint_no_real_data(col):
|
||||
"""Enriched hint must NOT contain real data values (min/max/mean/sample rows)."""
|
||||
error_msg = f"KeyError: '{col}'"
|
||||
hint = generate_enriched_hint(error_msg, SAMPLE_SAFE_PROFILE)
|
||||
# Should not contain statistical values or sample data
|
||||
for forbidden in ["Min=", "Max=", "Mean=", "TOP 5 高频值"]:
|
||||
assert forbidden not in hint
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 4: Env Var Config Override (Task 11.3)
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=10)
|
||||
@given(val=st.integers(min_value=1, max_value=100))
|
||||
def test_prop4_env_override_max_data_context_retries(val):
|
||||
"""APP_MAX_DATA_CONTEXT_RETRIES env var must override config."""
|
||||
from config.app_config import AppConfig
|
||||
os.environ["APP_MAX_DATA_CONTEXT_RETRIES"] = str(val)
|
||||
try:
|
||||
config = AppConfig.from_env()
|
||||
assert config.max_data_context_retries == val
|
||||
finally:
|
||||
del os.environ["APP_MAX_DATA_CONTEXT_RETRIES"]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 5: Sliding Window Trimming Invariants (Task 11.4)
|
||||
# ===========================================================================
|
||||
|
||||
def make_history(n_pairs: int, first_msg: str = "initial requirement"):
|
||||
"""Build a fake conversation history with n_pairs of user+assistant messages."""
|
||||
history = [{"role": "user", "content": first_msg}]
|
||||
for i in range(n_pairs):
|
||||
history.append({"role": "assistant", "content": f'action: "generate_code"\ncode: | print({i})'})
|
||||
history.append({"role": "user", "content": f"代码执行反馈:\n成功 round {i}"})
|
||||
return history
|
||||
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(
|
||||
n_pairs=st.integers(min_value=1, max_value=30),
|
||||
window=st.integers(min_value=1, max_value=10),
|
||||
)
|
||||
def test_prop5_trimming_preserves_first_message(n_pairs, window):
|
||||
"""After trimming, the first user message is always at index 0."""
|
||||
history = make_history(n_pairs, first_msg="ORIGINAL_REQ")
|
||||
max_messages = window * 2
|
||||
|
||||
if len(history) <= max_messages:
|
||||
return # no trimming needed, invariant trivially holds
|
||||
|
||||
first_message = history[0]
|
||||
start_idx = 1
|
||||
has_summary = (
|
||||
len(history) > 1
|
||||
and history[1]["role"] == "user"
|
||||
and history[1]["content"].startswith("[分析摘要]")
|
||||
)
|
||||
if has_summary:
|
||||
start_idx = 2
|
||||
|
||||
messages_to_consider = history[start_idx:]
|
||||
messages_to_trim = messages_to_consider[:-max_messages]
|
||||
messages_to_keep = messages_to_consider[-max_messages:]
|
||||
|
||||
if not messages_to_trim:
|
||||
return
|
||||
|
||||
new_history = [first_message]
|
||||
new_history.append({"role": "user", "content": "[分析摘要] summary"})
|
||||
new_history.extend(messages_to_keep)
|
||||
|
||||
assert new_history[0]["content"] == "ORIGINAL_REQ"
|
||||
assert len(new_history) <= max_messages + 2 # first + summary + window
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 6: Trimming Summary Content (Task 11.5)
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(n_pairs=st.integers(min_value=2, max_value=15))
|
||||
def test_prop6_summary_excludes_code_blocks(n_pairs):
|
||||
"""Compressed summary must not contain code blocks or raw output."""
|
||||
history = make_history(n_pairs)
|
||||
# Simulate _compress_trimmed_messages logic
|
||||
summary_parts = ["[分析摘要] 以下是之前分析轮次的概要:"]
|
||||
round_num = 0
|
||||
for msg in history[1:]: # skip first
|
||||
content = msg["content"]
|
||||
if msg["role"] == "assistant":
|
||||
round_num += 1
|
||||
action = "generate_code"
|
||||
if "collect_figures" in content:
|
||||
action = "collect_figures"
|
||||
summary_parts.append(f"- 轮次{round_num}: 动作={action}")
|
||||
elif msg["role"] == "user" and "代码执行反馈" in content:
|
||||
success = "失败" if "[ERROR]" in content or "执行错误" in content else "成功"
|
||||
if summary_parts and summary_parts[-1].startswith("- 轮次"):
|
||||
summary_parts[-1] += f", 执行结果={success}"
|
||||
|
||||
summary = "\n".join(summary_parts)
|
||||
assert "```" not in summary
|
||||
assert "print(" not in summary
|
||||
assert "[分析摘要]" in summary
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 7: Template Prompt Integration (Task 11.6)
|
||||
# ===========================================================================
|
||||
|
||||
valid_template_names = list(TEMPLATE_REGISTRY.keys())
|
||||
|
||||
|
||||
@settings(max_examples=len(valid_template_names))
|
||||
@given(name=st.sampled_from(valid_template_names))
|
||||
def test_prop7_template_prompt_prepended(name):
|
||||
"""For any valid template, get_full_prompt() output must be non-empty."""
|
||||
template = get_template(name)
|
||||
prompt = template.get_full_prompt()
|
||||
assert len(prompt) > 0
|
||||
assert template.display_name in prompt
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 8: Invalid Template Name Raises Error (Task 11.7)
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(name=st.text(min_size=1, max_size=30).filter(lambda s: s not in TEMPLATE_REGISTRY))
|
||||
def test_prop8_invalid_template_raises_error(name):
|
||||
"""Invalid template names must raise ValueError listing available templates."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_template(name)
|
||||
error_msg = str(exc_info.value)
|
||||
# Must list available template names
|
||||
for valid_name in TEMPLATE_REGISTRY:
|
||||
assert valid_name in error_msg
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 11: Parallel Profile Merge With Error Resilience (Task 11.8)
|
||||
# ===========================================================================
|
||||
|
||||
def test_prop11_parallel_profile_error_resilience():
|
||||
"""Parallel profiling with mix of valid/invalid files includes all entries."""
|
||||
from utils.data_privacy import build_safe_profile, build_local_profile
|
||||
|
||||
valid_file = "uploads/data_simple_200.csv"
|
||||
invalid_file = "/nonexistent/fake_file.csv"
|
||||
|
||||
# Test build_safe_profile handles missing files gracefully
|
||||
safe = build_safe_profile([valid_file, invalid_file])
|
||||
assert "fake_file.csv" in safe # error entry present
|
||||
if os.path.exists(valid_file):
|
||||
assert "data_simple_200.csv" in safe # valid entry present
|
||||
297
tests/test_unit.py
Normal file
297
tests/test_unit.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Unit and integration tests for agent-robustness-optimization features.
|
||||
|
||||
Run: python -m pytest tests/test_unit.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import re
|
||||
import pytest
|
||||
|
||||
from utils.data_privacy import (
|
||||
_extract_column_from_error,
|
||||
_lookup_column_in_profile,
|
||||
generate_enriched_hint,
|
||||
)
|
||||
from utils.analysis_templates import get_template, list_templates, TEMPLATE_REGISTRY
|
||||
from config.app_config import AppConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Updated patterns matching data_analysis_agent.py
|
||||
DATA_CONTEXT_PATTERNS = [
|
||||
# KeyError - missing key/column
|
||||
r"KeyError:\s*['\"](.+?)['\"]",
|
||||
# ValueError - value-related issues
|
||||
r"ValueError.*(?:column|col|field|shape|axis)",
|
||||
# NameError - undefined variables
|
||||
r"NameError.*(?:df|data|frame|series)",
|
||||
# Empty/missing data
|
||||
r"(?:empty|no\s+data|0\s+rows|No\s+data)",
|
||||
# IndexError - out of bounds
|
||||
r"IndexError.*(?:out of range|out of bounds)",
|
||||
# AttributeError - missing attributes
|
||||
r"AttributeError.*(?:DataFrame|Series|object)\s+has\s+no\s+attribute",
|
||||
# Pandas-specific errors
|
||||
r"pd\.errors\.(?:EmptyDataError|ParserError|MergeError)",
|
||||
r"MergeError: No common columns",
|
||||
# Type errors
|
||||
r"TypeError.*(?:unsupported operand|expected string|cannot convert)",
|
||||
# UnboundLocalError - undefined local variables
|
||||
r"UnboundLocalError.*referenced before assignment",
|
||||
# Syntax errors
|
||||
r"SyntaxError: invalid syntax",
|
||||
# Module/Import errors for data libraries
|
||||
r"ModuleNotFoundError.*(?:pandas|numpy|matplotlib)",
|
||||
r"ImportError.*(?:pandas|numpy|matplotlib)",
|
||||
]
|
||||
|
||||
|
||||
def classify_error(error_message: str) -> str:
|
||||
for pattern in DATA_CONTEXT_PATTERNS:
|
||||
if re.search(pattern, error_message, re.IGNORECASE):
|
||||
return "data_context"
|
||||
return "other"
|
||||
|
||||
|
||||
SAMPLE_PROFILE = """| 列名 | 数据类型 | 空值率 | 唯一值数 | 特征描述 |
|
||||
|------|---------|--------|---------|----------|
|
||||
| 车型 | object | 0.0% | 5 | 低基数分类(5类) |
|
||||
| 模块 | object | 2.0% | 12 | 中基数分类(12类) |
|
||||
"""
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 12.1: Unit tests for error classifier
|
||||
# ===========================================================================
|
||||
|
||||
class TestErrorClassifier:
|
||||
def test_keyerror_single_quotes(self):
|
||||
assert classify_error("KeyError: '车型'") == "data_context"
|
||||
|
||||
def test_keyerror_double_quotes(self):
|
||||
assert classify_error('KeyError: "model_name"') == "data_context"
|
||||
|
||||
def test_valueerror_column(self):
|
||||
assert classify_error("ValueError: column 'x' not in DataFrame") == "data_context"
|
||||
|
||||
def test_nameerror_df(self):
|
||||
assert classify_error("NameError: name 'df' is not defined") == "data_context"
|
||||
|
||||
def test_empty_dataframe(self):
|
||||
assert classify_error("empty DataFrame after filtering") == "data_context"
|
||||
|
||||
def test_zero_rows(self):
|
||||
assert classify_error("0 rows returned from query") == "data_context"
|
||||
|
||||
def test_index_out_of_range(self):
|
||||
assert classify_error("IndexError: index 10 is out of range") == "data_context"
|
||||
|
||||
def test_syntax_error_is_data_context(self):
|
||||
assert classify_error("SyntaxError: invalid syntax") == "data_context"
|
||||
|
||||
def test_type_error_is_data_context(self):
|
||||
assert classify_error("TypeError: unsupported operand") == "data_context"
|
||||
|
||||
def test_generic_text_is_other(self):
|
||||
assert classify_error("Something went wrong") == "other"
|
||||
|
||||
def test_empty_string_is_other(self):
|
||||
assert classify_error("") == "other"
|
||||
|
||||
# ===========================================================================
|
||||
# Additional tests for improved error classifier
|
||||
# ===========================================================================
|
||||
|
||||
def test_attributeerror_dataframe(self):
|
||||
assert classify_error("AttributeError: 'DataFrame' object has no attribute 'xxx'") == "data_context"
|
||||
|
||||
def test_attributeerror_series(self):
|
||||
assert classify_error("AttributeError: 'Series' object has no attribute 'xxx'") == "data_context"
|
||||
|
||||
def test_pd_emptydataerror(self):
|
||||
assert classify_error("pd.errors.EmptyDataError: No data") == "data_context"
|
||||
|
||||
def test_pd_parsererror(self):
|
||||
assert classify_error("pd.errors.ParserError: Error tokenizing data") == "data_context"
|
||||
|
||||
def test_pd_mergeerror(self):
|
||||
assert classify_error("MergeError: No common columns to merge") == "data_context"
|
||||
|
||||
def test_typeerror_unsupported_operand(self):
|
||||
assert classify_error("TypeError: unsupported operand type(s) for +: 'int' and 'str'") == "data_context"
|
||||
|
||||
def test_typeerror_expected_string(self):
|
||||
assert classify_error("TypeError: expected string or bytes-like object") == "data_context"
|
||||
|
||||
def test_unboundlocalerror(self):
|
||||
assert classify_error("UnboundLocalError: local variable 'df' referenced before assignment") == "data_context"
|
||||
|
||||
def test_syntaxerror(self):
|
||||
assert classify_error("SyntaxError: invalid syntax") == "data_context"
|
||||
|
||||
def test_modulenotfounderror(self):
|
||||
assert classify_error("ModuleNotFoundError: No module named 'pandas'") == "data_context"
|
||||
|
||||
def test_importerror(self):
|
||||
assert classify_error("ImportError: cannot import name 'xxx' from 'pandas'") == "data_context"
|
||||
|
||||
def test_valueerror_shape(self):
|
||||
assert classify_error("ValueError: shape mismatch") == "data_context"
|
||||
|
||||
def test_valueerror_axis(self):
|
||||
assert classify_error("ValueError: axis out of bounds") == "data_context"
|
||||
|
||||
def test_nameerror_series(self):
|
||||
assert classify_error("NameError: name 'series' is not defined") == "data_context"
|
||||
|
||||
def test_no_data_message(self):
|
||||
assert classify_error("No data available for analysis") == "data_context"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 12.1 continued: Unit tests for column extraction and lookup
|
||||
# ===========================================================================
|
||||
|
||||
class TestColumnExtraction:
|
||||
def test_extract_from_keyerror(self):
|
||||
assert _extract_column_from_error("KeyError: '车型'") == "车型"
|
||||
|
||||
def test_extract_from_column_phrase(self):
|
||||
assert _extract_column_from_error("column '模块' not found") == "模块"
|
||||
|
||||
def test_extract_none_for_generic(self):
|
||||
assert _extract_column_from_error("SyntaxError: bad") is None
|
||||
|
||||
def test_lookup_existing_column(self):
|
||||
result = _lookup_column_in_profile("车型", SAMPLE_PROFILE)
|
||||
assert result is not None
|
||||
assert result["dtype"] == "object"
|
||||
assert result["unique_count"] == "5"
|
||||
|
||||
def test_lookup_missing_column(self):
|
||||
assert _lookup_column_in_profile("不存在", SAMPLE_PROFILE) is None
|
||||
|
||||
def test_lookup_none_column(self):
|
||||
assert _lookup_column_in_profile(None, SAMPLE_PROFILE) is None
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 12.2: Unit tests for conversation trimming at boundary conditions
|
||||
# ===========================================================================
|
||||
|
||||
class TestConversationTrimming:
|
||||
def _make_history(self, n_pairs):
|
||||
history = [{"role": "user", "content": "ORIGINAL"}]
|
||||
for i in range(n_pairs):
|
||||
history.append({"role": "assistant", "content": f"response {i}"})
|
||||
history.append({"role": "user", "content": f"feedback {i}"})
|
||||
return history
|
||||
|
||||
def test_no_trimming_when_under_limit(self):
|
||||
"""History with 3 pairs and window=5 should not be trimmed."""
|
||||
history = self._make_history(3) # 1 + 6 = 7 messages
|
||||
window = 5
|
||||
max_messages = window * 2 # 10
|
||||
assert len(history) <= max_messages # no trimming
|
||||
|
||||
def test_trimming_at_exact_boundary(self):
|
||||
"""History exactly at 2*window should not be trimmed."""
|
||||
window = 3
|
||||
history = self._make_history(3) # 1 + 6 = 7 messages
|
||||
max_messages = window * 2 # 6
|
||||
# 7 > 6, so trimming should happen
|
||||
assert len(history) > max_messages
|
||||
|
||||
def test_first_message_always_preserved(self):
|
||||
"""After trimming, first message must be preserved."""
|
||||
history = self._make_history(10)
|
||||
window = 2
|
||||
max_messages = window * 2
|
||||
|
||||
first = history[0]
|
||||
to_consider = history[1:]
|
||||
to_keep = to_consider[-max_messages:]
|
||||
|
||||
new_history = [first, {"role": "user", "content": "[分析摘要] ..."}]
|
||||
new_history.extend(to_keep)
|
||||
|
||||
assert new_history[0]["content"] == "ORIGINAL"
|
||||
|
||||
def test_summary_replaces_old_summary(self):
|
||||
"""If a summary already exists at index 1, it should be replaced."""
|
||||
history = [
|
||||
{"role": "user", "content": "ORIGINAL"},
|
||||
{"role": "user", "content": "[分析摘要] old summary"},
|
||||
]
|
||||
for i in range(8):
|
||||
history.append({"role": "assistant", "content": f"resp {i}"})
|
||||
history.append({"role": "user", "content": f"fb {i}"})
|
||||
|
||||
# Simulate trimming with existing summary
|
||||
has_summary = history[1]["content"].startswith("[分析摘要]")
|
||||
assert has_summary
|
||||
start_idx = 2 if has_summary else 1
|
||||
assert start_idx == 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 12.3: Tests for template API
|
||||
# ===========================================================================
|
||||
|
||||
class TestTemplateSystem:
|
||||
def test_list_templates_returns_all(self):
|
||||
templates = list_templates()
|
||||
assert len(templates) == len(TEMPLATE_REGISTRY)
|
||||
names = {t["name"] for t in templates}
|
||||
assert names == set(TEMPLATE_REGISTRY.keys())
|
||||
|
||||
def test_get_valid_template(self):
|
||||
for name in TEMPLATE_REGISTRY:
|
||||
t = get_template(name)
|
||||
assert t.name # has a display name
|
||||
assert len(t.steps) > 0 # template has steps
|
||||
|
||||
def test_get_invalid_template_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
get_template("nonexistent_template_xyz")
|
||||
|
||||
def test_template_prompt_not_empty(self):
|
||||
for name in TEMPLATE_REGISTRY:
|
||||
t = get_template(name)
|
||||
prompt = t.get_full_prompt()
|
||||
assert len(prompt) > 50 # should be substantial
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 12.4: Tests for config
|
||||
# ===========================================================================
|
||||
|
||||
class TestAppConfig:
|
||||
def test_defaults(self):
|
||||
config = AppConfig()
|
||||
assert config.max_data_context_retries == 2
|
||||
assert config.conversation_window_size == 10
|
||||
assert config.max_parallel_profiles == 4
|
||||
|
||||
def test_env_override(self):
|
||||
os.environ["APP_MAX_DATA_CONTEXT_RETRIES"] = "5"
|
||||
os.environ["APP_CONVERSATION_WINDOW_SIZE"] = "20"
|
||||
os.environ["APP_MAX_PARALLEL_PROFILES"] = "8"
|
||||
try:
|
||||
config = AppConfig.from_env()
|
||||
assert config.max_data_context_retries == 5
|
||||
assert config.conversation_window_size == 20
|
||||
assert config.max_parallel_profiles == 8
|
||||
finally:
|
||||
del os.environ["APP_MAX_DATA_CONTEXT_RETRIES"]
|
||||
del os.environ["APP_CONVERSATION_WINDOW_SIZE"]
|
||||
del os.environ["APP_MAX_PARALLEL_PROFILES"]
|
||||
@@ -6,5 +6,12 @@
|
||||
from utils.code_executor import CodeExecutor
|
||||
from utils.llm_helper import LLMHelper
|
||||
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
|
||||
from utils.logger import PrintCapture, create_session_logger
|
||||
|
||||
__all__ = ["CodeExecutor", "LLMHelper", "AsyncFallbackOpenAIClient"]
|
||||
__all__ = [
|
||||
"CodeExecutor",
|
||||
"LLMHelper",
|
||||
"AsyncFallbackOpenAIClient",
|
||||
"PrintCapture",
|
||||
"create_session_logger",
|
||||
]
|
||||
153
utils/analysis_templates.py
Normal file
153
utils/analysis_templates.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
分析模板系统 - 从 config/templates/*.yaml 加载模板
|
||||
|
||||
模板文件格式:
|
||||
name: 模板显示名称
|
||||
description: 模板描述
|
||||
steps:
|
||||
- name: 步骤名称
|
||||
description: 步骤描述
|
||||
prompt: 给LLM的指令
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob
|
||||
import yaml
|
||||
from typing import List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
TEMPLATES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config", "templates")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisStep:
|
||||
"""分析步骤"""
|
||||
name: str
|
||||
description: str
|
||||
prompt: str
|
||||
|
||||
|
||||
class AnalysisTemplate:
|
||||
"""从 YAML 文件加载的分析模板"""
|
||||
|
||||
def __init__(self, name: str, display_name: str, description: str, steps: List[AnalysisStep], filepath: str = ""):
|
||||
self.name = name
|
||||
self.display_name = display_name
|
||||
self.description = description
|
||||
self.steps = steps
|
||||
self.filepath = filepath
|
||||
|
||||
def get_full_prompt(self) -> str:
|
||||
prompt = f"# {self.display_name}\n\n{self.description}\n\n"
|
||||
prompt += "## 分析步骤:\n\n"
|
||||
for i, step in enumerate(self.steps, 1):
|
||||
prompt += f"### {i}. {step.name}\n"
|
||||
prompt += f"{step.description}\n\n"
|
||||
prompt += f"```\n{step.prompt}\n```\n\n"
|
||||
return prompt
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"steps": [{"name": s.name, "description": s.description, "prompt": s.prompt} for s in self.steps],
|
||||
}
|
||||
|
||||
|
||||
def _load_template_from_file(filepath: str) -> AnalysisTemplate:
|
||||
"""从单个 YAML 文件加载模板"""
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
template_name = os.path.splitext(os.path.basename(filepath))[0]
|
||||
steps = []
|
||||
for s in data.get("steps", []):
|
||||
steps.append(AnalysisStep(
|
||||
name=s.get("name", ""),
|
||||
description=s.get("description", ""),
|
||||
prompt=s.get("prompt", ""),
|
||||
))
|
||||
|
||||
return AnalysisTemplate(
|
||||
name=template_name,
|
||||
display_name=data.get("name", template_name),
|
||||
description=data.get("description", ""),
|
||||
steps=steps,
|
||||
filepath=filepath,
|
||||
)
|
||||
|
||||
|
||||
def _scan_templates() -> Dict[str, AnalysisTemplate]:
|
||||
"""扫描 config/templates/ 目录加载所有模板"""
|
||||
registry = {}
|
||||
if not os.path.exists(TEMPLATES_DIR):
|
||||
os.makedirs(TEMPLATES_DIR, exist_ok=True)
|
||||
return registry
|
||||
|
||||
for fpath in sorted(glob.glob(os.path.join(TEMPLATES_DIR, "*.yaml"))):
|
||||
try:
|
||||
tpl = _load_template_from_file(fpath)
|
||||
registry[tpl.name] = tpl
|
||||
except Exception as e:
|
||||
print(f"[WARN] 加载模板失败 {fpath}: {e}")
|
||||
return registry
|
||||
|
||||
|
||||
# Module-level registry, refreshed on each call to support hot-editing
|
||||
def _get_registry() -> Dict[str, AnalysisTemplate]:
|
||||
return _scan_templates()
|
||||
|
||||
|
||||
# Keep TEMPLATE_REGISTRY as a lazy property for backward compatibility with tests
|
||||
TEMPLATE_REGISTRY = _scan_templates()
|
||||
|
||||
|
||||
def get_template(template_name: str) -> AnalysisTemplate:
|
||||
"""获取分析模板(每次从磁盘重新加载以支持热编辑)"""
|
||||
registry = _get_registry()
|
||||
if template_name in registry:
|
||||
return registry[template_name]
|
||||
raise ValueError(f"未找到模板: {template_name}。可用模板: {list(registry.keys())}")
|
||||
|
||||
|
||||
def list_templates() -> List[Dict[str, str]]:
|
||||
"""列出所有可用模板"""
|
||||
registry = _get_registry()
|
||||
return [
|
||||
{"name": tpl.name, "display_name": tpl.display_name, "description": tpl.description}
|
||||
for tpl in registry.values()
|
||||
]
|
||||
|
||||
|
||||
def save_template(template_name: str, data: Dict[str, Any]) -> str:
|
||||
"""保存或更新模板到 YAML 文件,返回文件路径"""
|
||||
os.makedirs(TEMPLATES_DIR, exist_ok=True)
|
||||
filepath = os.path.join(TEMPLATES_DIR, f"{template_name}.yaml")
|
||||
|
||||
yaml_data = {
|
||||
"name": data.get("display_name", data.get("name", template_name)),
|
||||
"description": data.get("description", ""),
|
||||
"steps": data.get("steps", []),
|
||||
}
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
yaml.dump(yaml_data, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
|
||||
|
||||
# Refresh global registry
|
||||
global TEMPLATE_REGISTRY
|
||||
TEMPLATE_REGISTRY = _scan_templates()
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
def delete_template(template_name: str) -> bool:
|
||||
"""删除模板文件"""
|
||||
filepath = os.path.join(TEMPLATES_DIR, f"{template_name}.yaml")
|
||||
if os.path.exists(filepath):
|
||||
os.remove(filepath)
|
||||
global TEMPLATE_REGISTRY
|
||||
TEMPLATE_REGISTRY = _scan_templates()
|
||||
return True
|
||||
return False
|
||||
103
utils/cache_manager.py
Normal file
103
utils/cache_manager.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
缓存管理器 - 支持数据和LLM响应缓存
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Callable
|
||||
from functools import wraps
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""缓存管理器"""
|
||||
|
||||
def __init__(self, cache_dir: str = ".cache", enabled: bool = True):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.enabled = enabled
|
||||
|
||||
if self.enabled:
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_cache_key(self, *args, **kwargs) -> str:
|
||||
"""生成缓存键"""
|
||||
key_data = f"{args}_{kwargs}"
|
||||
return hashlib.md5(key_data.encode()).hexdigest()
|
||||
|
||||
def _get_cache_path(self, key: str) -> Path:
|
||||
"""获取缓存文件路径"""
|
||||
return self.cache_dir / f"{key}.pkl"
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""获取缓存"""
|
||||
if not self.enabled:
|
||||
return None
|
||||
|
||||
cache_path = self._get_cache_path(key)
|
||||
if cache_path.exists():
|
||||
try:
|
||||
with open(cache_path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
except Exception as e:
|
||||
print(f"[WARN] 读取缓存失败: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""设置缓存"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
cache_path = self._get_cache_path(key)
|
||||
try:
|
||||
with open(cache_path, 'wb') as f:
|
||||
pickle.dump(value, f)
|
||||
except Exception as e:
|
||||
print(f"[WARN] 写入缓存失败: {e}")
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空所有缓存"""
|
||||
if self.cache_dir.exists():
|
||||
for cache_file in self.cache_dir.glob("*.pkl"):
|
||||
cache_file.unlink()
|
||||
print("[OK] 缓存已清空")
|
||||
|
||||
def cached(self, key_func: Optional[Callable] = None):
|
||||
"""缓存装饰器"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not self.enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# 生成缓存键
|
||||
if key_func:
|
||||
cache_key = key_func(*args, **kwargs)
|
||||
else:
|
||||
cache_key = self._get_cache_key(*args, **kwargs)
|
||||
|
||||
# 尝试从缓存获取
|
||||
cached_value = self.get(cache_key)
|
||||
if cached_value is not None:
|
||||
print(f"[CACHE] 使用缓存: {cache_key[:8]}...")
|
||||
return cached_value
|
||||
|
||||
# 执行函数并缓存结果
|
||||
result = func(*args, **kwargs)
|
||||
self.set(cache_key, result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
class LLMCacheManager(CacheManager):
|
||||
"""LLM响应缓存管理器"""
|
||||
|
||||
def get_cache_key_from_messages(self, messages: list, model: str = "") -> str:
|
||||
"""从消息列表生成缓存键"""
|
||||
key_data = json.dumps(messages, sort_keys=True) + model
|
||||
return hashlib.md5(key_data.encode()).hexdigest()
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import ast
|
||||
import traceback
|
||||
@@ -15,6 +16,7 @@ from IPython.utils.capture import capture_output
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.font_manager as fm
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
@@ -26,7 +28,9 @@ class CodeExecutor:
|
||||
"pandas",
|
||||
"pd",
|
||||
"numpy",
|
||||
"glob",
|
||||
"np",
|
||||
"subprocess",
|
||||
"matplotlib",
|
||||
"matplotlib.pyplot",
|
||||
"plt",
|
||||
@@ -35,6 +39,16 @@ class CodeExecutor:
|
||||
"duckdb",
|
||||
"scipy",
|
||||
"sklearn",
|
||||
"sklearn.feature_extraction.text",
|
||||
"sklearn.preprocessing",
|
||||
"sklearn.model_selection",
|
||||
"sklearn.metrics",
|
||||
"sklearn.ensemble",
|
||||
"sklearn.linear_model",
|
||||
"sklearn.cluster",
|
||||
"sklearn.decomposition",
|
||||
"sklearn.manifold",
|
||||
"statsmodels",
|
||||
"plotly",
|
||||
"dash",
|
||||
"requests",
|
||||
@@ -65,8 +79,49 @@ class CodeExecutor:
|
||||
"dataclasses",
|
||||
"enum",
|
||||
"sqlite3",
|
||||
"jieba",
|
||||
"wordcloud",
|
||||
"PIL",
|
||||
"random",
|
||||
"networkx",
|
||||
"platform",
|
||||
}
|
||||
|
||||
# Maximum rows for auto-export; DataFrames larger than this are skipped
|
||||
# to avoid heavy disk I/O on large datasets.
|
||||
AUTO_EXPORT_MAX_ROWS = 50000
|
||||
|
||||
# Variable names to skip during DataFrame auto-export
|
||||
# (common import aliases, built-in namespace names, and typical
|
||||
# temporary/intermediate variable names that shouldn't be persisted)
|
||||
_SKIP_EXPORT_NAMES = {
|
||||
# Import aliases
|
||||
"pd", "np", "plt", "sns", "os", "json", "sys", "re", "io",
|
||||
"csv", "glob", "duckdb", "display", "math", "datetime", "time",
|
||||
"warnings", "logging", "copy", "pickle", "pathlib", "collections",
|
||||
"itertools", "functools", "operator", "random", "networkx",
|
||||
# Common data variable — the main loaded DataFrame should not be
|
||||
# auto-exported every round; the LLM can save it explicitly via
|
||||
# DATA_FILE_SAVED if needed.
|
||||
"df",
|
||||
# Typical intermediate/temporary variable names from analysis code
|
||||
"cross_table", "cross_table_filtered",
|
||||
"module_issue_table", "module_issue_filtered",
|
||||
"correlation_matrix",
|
||||
"feature_data", "person_stats", "top_persons",
|
||||
"abnormal_durations", "abnormal_orders",
|
||||
"missing_df", "missing_values", "missing_percent",
|
||||
"monthly_counts", "monthly_summary",
|
||||
"distribution_results", "phrase_freq",
|
||||
"normal_durations",
|
||||
"df_check", "df_temp",
|
||||
}
|
||||
|
||||
# Regex for parsing DATA_FILE_SAVED markers
|
||||
_DATA_FILE_SAVED_RE = re.compile(
|
||||
r"\[DATA_FILE_SAVED\]\s*filename:\s*(.+?),\s*rows:\s*(\d+),\s*description:\s*(.+)"
|
||||
)
|
||||
|
||||
def __init__(self, output_dir: str = "outputs"):
|
||||
"""
|
||||
初始化代码执行器
|
||||
@@ -197,6 +252,7 @@ import matplotlib.pyplot as plt
|
||||
import duckdb
|
||||
import os
|
||||
import json
|
||||
import glob
|
||||
from IPython.display import display
|
||||
"""
|
||||
try:
|
||||
@@ -223,12 +279,16 @@ from IPython.display import display
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
if alias.name not in self.ALLOWED_IMPORTS:
|
||||
# 获取根包名 (e.g. sklearn.preprocessing -> sklearn)
|
||||
root_package = alias.name.split('.')[0]
|
||||
if root_package not in self.ALLOWED_IMPORTS and alias.name not in self.ALLOWED_IMPORTS:
|
||||
return False, f"不允许的导入: {alias.name}"
|
||||
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module not in self.ALLOWED_IMPORTS:
|
||||
return False, f"不允许的导入: {node.module}"
|
||||
if node.module:
|
||||
root_package = node.module.split('.')[0]
|
||||
if root_package not in self.ALLOWED_IMPORTS and node.module not in self.ALLOWED_IMPORTS:
|
||||
return False, f"不允许的导入: {node.module}"
|
||||
|
||||
# 检查属性访问(防止通过os.system等方式绕过)
|
||||
elif isinstance(node, ast.Attribute):
|
||||
@@ -296,6 +356,192 @@ from IPython.display import display
|
||||
|
||||
return str(obj)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_for_json(rows: List[Dict]) -> List[Dict]:
|
||||
"""Make evidence row values JSON-serializable.
|
||||
|
||||
Handles NaN/inf → None, Timestamp/datetime → isoformat string,
|
||||
numpy scalars → Python native types.
|
||||
"""
|
||||
import math
|
||||
sanitized = []
|
||||
for row in rows:
|
||||
clean = {}
|
||||
for k, v in row.items():
|
||||
if v is None:
|
||||
clean[k] = None
|
||||
elif isinstance(v, float) and (math.isnan(v) or math.isinf(v)):
|
||||
clean[k] = None
|
||||
elif hasattr(v, 'isoformat'): # Timestamp, datetime
|
||||
clean[k] = v.isoformat()
|
||||
elif hasattr(v, 'item'): # numpy scalar
|
||||
clean[k] = v.item()
|
||||
else:
|
||||
try:
|
||||
if pd.isna(v):
|
||||
clean[k] = None
|
||||
continue
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
clean[k] = v
|
||||
sanitized.append(clean)
|
||||
return sanitized
|
||||
|
||||
def _capture_evidence_rows(self, result, shell, df_snapshot_before=None) -> List[Dict]:
|
||||
"""
|
||||
Capture up to 10 evidence rows from the execution result.
|
||||
|
||||
Priority order:
|
||||
1. result.result if it's a DataFrame (direct code output)
|
||||
2. Smallest newly-created DataFrame this round (most likely an analysis result)
|
||||
3. Last DataFrame in namespace (fallback)
|
||||
"""
|
||||
try:
|
||||
# Primary: check if result.result is a DataFrame
|
||||
if result.result is not None and isinstance(result.result, pd.DataFrame):
|
||||
return self._sanitize_for_json(
|
||||
result.result.head(10).to_dict(orient="records")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Secondary: find the smallest NEW DataFrame created this round
|
||||
# (e.g. groupby result, crosstab, etc. — more relevant than the main df)
|
||||
if df_snapshot_before is not None:
|
||||
try:
|
||||
after = self._snapshot_dataframes(shell)
|
||||
new_names = [n for n in after if n not in df_snapshot_before]
|
||||
if new_names:
|
||||
# Pick the smallest new DataFrame (most likely a summary/aggregation)
|
||||
best_df = None
|
||||
best_size = float('inf')
|
||||
for name in new_names:
|
||||
try:
|
||||
obj = shell.user_ns[name]
|
||||
if isinstance(obj, pd.DataFrame) and len(obj) < best_size:
|
||||
best_df = obj
|
||||
best_size = len(obj)
|
||||
except Exception:
|
||||
continue
|
||||
if best_df is not None:
|
||||
return self._sanitize_for_json(
|
||||
best_df.head(10).to_dict(orient="records")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: find the last-assigned DataFrame variable in namespace
|
||||
try:
|
||||
last_df = None
|
||||
for name, obj in shell.user_ns.items():
|
||||
if (
|
||||
not name.startswith("_")
|
||||
and name not in self._SKIP_EXPORT_NAMES
|
||||
and isinstance(obj, pd.DataFrame)
|
||||
):
|
||||
last_df = obj
|
||||
if last_df is not None:
|
||||
return self._sanitize_for_json(
|
||||
last_df.head(10).to_dict(orient="records")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return []
|
||||
|
||||
def _snapshot_dataframes(self, shell) -> Dict[str, int]:
|
||||
"""Snapshot current DataFrame variables as {name: id(obj)}."""
|
||||
snapshot = {}
|
||||
try:
|
||||
for name, obj in shell.user_ns.items():
|
||||
if (
|
||||
not name.startswith("_")
|
||||
and name not in self._SKIP_EXPORT_NAMES
|
||||
and isinstance(obj, pd.DataFrame)
|
||||
):
|
||||
snapshot[name] = id(obj)
|
||||
except Exception:
|
||||
pass
|
||||
return snapshot
|
||||
|
||||
def _detect_new_dataframes(
|
||||
self, before: Dict[str, int], after: Dict[str, int]
|
||||
) -> List[str]:
|
||||
"""Return variable names of truly NEW DataFrames only.
|
||||
|
||||
Only returns names that did not exist in the before-snapshot.
|
||||
Changed DataFrames (same name, different id) are excluded to avoid
|
||||
re-exporting the main 'df' or other modified variables every round.
|
||||
"""
|
||||
new_only = []
|
||||
for name, obj_id in after.items():
|
||||
if name not in before:
|
||||
new_only.append(name)
|
||||
return new_only
|
||||
|
||||
def _export_dataframe(self, var_name: str, df) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Export a DataFrame to CSV with dedup suffix. Returns metadata dict or None.
|
||||
Skips export for DataFrames exceeding AUTO_EXPORT_MAX_ROWS to avoid
|
||||
heavy disk I/O on large datasets; only metadata is recorded.
|
||||
"""
|
||||
try:
|
||||
rows_count = len(df)
|
||||
cols_count = len(df.columns)
|
||||
col_names = list(df.columns)
|
||||
|
||||
# Skip writing large DataFrames to disk — record metadata only
|
||||
if rows_count > self.AUTO_EXPORT_MAX_ROWS:
|
||||
return {
|
||||
"variable_name": var_name,
|
||||
"filename": f"(skipped: {var_name} has {rows_count} rows)",
|
||||
"rows": rows_count,
|
||||
"cols": cols_count,
|
||||
"columns": col_names,
|
||||
"skipped": True,
|
||||
}
|
||||
|
||||
base_filename = f"{var_name}.csv"
|
||||
filepath = os.path.join(self.output_dir, base_filename)
|
||||
|
||||
# Dedup: if file exists, try _1, _2, ...
|
||||
if os.path.exists(filepath):
|
||||
suffix = 1
|
||||
while True:
|
||||
dedup_filename = f"{var_name}_{suffix}.csv"
|
||||
filepath = os.path.join(self.output_dir, dedup_filename)
|
||||
if not os.path.exists(filepath):
|
||||
base_filename = dedup_filename
|
||||
break
|
||||
suffix += 1
|
||||
|
||||
df.to_csv(filepath, index=False)
|
||||
return {
|
||||
"variable_name": var_name,
|
||||
"filename": base_filename,
|
||||
"rows": rows_count,
|
||||
"cols": cols_count,
|
||||
"columns": col_names,
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _parse_data_file_saved_markers(self, stdout_text: str) -> List[Dict[str, Any]]:
|
||||
"""Parse [DATA_FILE_SAVED] marker lines from captured stdout."""
|
||||
results = []
|
||||
try:
|
||||
for line in stdout_text.splitlines():
|
||||
m = self._DATA_FILE_SAVED_RE.search(line)
|
||||
if m:
|
||||
results.append({
|
||||
"filename": m.group(1).strip(),
|
||||
"rows": int(m.group(2)),
|
||||
"description": m.group(3).strip(),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
return results
|
||||
|
||||
def execute_code(self, code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
执行代码并返回结果
|
||||
@@ -308,7 +554,10 @@ from IPython.display import display
|
||||
'success': bool,
|
||||
'output': str,
|
||||
'error': str,
|
||||
'variables': Dict[str, Any] # 新生成的重要变量
|
||||
'variables': Dict[str, Any], # 新生成的重要变量
|
||||
'evidence_rows': List[Dict], # up to 10 evidence rows
|
||||
'auto_exported_files': List[Dict], # auto-detected DataFrame exports
|
||||
'prompt_saved_files': List[Dict], # parsed DATA_FILE_SAVED markers
|
||||
}
|
||||
"""
|
||||
# 检查代码安全性
|
||||
@@ -319,12 +568,18 @@ from IPython.display import display
|
||||
"output": "",
|
||||
"error": f"代码安全检查失败: {safety_error}",
|
||||
"variables": {},
|
||||
"evidence_rows": [],
|
||||
"auto_exported_files": [],
|
||||
"prompt_saved_files": [],
|
||||
}
|
||||
|
||||
# 记录执行前的变量
|
||||
vars_before = set(self.shell.user_ns.keys())
|
||||
|
||||
try:
|
||||
# --- Task 6.1: Snapshot DataFrame variables before execution ---
|
||||
df_snapshot_before = self._snapshot_dataframes(self.shell)
|
||||
|
||||
# 使用IPython的capture_output来捕获所有输出
|
||||
with capture_output() as captured:
|
||||
result = self.shell.run_cell(code)
|
||||
@@ -337,6 +592,9 @@ from IPython.display import display
|
||||
"output": captured.stdout,
|
||||
"error": f"执行前错误: {error_msg}",
|
||||
"variables": {},
|
||||
"evidence_rows": [],
|
||||
"auto_exported_files": [],
|
||||
"prompt_saved_files": self._parse_data_file_saved_markers(captured.stdout),
|
||||
}
|
||||
|
||||
if result.error_in_exec:
|
||||
@@ -346,6 +604,9 @@ from IPython.display import display
|
||||
"output": captured.stdout,
|
||||
"error": f"执行错误: {error_msg}",
|
||||
"variables": {},
|
||||
"evidence_rows": [],
|
||||
"auto_exported_files": [],
|
||||
"prompt_saved_files": self._parse_data_file_saved_markers(captured.stdout),
|
||||
}
|
||||
|
||||
# 获取输出
|
||||
@@ -374,11 +635,63 @@ from IPython.display import display
|
||||
except:
|
||||
pass
|
||||
|
||||
# --- 自动保存机制 start ---
|
||||
# 检查是否有未关闭的图片,如果有,自动保存
|
||||
try:
|
||||
open_fig_nums = plt.get_fignums()
|
||||
if open_fig_nums:
|
||||
for fig_num in open_fig_nums:
|
||||
fig = plt.figure(fig_num)
|
||||
# 生成自动保存的文件名
|
||||
auto_filename = f"autosave_fig_{self.image_counter}_{fig_num}.png"
|
||||
auto_filepath = os.path.join(self.output_dir, auto_filename)
|
||||
|
||||
try:
|
||||
# 尝试保存
|
||||
fig.savefig(auto_filepath, bbox_inches='tight')
|
||||
print(f"[CACHE] [Auto-Save] 检测到未闭合图表,已安全保存至: {auto_filepath}")
|
||||
|
||||
# 添加到输出中,告知Agent
|
||||
output += f"\n[Auto-Save] [WARN] 检测到Figure {fig_num}未关闭,系统已自动保存为: {auto_filename}"
|
||||
self.image_counter += 1
|
||||
except Exception as e:
|
||||
print(f"[WARN] [Auto-Save] 保存失败: {e}")
|
||||
finally:
|
||||
plt.close(fig_num)
|
||||
except Exception as e:
|
||||
print(f"[WARN] [Auto-Save Global] 异常: {e}")
|
||||
# --- 自动保存机制 end ---
|
||||
|
||||
# --- Task 5: Evidence capture ---
|
||||
evidence_rows = self._capture_evidence_rows(result, self.shell, df_snapshot_before)
|
||||
|
||||
# --- Task 6.2-6.4: DataFrame auto-detection and export ---
|
||||
auto_exported_files = []
|
||||
try:
|
||||
df_snapshot_after = self._snapshot_dataframes(self.shell)
|
||||
new_df_names = self._detect_new_dataframes(df_snapshot_before, df_snapshot_after)
|
||||
for var_name in new_df_names:
|
||||
try:
|
||||
df_obj = self.shell.user_ns[var_name]
|
||||
meta = self._export_dataframe(var_name, df_obj)
|
||||
if meta is not None:
|
||||
auto_exported_files.append(meta)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- Task 7: DATA_FILE_SAVED marker parsing ---
|
||||
prompt_saved_files = self._parse_data_file_saved_markers(captured.stdout)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output": output,
|
||||
"error": "",
|
||||
"variables": important_new_vars,
|
||||
"evidence_rows": evidence_rows,
|
||||
"auto_exported_files": auto_exported_files,
|
||||
"prompt_saved_files": prompt_saved_files,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
@@ -386,6 +699,9 @@ from IPython.display import display
|
||||
"output": captured.stdout if "captured" in locals() else "",
|
||||
"error": f"执行异常: {str(e)}\n{traceback.format_exc()}",
|
||||
"variables": {},
|
||||
"evidence_rows": [],
|
||||
"auto_exported_files": [],
|
||||
"prompt_saved_files": [],
|
||||
}
|
||||
|
||||
def reset_environment(self):
|
||||
|
||||
@@ -2,6 +2,17 @@
|
||||
import os
|
||||
import pandas as pd
|
||||
import io
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Optional, Iterator
|
||||
from config.app_config import app_config
|
||||
from utils.cache_manager import CacheManager
|
||||
|
||||
# 初始化缓存管理器
|
||||
data_cache = CacheManager(
|
||||
cache_dir=app_config.cache_dir,
|
||||
enabled=app_config.data_cache_enabled
|
||||
)
|
||||
|
||||
def load_and_profile_data(file_paths: list) -> str:
|
||||
"""
|
||||
@@ -23,7 +34,7 @@ def load_and_profile_data(file_paths: list) -> str:
|
||||
profile_summary += f"## 文件: {file_name}\n\n"
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
profile_summary += f"⚠️ 文件不存在: {file_path}\n\n"
|
||||
profile_summary += f"[WARN] 文件不存在: {file_path}\n\n"
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -41,7 +52,7 @@ def load_and_profile_data(file_paths: list) -> str:
|
||||
elif ext in ['.xlsx', '.xls']:
|
||||
df = pd.read_excel(file_path)
|
||||
else:
|
||||
profile_summary += f"⚠️ 不支持的文件格式: {ext}\n\n"
|
||||
profile_summary += f"[WARN] 不支持的文件格式: {ext}\n\n"
|
||||
continue
|
||||
|
||||
# 基础信息
|
||||
@@ -59,7 +70,7 @@ def load_and_profile_data(file_paths: list) -> str:
|
||||
|
||||
profile_summary += f"#### {col} ({dtype})\n"
|
||||
if null_count > 0:
|
||||
profile_summary += f"- ⚠️ 空值: {null_count} ({null_ratio:.1f}%)\n"
|
||||
profile_summary += f"- [WARN] 空值: {null_count} ({null_ratio:.1f}%)\n"
|
||||
|
||||
# 数值列分析
|
||||
if pd.api.types.is_numeric_dtype(dtype):
|
||||
@@ -85,6 +96,227 @@ def load_and_profile_data(file_paths: list) -> str:
|
||||
profile_summary += "\n"
|
||||
|
||||
except Exception as e:
|
||||
profile_summary += f"❌ 读取或分析文件失败: {str(e)}\n\n"
|
||||
profile_summary += f"[ERROR] 读取或分析文件失败: {str(e)}\n\n"
|
||||
|
||||
return profile_summary
|
||||
|
||||
|
||||
def get_file_hash(file_path: str) -> str:
|
||||
"""计算文件哈希值,用于缓存键"""
|
||||
hasher = hashlib.md5()
|
||||
hasher.update(file_path.encode())
|
||||
|
||||
# 添加文件修改时间
|
||||
if os.path.exists(file_path):
|
||||
mtime = os.path.getmtime(file_path)
|
||||
hasher.update(str(mtime).encode())
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def load_data_chunked(file_path: str, chunksize: Optional[int] = None) -> Iterator[pd.DataFrame]:
|
||||
"""
|
||||
流式读取大文件,分块返回DataFrame
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
chunksize: 每块行数,默认使用配置值
|
||||
|
||||
Yields:
|
||||
DataFrame块
|
||||
"""
|
||||
if chunksize is None:
|
||||
chunksize = app_config.chunk_size
|
||||
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
if ext == '.csv':
|
||||
# 尝试多种编码
|
||||
for encoding in ['utf-8', 'gbk', 'latin1']:
|
||||
try:
|
||||
chunks = pd.read_csv(file_path, encoding=encoding, chunksize=chunksize)
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 读取CSV文件失败: {e}")
|
||||
break
|
||||
elif ext in ['.xlsx', '.xls']:
|
||||
# Excel文件不支持chunksize,直接读取
|
||||
try:
|
||||
df = pd.read_excel(file_path)
|
||||
# 手动分块
|
||||
for i in range(0, len(df), chunksize):
|
||||
yield df.iloc[i:i+chunksize]
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 读取Excel文件失败: {e}")
|
||||
|
||||
|
||||
def _profile_chunked(file_path: str) -> str:
|
||||
"""
|
||||
Profile a large file by reading the first chunk plus sampled subsequent chunks.
|
||||
|
||||
Uses ``load_data_chunked()`` to stream the file. The first chunk is kept
|
||||
in full; every 5th subsequent chunk contributes up to 100 sampled rows.
|
||||
A markdown profile is generated from the combined sample.
|
||||
|
||||
Args:
|
||||
file_path: Path to the data file.
|
||||
|
||||
Returns:
|
||||
A markdown string containing the sampled profile for this file.
|
||||
"""
|
||||
file_name = os.path.basename(file_path)
|
||||
chunks_iter = load_data_chunked(file_path)
|
||||
first_chunk = next(chunks_iter, None)
|
||||
if first_chunk is None:
|
||||
return f"## 文件: {file_name}\n\n[ERROR] 无法读取文件: {file_path}\n\n"
|
||||
|
||||
sample_parts = [first_chunk]
|
||||
for i, chunk in enumerate(chunks_iter):
|
||||
if i % 5 == 0: # sample every 5th subsequent chunk
|
||||
sample_parts.append(chunk.head(min(100, len(chunk))))
|
||||
|
||||
combined = pd.concat(sample_parts, ignore_index=True)
|
||||
|
||||
# Build profile from the combined sample
|
||||
profile = f"## 文件: {file_name}\n\n"
|
||||
profile += f"- **注意**: 此画像基于抽样数据生成(首块 + 每5块采样100行)\n"
|
||||
rows, cols = combined.shape
|
||||
profile += f"- **样本维度**: {rows} 行 x {cols} 列\n"
|
||||
profile += f"- **列名**: `{', '.join(combined.columns)}`\n\n"
|
||||
profile += "### 列详细分布:\n"
|
||||
|
||||
for col in combined.columns:
|
||||
dtype = combined[col].dtype
|
||||
null_count = combined[col].isnull().sum()
|
||||
null_ratio = (null_count / rows) * 100 if rows > 0 else 0
|
||||
|
||||
profile += f"#### {col} ({dtype})\n"
|
||||
if null_count > 0:
|
||||
profile += f"- [WARN] 空值: {null_count} ({null_ratio:.1f}%)\n"
|
||||
|
||||
if pd.api.types.is_numeric_dtype(dtype):
|
||||
desc = combined[col].describe()
|
||||
profile += f"- 统计: Min={desc['min']:.2f}, Max={desc['max']:.2f}, Mean={desc['mean']:.2f}\n"
|
||||
elif pd.api.types.is_object_dtype(dtype) or pd.api.types.is_categorical_dtype(dtype):
|
||||
unique_count = combined[col].nunique()
|
||||
profile += f"- 唯一值数量: {unique_count}\n"
|
||||
if unique_count > 0:
|
||||
top_n = combined[col].value_counts().head(5)
|
||||
top_items_str = ", ".join([f"{k}({v})" for k, v in top_n.items()])
|
||||
profile += f"- **TOP 5 高频值**: {top_items_str}\n"
|
||||
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
||||
profile += f"- 范围: {combined[col].min()} 至 {combined[col].max()}\n"
|
||||
|
||||
profile += "\n"
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
def load_and_profile_data_smart(file_paths: list, max_file_size_mb: int = None) -> str:
|
||||
"""
|
||||
Smart data loader: selects chunked profiling for large files and full
|
||||
profiling for small files based on a size threshold.
|
||||
|
||||
Args:
|
||||
file_paths: List of file paths to profile.
|
||||
max_file_size_mb: Size threshold in MB. Files larger than this use
|
||||
chunked profiling. Defaults to ``app_config.max_file_size_mb``.
|
||||
|
||||
Returns:
|
||||
A markdown string containing the combined data profile.
|
||||
"""
|
||||
if max_file_size_mb is None:
|
||||
max_file_size_mb = app_config.max_file_size_mb
|
||||
|
||||
profile_summary = "# 数据画像报告 (Data Profile)\n\n"
|
||||
|
||||
if not file_paths:
|
||||
return profile_summary + "未提供数据文件。"
|
||||
|
||||
for file_path in file_paths:
|
||||
if not os.path.exists(file_path):
|
||||
profile_summary += f"## 文件: {os.path.basename(file_path)}\n\n"
|
||||
profile_summary += f"[WARN] 文件不存在: {file_path}\n\n"
|
||||
continue
|
||||
|
||||
try:
|
||||
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
|
||||
if file_size_mb > max_file_size_mb:
|
||||
profile_summary += _profile_chunked(file_path)
|
||||
else:
|
||||
# Use existing full-load profiling for this single file
|
||||
profile_summary += load_and_profile_data([file_path]).replace(
|
||||
"# 数据画像报告 (Data Profile)\n\n", ""
|
||||
)
|
||||
except Exception as e:
|
||||
profile_summary += f"## 文件: {os.path.basename(file_path)}\n\n"
|
||||
profile_summary += f"[ERROR] 读取或分析文件失败: {str(e)}\n\n"
|
||||
|
||||
return profile_summary
|
||||
|
||||
|
||||
def load_data_with_cache(file_path: str, force_reload: bool = False) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
带缓存的数据加载
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
force_reload: 是否强制重新加载
|
||||
|
||||
Returns:
|
||||
DataFrame或None
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"[WARN] 文件不存在: {file_path}")
|
||||
return None
|
||||
|
||||
# 检查文件大小
|
||||
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
|
||||
|
||||
# 对于大文件,建议使用流式处理
|
||||
if file_size_mb > app_config.max_file_size_mb:
|
||||
print(f"[WARN] 文件过大 ({file_size_mb:.1f}MB),建议使用 load_data_chunked() 流式处理")
|
||||
|
||||
# 生成缓存键
|
||||
cache_key = get_file_hash(file_path)
|
||||
|
||||
# 尝试从缓存加载
|
||||
if not force_reload and app_config.data_cache_enabled:
|
||||
cached_data = data_cache.get(cache_key)
|
||||
if cached_data is not None:
|
||||
print(f"[CACHE] 从缓存加载数据: {os.path.basename(file_path)}")
|
||||
return cached_data
|
||||
|
||||
# 加载数据
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
df = None
|
||||
|
||||
try:
|
||||
if ext == '.csv':
|
||||
# 尝试多种编码
|
||||
for encoding in ['utf-8', 'gbk', 'latin1']:
|
||||
try:
|
||||
df = pd.read_csv(file_path, encoding=encoding)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
elif ext in ['.xlsx', '.xls']:
|
||||
df = pd.read_excel(file_path)
|
||||
else:
|
||||
print(f"[WARN] 不支持的文件格式: {ext}")
|
||||
return None
|
||||
|
||||
# 缓存数据
|
||||
if df is not None and app_config.data_cache_enabled:
|
||||
data_cache.set(cache_key, df)
|
||||
print(f"[OK] 数据已缓存: {os.path.basename(file_path)}")
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 加载数据失败: {e}")
|
||||
return None
|
||||
|
||||
301
utils/data_privacy.py
Normal file
301
utils/data_privacy.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据隐私保护层
|
||||
|
||||
核心原则:发给外部 LLM 的信息只包含 schema 级别的元数据,
|
||||
绝不包含真实数据值。所有真实数据仅在本地代码执行环境中使用。
|
||||
|
||||
分级策略:
|
||||
- SAFE(安全级): 可发送给 LLM — 列名、数据类型、行列数、空值率、唯一值数量
|
||||
- LOCAL(本地级): 仅本地使用 — 真实数据值、TOP N 高频值、统计数值、样本行
|
||||
"""
|
||||
|
||||
import re
|
||||
import pandas as pd
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
def build_safe_profile(file_paths: list) -> str:
|
||||
"""
|
||||
生成可安全发送给外部 LLM 的数据画像。
|
||||
只包含 schema 信息,不包含任何真实数据值。
|
||||
|
||||
Args:
|
||||
file_paths: 数据文件路径列表
|
||||
|
||||
Returns:
|
||||
安全的 Markdown 格式数据画像
|
||||
"""
|
||||
import os
|
||||
|
||||
profile = "# 数据结构概览 (Schema Profile)\n\n"
|
||||
|
||||
if not file_paths:
|
||||
return profile + "未提供数据文件。"
|
||||
|
||||
for file_path in file_paths:
|
||||
file_name = os.path.basename(file_path)
|
||||
profile += f"## 文件: {file_name}\n\n"
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
profile += f"[WARN] 文件不存在: {file_path}\n\n"
|
||||
continue
|
||||
|
||||
try:
|
||||
df = _load_dataframe(file_path)
|
||||
if df is None:
|
||||
continue
|
||||
|
||||
rows, cols = df.shape
|
||||
profile += f"- **维度**: {rows} 行 x {cols} 列\n"
|
||||
profile += f"- **列名**: `{', '.join(df.columns)}`\n\n"
|
||||
profile += "### 列结构:\n\n"
|
||||
profile += "| 列名 | 数据类型 | 空值率 | 唯一值数 | 特征描述 |\n"
|
||||
profile += "|------|---------|--------|---------|----------|\n"
|
||||
|
||||
for col in df.columns:
|
||||
dtype = str(df[col].dtype)
|
||||
null_count = df[col].isnull().sum()
|
||||
null_pct = f"{(null_count / rows) * 100:.1f}%" if rows > 0 else "0%"
|
||||
unique_count = df[col].nunique()
|
||||
|
||||
# 特征描述:只描述数据特征,不暴露具体值
|
||||
feature_desc = _describe_column_safe(df[col], unique_count, rows)
|
||||
|
||||
profile += f"| {col} | {dtype} | {null_pct} | {unique_count} | {feature_desc} |\n"
|
||||
|
||||
profile += "\n"
|
||||
|
||||
except Exception as e:
|
||||
profile += f"[ERROR] 读取文件失败: {str(e)}\n\n"
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
def build_local_profile(file_paths: list) -> str:
|
||||
"""
|
||||
生成完整的本地数据画像(包含真实数据值)。
|
||||
仅用于本地代码执行环境,不发送给 LLM。
|
||||
|
||||
这是原来 load_and_profile_data 的功能,保留完整信息。
|
||||
"""
|
||||
from utils.data_loader import load_and_profile_data
|
||||
return load_and_profile_data(file_paths)
|
||||
|
||||
|
||||
def sanitize_execution_feedback(feedback: str, max_lines: int = 30) -> str:
|
||||
"""
|
||||
对代码执行反馈进行脱敏处理,移除可能包含真实数据的内容。
|
||||
|
||||
保留:
|
||||
- 执行状态(成功/失败)
|
||||
- 错误信息
|
||||
- DataFrame 的 shape 信息
|
||||
- 图片保存路径
|
||||
- 列名信息
|
||||
|
||||
移除/截断:
|
||||
- 具体的数据行(DataFrame 输出)
|
||||
- 大段的数值输出
|
||||
|
||||
Args:
|
||||
feedback: 原始执行反馈
|
||||
max_lines: 最大保留行数
|
||||
|
||||
Returns:
|
||||
脱敏后的反馈
|
||||
"""
|
||||
if not feedback:
|
||||
return feedback
|
||||
|
||||
lines = feedback.split("\n")
|
||||
safe_lines = []
|
||||
in_dataframe_output = False
|
||||
df_line_count = 0
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
# 始终保留的关键信息
|
||||
if any(kw in stripped for kw in [
|
||||
"图片已保存", "保存至", "[OK]", "[WARN]", "[ERROR]",
|
||||
"[Auto-Save]", "数据表形状", "列名:", ".png",
|
||||
"shape", "columns", "dtype", "info()", "describe()",
|
||||
]):
|
||||
safe_lines.append(line)
|
||||
in_dataframe_output = False
|
||||
continue
|
||||
|
||||
# 检测 DataFrame 输出的开始(通常有列头行)
|
||||
if _looks_like_dataframe_row(stripped):
|
||||
if not in_dataframe_output:
|
||||
in_dataframe_output = True
|
||||
df_line_count = 0
|
||||
safe_lines.append("[数据输出已省略 - 数据仅在本地执行环境中可见]")
|
||||
df_line_count += 1
|
||||
continue
|
||||
|
||||
# 检测纯数值行
|
||||
if _is_numeric_heavy_line(stripped):
|
||||
if not in_dataframe_output:
|
||||
in_dataframe_output = True
|
||||
safe_lines.append("[数值输出已省略]")
|
||||
continue
|
||||
|
||||
# 普通文本行
|
||||
in_dataframe_output = False
|
||||
safe_lines.append(line)
|
||||
|
||||
# 限制总行数
|
||||
if len(safe_lines) > max_lines:
|
||||
safe_lines = safe_lines[:max_lines]
|
||||
safe_lines.append(f"[... 输出已截断,共 {len(lines)} 行]")
|
||||
|
||||
return "\n".join(safe_lines)
|
||||
|
||||
|
||||
def _extract_column_from_error(error_message: str) -> Optional[str]:
|
||||
"""Extract column name from error message patterns like KeyError: 'col_name'.
|
||||
|
||||
Supports:
|
||||
- KeyError: 'column_name' or KeyError: "column_name"
|
||||
- column 'column_name' or column "column_name" (case-insensitive)
|
||||
|
||||
Returns:
|
||||
The extracted column name, or None if no column reference is found.
|
||||
"""
|
||||
match = re.search(r"KeyError:\s*['\"](.+?)['\"]", error_message)
|
||||
if match:
|
||||
return match.group(1)
|
||||
match = re.search(r"column\s+['\"](.+?)['\"]", error_message, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
|
||||
def _lookup_column_in_profile(column_name: Optional[str], safe_profile: str) -> Optional[dict]:
|
||||
"""Look up column metadata in the safe profile markdown table.
|
||||
|
||||
Parses the markdown table rows produced by build_safe_profile() and returns
|
||||
a dict with keys: dtype, null_rate, unique_count, description.
|
||||
|
||||
Args:
|
||||
column_name: The column name to look up (may be None).
|
||||
safe_profile: The safe profile markdown string.
|
||||
|
||||
Returns:
|
||||
A dict of column metadata, or None if not found.
|
||||
"""
|
||||
if not column_name:
|
||||
return None
|
||||
for line in safe_profile.split("\n"):
|
||||
if line.startswith("|") and column_name in line:
|
||||
parts = [p.strip() for p in line.split("|") if p.strip()]
|
||||
if len(parts) >= 5 and parts[0] == column_name:
|
||||
return {
|
||||
"dtype": parts[1],
|
||||
"null_rate": parts[2],
|
||||
"unique_count": parts[3],
|
||||
"description": parts[4],
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def generate_enriched_hint(error_message: str, safe_profile: str) -> str:
|
||||
"""Generate an enriched hint from the safe profile for a data-context error.
|
||||
|
||||
Extracts the referenced column name from the error, looks it up in the safe
|
||||
profile markdown table, and returns a hint string containing only schema-level
|
||||
metadata — no real data values.
|
||||
|
||||
Args:
|
||||
error_message: The error message from code execution.
|
||||
safe_profile: The safe profile markdown string.
|
||||
|
||||
Returns:
|
||||
A hint string with retry context and column metadata (if found).
|
||||
"""
|
||||
column_name = _extract_column_from_error(error_message)
|
||||
column_meta = _lookup_column_in_profile(column_name, safe_profile)
|
||||
|
||||
hint = "[RETRY CONTEXT] 上一次代码执行因数据上下文错误失败。\n"
|
||||
hint += f"错误信息: {error_message}\n"
|
||||
if column_meta:
|
||||
hint += f"相关列 '{column_name}' 的结构信息:\n"
|
||||
hint += f" - 数据类型: {column_meta['dtype']}\n"
|
||||
hint += f" - 唯一值数量: {column_meta['unique_count']}\n"
|
||||
hint += f" - 空值率: {column_meta['null_rate']}\n"
|
||||
hint += f" - 特征描述: {column_meta['description']}\n"
|
||||
hint += "请根据以上结构信息修正代码,不要假设具体的数据值。"
|
||||
return hint
|
||||
|
||||
|
||||
def _load_dataframe(file_path: str):
|
||||
"""加载 DataFrame,支持多种格式和编码"""
|
||||
import os
|
||||
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
if ext == ".csv":
|
||||
for encoding in ["utf-8", "gbk", "gb18030", "latin1"]:
|
||||
try:
|
||||
return pd.read_csv(file_path, encoding=encoding)
|
||||
except (UnicodeDecodeError, Exception):
|
||||
continue
|
||||
elif ext in [".xlsx", ".xls"]:
|
||||
try:
|
||||
return pd.read_excel(file_path)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _describe_column_safe(series: pd.Series, unique_count: int, total_rows: int) -> str:
|
||||
"""安全地描述列特征,不暴露具体值"""
|
||||
dtype = series.dtype
|
||||
|
||||
if pd.api.types.is_numeric_dtype(dtype):
|
||||
if unique_count <= 5:
|
||||
return "低基数数值(可能是分类编码)"
|
||||
elif unique_count < total_rows * 0.05:
|
||||
return "离散数值"
|
||||
else:
|
||||
return "连续数值"
|
||||
|
||||
if pd.api.types.is_datetime64_any_dtype(dtype):
|
||||
return "时间序列"
|
||||
|
||||
# 文本/分类列
|
||||
if unique_count == 1:
|
||||
return "单一值(常量列)"
|
||||
elif unique_count <= 10:
|
||||
return f"低基数分类({unique_count}类)"
|
||||
elif unique_count <= 50:
|
||||
return f"中基数分类({unique_count}类)"
|
||||
elif unique_count > total_rows * 0.8:
|
||||
return "高基数文本(可能是ID或描述)"
|
||||
else:
|
||||
return f"文本分类({unique_count}类)"
|
||||
|
||||
|
||||
def _looks_like_dataframe_row(line: str) -> bool:
|
||||
"""判断一行是否看起来像 DataFrame 输出"""
|
||||
if not line:
|
||||
return False
|
||||
# DataFrame 输出通常有多个空格分隔的列
|
||||
parts = line.split()
|
||||
if len(parts) >= 3:
|
||||
# 第一个元素是索引(数字)
|
||||
try:
|
||||
int(parts[0])
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def _is_numeric_heavy_line(line: str) -> bool:
|
||||
"""判断一行是否主要由数值组成"""
|
||||
if not line or len(line) < 5:
|
||||
return False
|
||||
digits_and_dots = sum(1 for c in line if c.isdigit() or c in ".,-+eE ")
|
||||
return digits_and_dots / len(line) > 0.7
|
||||
224
utils/data_quality.py
Normal file
224
utils/data_quality.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据质量检查模块 - 自动评估数据质量并提供改进建议
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class QualityIssue:
|
||||
"""数据质量问题"""
|
||||
column: str
|
||||
issue_type: str # missing, duplicate, outlier, type_mismatch等
|
||||
severity: str # high, medium, low
|
||||
description: str
|
||||
suggestion: str
|
||||
|
||||
|
||||
class DataQualityChecker:
|
||||
"""数据质量检查器"""
|
||||
|
||||
def __init__(self, df: pd.DataFrame):
|
||||
self.df = df
|
||||
self.issues: List[QualityIssue] = []
|
||||
self.quality_score: float = 100.0
|
||||
|
||||
def check_all(self) -> Dict[str, Any]:
|
||||
"""执行所有质量检查"""
|
||||
self.check_missing_values()
|
||||
self.check_duplicates()
|
||||
self.check_data_types()
|
||||
self.check_outliers()
|
||||
self.check_consistency()
|
||||
|
||||
return self.generate_report()
|
||||
|
||||
def check_missing_values(self) -> None:
|
||||
"""检查缺失值"""
|
||||
for col in self.df.columns:
|
||||
missing_count = self.df[col].isnull().sum()
|
||||
missing_ratio = (missing_count / len(self.df)) * 100
|
||||
|
||||
if missing_ratio > 50:
|
||||
severity = "high"
|
||||
self.quality_score -= 10
|
||||
elif missing_ratio > 20:
|
||||
severity = "medium"
|
||||
self.quality_score -= 5
|
||||
elif missing_ratio > 0:
|
||||
severity = "low"
|
||||
self.quality_score -= 2
|
||||
else:
|
||||
continue
|
||||
|
||||
issue = QualityIssue(
|
||||
column=col,
|
||||
issue_type="missing",
|
||||
severity=severity,
|
||||
description=f"列 '{col}' 存在 {missing_count} 个缺失值 ({missing_ratio:.1f}%)",
|
||||
suggestion=self._suggest_missing_handling(col, missing_ratio)
|
||||
)
|
||||
self.issues.append(issue)
|
||||
|
||||
def check_duplicates(self) -> None:
|
||||
"""检查重复数据"""
|
||||
duplicate_count = self.df.duplicated().sum()
|
||||
if duplicate_count > 0:
|
||||
duplicate_ratio = (duplicate_count / len(self.df)) * 100
|
||||
|
||||
severity = "high" if duplicate_ratio > 10 else "medium"
|
||||
self.quality_score -= 5 if severity == "high" else 3
|
||||
|
||||
issue = QualityIssue(
|
||||
column="全表",
|
||||
issue_type="duplicate",
|
||||
severity=severity,
|
||||
description=f"发现 {duplicate_count} 行重复数据 ({duplicate_ratio:.1f}%)",
|
||||
suggestion="建议使用 df.drop_duplicates() 删除重复行,或检查是否为合理的重复记录"
|
||||
)
|
||||
self.issues.append(issue)
|
||||
|
||||
def check_data_types(self) -> None:
|
||||
"""检查数据类型一致性"""
|
||||
for col in self.df.columns:
|
||||
# 检查是否有数值列被识别为object
|
||||
if self.df[col].dtype == 'object':
|
||||
try:
|
||||
# 尝试转换为数值
|
||||
pd.to_numeric(self.df[col].dropna(), errors='raise')
|
||||
|
||||
issue = QualityIssue(
|
||||
column=col,
|
||||
issue_type="type_mismatch",
|
||||
severity="medium",
|
||||
description=f"列 '{col}' 当前为文本类型,但可以转换为数值类型",
|
||||
suggestion=f"建议使用 df['{col}'] = pd.to_numeric(df['{col}']) 转换类型"
|
||||
)
|
||||
self.issues.append(issue)
|
||||
self.quality_score -= 3
|
||||
except:
|
||||
pass
|
||||
|
||||
def check_outliers(self) -> None:
|
||||
"""检查数值列的异常值"""
|
||||
numeric_cols = self.df.select_dtypes(include=[np.number]).columns
|
||||
|
||||
for col in numeric_cols:
|
||||
q1 = self.df[col].quantile(0.25)
|
||||
q3 = self.df[col].quantile(0.75)
|
||||
iqr = q3 - q1
|
||||
|
||||
lower_bound = q1 - 3 * iqr
|
||||
upper_bound = q3 + 3 * iqr
|
||||
|
||||
outliers = self.df[(self.df[col] < lower_bound) | (self.df[col] > upper_bound)]
|
||||
outlier_count = len(outliers)
|
||||
|
||||
if outlier_count > 0:
|
||||
outlier_ratio = (outlier_count / len(self.df)) * 100
|
||||
|
||||
if outlier_ratio > 5:
|
||||
severity = "medium"
|
||||
self.quality_score -= 3
|
||||
else:
|
||||
severity = "low"
|
||||
self.quality_score -= 1
|
||||
|
||||
issue = QualityIssue(
|
||||
column=col,
|
||||
issue_type="outlier",
|
||||
severity=severity,
|
||||
description=f"列 '{col}' 存在 {outlier_count} 个异常值 ({outlier_ratio:.1f}%)",
|
||||
suggestion=f"建议检查 {lower_bound:.2f} 以下和 {upper_bound:.2f} 以上的值是否合理"
|
||||
)
|
||||
self.issues.append(issue)
|
||||
|
||||
def check_consistency(self) -> None:
|
||||
"""检查数据一致性"""
|
||||
# 检查时间列的时序性
|
||||
datetime_cols = self.df.select_dtypes(include=['datetime64']).columns
|
||||
|
||||
for col in datetime_cols:
|
||||
if not self.df[col].is_monotonic_increasing:
|
||||
issue = QualityIssue(
|
||||
column=col,
|
||||
issue_type="consistency",
|
||||
severity="medium",
|
||||
description=f"时间列 '{col}' 不是单调递增的,可能存在乱序",
|
||||
suggestion=f"建议使用 df.sort_values('{col}') 进行排序"
|
||||
)
|
||||
self.issues.append(issue)
|
||||
self.quality_score -= 3
|
||||
|
||||
def _suggest_missing_handling(self, col: str, missing_ratio: float) -> str:
|
||||
"""建议缺失值处理方法"""
|
||||
if missing_ratio > 70:
|
||||
return f"缺失比例过高,建议删除列 '{col}'"
|
||||
elif missing_ratio > 30:
|
||||
return f"建议填充或删除缺失值:使用中位数/众数填充或删除含缺失值的行"
|
||||
else:
|
||||
if pd.api.types.is_numeric_dtype(self.df[col]):
|
||||
return f"建议使用均值/中位数填充:df['{col}'].fillna(df['{col}'].median())"
|
||||
else:
|
||||
return f"建议使用众数填充:df['{col}'].fillna(df['{col}'].mode()[0])"
|
||||
|
||||
def generate_report(self) -> Dict[str, Any]:
|
||||
"""生成质量报告"""
|
||||
# 确保质量分数在0-100之间
|
||||
self.quality_score = max(0, min(100, self.quality_score))
|
||||
|
||||
# 按严重程度分类
|
||||
high_issues = [i for i in self.issues if i.severity == "high"]
|
||||
medium_issues = [i for i in self.issues if i.severity == "medium"]
|
||||
low_issues = [i for i in self.issues if i.severity == "low"]
|
||||
|
||||
return {
|
||||
"quality_score": round(self.quality_score, 2),
|
||||
"total_issues": len(self.issues),
|
||||
"high_severity": len(high_issues),
|
||||
"medium_severity": len(medium_issues),
|
||||
"low_severity": len(low_issues),
|
||||
"issues": self.issues,
|
||||
"summary": self._generate_summary()
|
||||
}
|
||||
|
||||
def _generate_summary(self) -> str:
|
||||
"""生成可读的摘要"""
|
||||
summary = f"## 数据质量报告\n\n"
|
||||
summary += f"**质量评分**: {self.quality_score:.1f}/100\n\n"
|
||||
|
||||
if self.quality_score >= 90:
|
||||
summary += "[OK] **评级**: 优秀 - 数据质量很好\n\n"
|
||||
elif self.quality_score >= 75:
|
||||
summary += "[WARN] **评级**: 良好 - 存在一些小问题\n\n"
|
||||
elif self.quality_score >= 60:
|
||||
summary += "[WARN] **评级**: 一般 - 需要处理多个问题\n\n"
|
||||
else:
|
||||
summary += "[ERROR] **评级**: 差 - 数据质量问题严重\n\n"
|
||||
|
||||
summary += f"**问题统计**: 共 {len(self.issues)} 个质量问题\n"
|
||||
summary += f"- [RED] 高严重性: {len([i for i in self.issues if i.severity == 'high'])} 个\n"
|
||||
summary += f"- [YELLOW] 中严重性: {len([i for i in self.issues if i.severity == 'medium'])} 个\n"
|
||||
summary += f"- [GREEN] 低严重性: {len([i for i in self.issues if i.severity == 'low'])} 个\n\n"
|
||||
|
||||
if self.issues:
|
||||
summary += "### 主要问题:\n\n"
|
||||
# 只显示高和中严重性的问题
|
||||
for issue in self.issues:
|
||||
if issue.severity in ["high", "medium"]:
|
||||
emoji = "[RED]" if issue.severity == "high" else "[YELLOW]"
|
||||
summary += f"{emoji} **{issue.column}** - {issue.description}\n"
|
||||
summary += f" [TIP] {issue.suggestion}\n\n"
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def quick_quality_check(df: pd.DataFrame) -> str:
|
||||
"""快速数据质量检查"""
|
||||
checker = DataQualityChecker(df)
|
||||
report = checker.check_all()
|
||||
return report['summary']
|
||||
@@ -29,6 +29,22 @@ def extract_code_from_response(response: str) -> Optional[str]:
|
||||
end = response.find('```', start)
|
||||
if end != -1:
|
||||
return response[start:end].strip()
|
||||
|
||||
# 尝试提取 code: | 形式的代码块(针对YAML格式错误但结构清晰的情况)
|
||||
import re
|
||||
# 匹配 code: | 后面的内容,直到遇到下一个键(next_key:)或结尾
|
||||
# 假设代码块至少缩进2个空格
|
||||
pattern = r'code:\s*\|\s*\n((?: {2,}.*\n?)+)'
|
||||
match = re.search(pattern, response)
|
||||
if match:
|
||||
code_block = match.group(1)
|
||||
# 尝试去除公共缩进
|
||||
try:
|
||||
import textwrap
|
||||
return textwrap.dedent(code_block).strip()
|
||||
except:
|
||||
return code_block.strip()
|
||||
|
||||
elif '```' in response:
|
||||
start = response.find('```') + 3
|
||||
end = response.find('```', start)
|
||||
|
||||
@@ -57,7 +57,7 @@ class AsyncFallbackOpenAIClient:
|
||||
self.fallback_client = AsyncOpenAI(api_key=fallback_api_key, base_url=fallback_base_url, **_fallback_args)
|
||||
self.fallback_model_name = fallback_model_name
|
||||
else:
|
||||
print("⚠️ 警告: 未完全配置备用 API 客户端。如果主 API 失败,将无法进行回退。")
|
||||
print("[WARN] 警告: 未完全配置备用 API 客户端。如果主 API 失败,将无法进行回退。")
|
||||
|
||||
self.content_filter_error_code = content_filter_error_code
|
||||
self.content_filter_error_field = content_filter_error_field
|
||||
@@ -90,35 +90,60 @@ class AsyncFallbackOpenAIClient:
|
||||
return completion
|
||||
except (APIConnectionError, APITimeoutError) as e: # 通常可以重试的网络错误
|
||||
last_exception = e
|
||||
print(f"⚠️ {api_name} API 调用时发生可重试错误 ({type(e).__name__}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
|
||||
print(f"[WARN] {api_name} API 调用时发生可重试错误 ({type(e).__name__}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1)) # 增加延迟
|
||||
else:
|
||||
print(f"❌ {api_name} API 在达到最大重试次数后仍然失败。")
|
||||
print(f"[ERROR] {api_name} API 在达到最大重试次数后仍然失败。")
|
||||
except APIStatusError as e: # API 返回的特定状态码错误
|
||||
is_content_filter_error = False
|
||||
if e.status_code == 400:
|
||||
try:
|
||||
error_json = e.response.json()
|
||||
error_details = error_json.get("error", {})
|
||||
if (error_details.get("code") == self.content_filter_error_code and
|
||||
self.content_filter_error_field in error_json):
|
||||
is_content_filter_error = True
|
||||
except Exception:
|
||||
pass # 解析错误响应失败,不认为是内容过滤错误
|
||||
retry_after = None
|
||||
|
||||
# 尝试解析错误详情以获取更多信息(如 Google RPC RetryInfo)
|
||||
try:
|
||||
error_json = e.response.json()
|
||||
error_details = error_json.get("error", {})
|
||||
|
||||
# 检查内容过滤错误(针对特定服务商)
|
||||
if (error_details.get("code") == self.content_filter_error_code and
|
||||
self.content_filter_error_field in error_json):
|
||||
is_content_filter_error = True
|
||||
|
||||
# 检查 Google RPC RetryInfo
|
||||
# 格式示例: {'error': {'details': [{'@type': 'type.googleapis.com/google.rpc.RetryInfo', 'retryDelay': '38s'}]}}
|
||||
if "details" in error_details:
|
||||
for detail in error_details["details"]:
|
||||
if detail.get("@type") == "type.googleapis.com/google.rpc.RetryInfo":
|
||||
delay_str = detail.get("retryDelay", "")
|
||||
if delay_str.endswith("s"):
|
||||
try:
|
||||
retry_after = float(delay_str[:-1])
|
||||
print(f"[TIMER] 收到服务器 RetryInfo,等待时间: {retry_after}秒")
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception:
|
||||
pass # 解析错误响应失败,忽略
|
||||
|
||||
if is_content_filter_error and api_name == "主": # 如果是主 API 的内容过滤错误,则直接抛出以便回退
|
||||
raise e
|
||||
|
||||
last_exception = e
|
||||
print(f"⚠️ {api_name} API 调用时发生 APIStatusError ({e.status_code}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
|
||||
print(f"[WARN] {api_name} API 调用时发生 APIStatusError ({e.status_code}): {e}. 尝试次数 {attempt + 1}/{max_retries + 1}")
|
||||
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(self.retry_delay_seconds * (attempt + 1))
|
||||
# 如果获取到了明确的 retry_after,则使用它;否则使用默认的指数退避
|
||||
wait_time = retry_after if retry_after is not None else (self.retry_delay_seconds * (attempt + 1))
|
||||
# 如果是 429 Too Many Requests 且没有解析出 retry_after,建议加大等待时间
|
||||
if e.status_code == 429 and retry_after is None:
|
||||
wait_time = max(wait_time, 5.0 * (attempt + 1)) # 429 默认至少等 5 秒
|
||||
|
||||
print(f"[WAIT] 将等待 {wait_time:.2f} 秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
print(f"❌ {api_name} API 在达到最大重试次数后仍然失败 (APIStatusError)。")
|
||||
print(f"[ERROR] {api_name} API 在达到最大重试次数后仍然失败 (APIStatusError)。")
|
||||
except APIError as e: # 其他不可轻易重试的 OpenAI 错误
|
||||
last_exception = e
|
||||
print(f"❌ {api_name} API 调用时发生不可重试错误 ({type(e).__name__}): {e}")
|
||||
print(f"[ERROR] {api_name} API 调用时发生不可重试错误 ({type(e).__name__}): {e}")
|
||||
break # 不再重试此类错误
|
||||
|
||||
if last_exception:
|
||||
@@ -171,7 +196,7 @@ class AsyncFallbackOpenAIClient:
|
||||
pass
|
||||
|
||||
if is_content_filter_error and self.fallback_client and self.fallback_model_name:
|
||||
print(f"ℹ️ 主 API 内容过滤错误 ({e_primary.status_code})。尝试切换到备用 API ({self.fallback_client.base_url})...")
|
||||
print(f"[INFO] 主 API 内容过滤错误 ({e_primary.status_code})。尝试切换到备用 API ({self.fallback_client.base_url})...")
|
||||
try:
|
||||
fallback_completion = await self._attempt_api_call(
|
||||
client=self.fallback_client,
|
||||
@@ -181,20 +206,20 @@ class AsyncFallbackOpenAIClient:
|
||||
api_name="备用",
|
||||
**kwargs.copy()
|
||||
)
|
||||
print(f"✅ 备用 API 调用成功。")
|
||||
print(f"[OK] 备用 API 调用成功。")
|
||||
return fallback_completion
|
||||
except APIError as e_fallback:
|
||||
print(f"❌ 备用 API 调用最终失败: {type(e_fallback).__name__} - {e_fallback}")
|
||||
print(f"[ERROR] 备用 API 调用最终失败: {type(e_fallback).__name__} - {e_fallback}")
|
||||
raise e_fallback
|
||||
else:
|
||||
if not (self.fallback_client and self.fallback_model_name and is_content_filter_error):
|
||||
# 如果不是内容过滤错误,或者没有可用的备用API,则记录主API的原始错误
|
||||
print(f"ℹ️ 主 API 错误 ({type(e_primary).__name__}: {e_primary}), 且不满足备用条件或备用API未配置。")
|
||||
print(f"[INFO] 主 API 错误 ({type(e_primary).__name__}: {e_primary}), 且不满足备用条件或备用API未配置。")
|
||||
raise e_primary
|
||||
except APIError as e_primary_other:
|
||||
print(f"❌ 主 API 调用最终失败 (非内容过滤,错误类型: {type(e_primary_other).__name__}): {e_primary_other}")
|
||||
print(f"[ERROR] 主 API 调用最终失败 (非内容过滤,错误类型: {type(e_primary_other).__name__}): {e_primary_other}")
|
||||
if self.fallback_client and self.fallback_model_name:
|
||||
print(f"ℹ️ 主 API 失败,尝试切换到备用 API ({self.fallback_client.base_url})...")
|
||||
print(f"[INFO] 主 API 失败,尝试切换到备用 API ({self.fallback_client.base_url})...")
|
||||
try:
|
||||
fallback_completion = await self._attempt_api_call(
|
||||
client=self.fallback_client,
|
||||
@@ -204,10 +229,10 @@ class AsyncFallbackOpenAIClient:
|
||||
api_name="备用",
|
||||
**kwargs.copy()
|
||||
)
|
||||
print(f"✅ 备用 API 调用成功。")
|
||||
print(f"[OK] 备用 API 调用成功。")
|
||||
return fallback_completion
|
||||
except APIError as e_fallback_after_primary_fail:
|
||||
print(f"❌ 备用 API 在主 API 失败后也调用失败: {type(e_fallback_after_primary_fail).__name__} - {e_fallback_after_primary_fail}")
|
||||
print(f"[ERROR] 备用 API 在主 API 失败后也调用失败: {type(e_fallback_after_primary_fail).__name__} - {e_fallback_after_primary_fail}")
|
||||
raise e_fallback_after_primary_fail
|
||||
else:
|
||||
raise e_primary_other
|
||||
|
||||
@@ -7,17 +7,17 @@ def format_execution_result(result: Dict[str, Any]) -> str:
|
||||
feedback = []
|
||||
|
||||
if result['success']:
|
||||
feedback.append("✅ 代码执行成功")
|
||||
feedback.append("[OK] 代码执行成功")
|
||||
|
||||
if result['output']:
|
||||
feedback.append(f"📊 输出结果:\n{result['output']}")
|
||||
feedback.append(f"[CHART] 输出结果:\n{result['output']}")
|
||||
|
||||
if result.get('variables'):
|
||||
feedback.append("📋 新生成的变量:")
|
||||
feedback.append("[LIST] 新生成的变量:")
|
||||
for var_name, var_info in result['variables'].items():
|
||||
feedback.append(f" - {var_name}: {var_info}")
|
||||
else:
|
||||
feedback.append("❌ 代码执行失败")
|
||||
feedback.append("[ERROR] 代码执行失败")
|
||||
feedback.append(f"错误信息: {result['error']}")
|
||||
if result['output']:
|
||||
feedback.append(f"部分输出: {result['output']}")
|
||||
|
||||
@@ -5,8 +5,17 @@ LLM调用辅助模块
|
||||
|
||||
import asyncio
|
||||
import yaml
|
||||
from typing import Optional, Callable, AsyncIterator
|
||||
from config.llm_config import LLMConfig
|
||||
from config.app_config import app_config
|
||||
from utils.fallback_openai_client import AsyncFallbackOpenAIClient
|
||||
from utils.cache_manager import LLMCacheManager
|
||||
|
||||
# 初始化LLM缓存管理器
|
||||
llm_cache = LLMCacheManager(
|
||||
cache_dir=app_config.llm_cache_dir,
|
||||
enabled=app_config.llm_cache_enabled
|
||||
)
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM调用辅助类,支持同步和异步调用"""
|
||||
@@ -75,12 +84,126 @@ class LLMHelper:
|
||||
else:
|
||||
yaml_content = response.strip()
|
||||
|
||||
return yaml.safe_load(yaml_content)
|
||||
# Strip language identifier if LLM used ```python instead of ```yaml
|
||||
# e.g. "python\naction: ..." → "action: ..."
|
||||
import re
|
||||
if re.match(r'^[a-zA-Z]+\n', yaml_content):
|
||||
yaml_content = yaml_content.split('\n', 1)[1]
|
||||
|
||||
# Fix Windows backslash paths that break YAML double-quoted strings.
|
||||
# Replace ALL backslashes inside double-quoted strings with forward slashes.
|
||||
# This handles both "D:\code\..." and "outputs\session_..." patterns.
|
||||
yaml_content = re.sub(
|
||||
r'"([^"]*\\[^"]*)"',
|
||||
lambda m: '"' + m.group(1).replace('\\', '/') + '"',
|
||||
yaml_content,
|
||||
)
|
||||
|
||||
parsed = yaml.safe_load(yaml_content)
|
||||
return parsed if parsed is not None else {}
|
||||
except Exception as e:
|
||||
print(f"YAML解析失败: {e}")
|
||||
print(f"原始响应: {response}")
|
||||
return {}
|
||||
|
||||
|
||||
async def close(self):
|
||||
"""关闭客户端"""
|
||||
await self.client.close()
|
||||
await self.client.close()
|
||||
|
||||
async def async_call_with_cache(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: str = None,
|
||||
max_tokens: int = None,
|
||||
temperature: float = None,
|
||||
use_cache: bool = True
|
||||
) -> str:
|
||||
"""带缓存的异步LLM调用"""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 生成缓存键
|
||||
cache_key = llm_cache.get_cache_key_from_messages(messages, self.config.model)
|
||||
|
||||
# 尝试从缓存获取
|
||||
if use_cache and app_config.llm_cache_enabled:
|
||||
cached_response = llm_cache.get(cache_key)
|
||||
if cached_response:
|
||||
print("[CACHE] 使用LLM缓存响应")
|
||||
return cached_response
|
||||
|
||||
# 调用LLM
|
||||
response = await self.async_call(prompt, system_prompt, max_tokens, temperature)
|
||||
|
||||
# 缓存响应
|
||||
if use_cache and app_config.llm_cache_enabled and response:
|
||||
llm_cache.set(cache_key, response)
|
||||
|
||||
return response
|
||||
|
||||
def call_with_cache(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: str = None,
|
||||
max_tokens: int = None,
|
||||
temperature: float = None,
|
||||
use_cache: bool = True
|
||||
) -> str:
|
||||
"""带缓存的同步LLM调用"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
|
||||
return loop.run_until_complete(
|
||||
self.async_call_with_cache(prompt, system_prompt, max_tokens, temperature, use_cache)
|
||||
)
|
||||
|
||||
async def async_call_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: str = None,
|
||||
max_tokens: int = None,
|
||||
temperature: float = None,
|
||||
callback: Optional[Callable[[str], None]] = None
|
||||
) -> AsyncIterator[str]:
|
||||
"""流式异步LLM调用"""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
kwargs = {
|
||||
'stream': True,
|
||||
'max_tokens': max_tokens or self.config.max_tokens,
|
||||
'temperature': temperature or self.config.temperature
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self.client.chat_completions_create(
|
||||
messages=messages,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
async for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
full_response += content
|
||||
|
||||
# 调用回调函数
|
||||
if callback:
|
||||
callback(content)
|
||||
|
||||
yield content
|
||||
|
||||
except Exception as e:
|
||||
print(f"流式LLM调用失败: {e}")
|
||||
yield ""
|
||||
113
utils/logger.py
Normal file
113
utils/logger.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
统一日志模块 - 替代全局 sys.stdout 劫持
|
||||
|
||||
提供线程安全的日志记录,支持同时输出到终端和文件。
|
||||
每个会话拥有独立的日志文件,不会互相干扰。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def create_session_logger(
|
||||
session_id: str,
|
||||
log_dir: str,
|
||||
log_filename: str = "process.log",
|
||||
level: int = logging.INFO,
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
为指定会话创建独立的 Logger 实例。
|
||||
|
||||
Args:
|
||||
session_id: 会话唯一标识
|
||||
log_dir: 日志文件所在目录
|
||||
log_filename: 日志文件名
|
||||
level: 日志级别
|
||||
|
||||
Returns:
|
||||
配置好的 Logger 实例
|
||||
"""
|
||||
logger = logging.getLogger(f"session.{session_id}")
|
||||
logger.setLevel(level)
|
||||
|
||||
# 避免重复添加 handler
|
||||
if logger.handlers:
|
||||
return logger
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
# 文件 handler — 写入会话专属日志
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_path = os.path.join(log_dir, log_filename)
|
||||
file_handler = logging.FileHandler(log_path, encoding="utf-8", mode="a")
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# 终端 handler — 输出到 stderr(不干扰 stdout)
|
||||
console_handler = logging.StreamHandler(sys.stderr)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# 不向父 logger 传播
|
||||
logger.propagate = False
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
class PrintCapture:
|
||||
"""
|
||||
轻量级 print 捕获器,将 print 输出同时写入日志文件。
|
||||
用于兼容现有大量使用 print() 的代码,无需逐行改造。
|
||||
|
||||
用法:
|
||||
with PrintCapture(log_path) as cap:
|
||||
print("hello") # 同时输出到终端和文件
|
||||
# 退出后 sys.stdout 自动恢复
|
||||
"""
|
||||
|
||||
def __init__(self, log_path: str, filter_patterns: Optional[list] = None):
|
||||
self.log_path = log_path
|
||||
self.filter_patterns = filter_patterns or ["[TOOL] 执行代码:"]
|
||||
self._original_stdout = None
|
||||
self._log_file = None
|
||||
|
||||
def __enter__(self):
|
||||
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
|
||||
self._original_stdout = sys.stdout
|
||||
self._log_file = open(self.log_path, "a", encoding="utf-8", buffering=1)
|
||||
sys.stdout = self._DualWriter(
|
||||
self._original_stdout, self._log_file, self.filter_patterns
|
||||
)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
sys.stdout = self._original_stdout
|
||||
if self._log_file:
|
||||
self._log_file.close()
|
||||
return False
|
||||
|
||||
class _DualWriter:
|
||||
"""同时写入两个流,支持过滤"""
|
||||
|
||||
def __init__(self, terminal, log_file, filter_patterns):
|
||||
self.terminal = terminal
|
||||
self.log_file = log_file
|
||||
self.filter_patterns = filter_patterns
|
||||
|
||||
def write(self, message):
|
||||
self.terminal.write(message)
|
||||
# 过滤不需要写入日志的内容
|
||||
if any(p in message for p in self.filter_patterns):
|
||||
return
|
||||
self.log_file.write(message)
|
||||
|
||||
def flush(self):
|
||||
self.terminal.flush()
|
||||
self.log_file.flush()
|
||||
280
utils/script_generator.py
Normal file
280
utils/script_generator.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可复用脚本生成器
|
||||
|
||||
从分析会话的执行历史中提取成功执行的代码,
|
||||
合并去重后生成可独立运行的 .py 脚本文件。
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Set
|
||||
|
||||
|
||||
def extract_imports(code: str) -> Set[str]:
|
||||
"""从代码中提取所有 import 语句"""
|
||||
imports = set()
|
||||
lines = code.split('\n')
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if stripped.startswith('import ') or stripped.startswith('from '):
|
||||
# 标准化 import 语句
|
||||
imports.add(stripped)
|
||||
return imports
|
||||
|
||||
|
||||
def remove_imports(code: str) -> str:
|
||||
"""从代码中移除所有 import 语句"""
|
||||
lines = code.split('\n')
|
||||
result_lines = []
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if not stripped.startswith('import ') and not stripped.startswith('from '):
|
||||
result_lines.append(line)
|
||||
return '\n'.join(result_lines)
|
||||
|
||||
|
||||
def clean_code_block(code: str) -> str:
|
||||
"""清理代码块,移除不必要的内容"""
|
||||
# 移除可能的重复配置代码
|
||||
patterns_to_skip = [
|
||||
r"plt\.rcParams\['font\.sans-serif'\]", # 字体配置在模板中统一处理
|
||||
r"plt\.rcParams\['axes\.unicode_minus'\]",
|
||||
]
|
||||
|
||||
lines = code.split('\n')
|
||||
result_lines = []
|
||||
skip_until_empty = False
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
# 跳过空行连续的情况
|
||||
if not stripped:
|
||||
if skip_until_empty:
|
||||
skip_until_empty = False
|
||||
continue
|
||||
result_lines.append(line)
|
||||
continue
|
||||
|
||||
# 检查是否需要跳过的模式
|
||||
should_skip = False
|
||||
for pattern in patterns_to_skip:
|
||||
if re.search(pattern, stripped):
|
||||
should_skip = True
|
||||
break
|
||||
|
||||
if not should_skip:
|
||||
result_lines.append(line)
|
||||
|
||||
return '\n'.join(result_lines)
|
||||
|
||||
|
||||
def _is_verification_code(code: str) -> bool:
|
||||
"""Detect code blocks that only check/list files without doing real analysis.
|
||||
|
||||
These are typically generated when the LLM runs os.listdir / os.path.exists
|
||||
loops to verify outputs, and should not appear in the reusable script.
|
||||
"""
|
||||
lines = [l.strip() for l in code.strip().splitlines() if l.strip() and not l.strip().startswith('#')]
|
||||
if not lines:
|
||||
return True
|
||||
|
||||
verification_indicators = 0
|
||||
analysis_indicators = 0
|
||||
|
||||
for line in lines:
|
||||
# Verification patterns
|
||||
if any(kw in line for kw in [
|
||||
'os.listdir(', 'os.path.exists(', 'os.path.getsize(',
|
||||
'os.path.isfile(', '✓', '✗', 'all_exist',
|
||||
]):
|
||||
verification_indicators += 1
|
||||
# Analysis patterns (actual computation / plotting / saving)
|
||||
if any(kw in line for kw in [
|
||||
'.plot(', 'plt.', '.to_csv(', '.value_counts()',
|
||||
'.groupby(', '.corr(', '.fit_transform(', '.fit_predict(',
|
||||
'pd.read_csv(', 'pd.crosstab(', '.describe()',
|
||||
]):
|
||||
analysis_indicators += 1
|
||||
|
||||
# If the block is dominated by verification with no real analysis, skip it
|
||||
return verification_indicators > 0 and analysis_indicators == 0
|
||||
|
||||
|
||||
def _is_duplicate_data_load(code: str, seen_load_blocks: set) -> bool:
|
||||
"""Detect duplicate data loading blocks (LLM 'amnesia' repeats).
|
||||
|
||||
Computes a fingerprint from the code's structural lines (ignoring
|
||||
whitespace and comments) and returns True if we've seen it before.
|
||||
"""
|
||||
# Extract structural fingerprint: non-empty, non-comment lines
|
||||
structural_lines = []
|
||||
for line in code.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith('#'):
|
||||
structural_lines.append(stripped)
|
||||
|
||||
fingerprint = '\n'.join(structural_lines[:30]) # First 30 lines are enough
|
||||
|
||||
if fingerprint in seen_load_blocks:
|
||||
return True
|
||||
seen_load_blocks.add(fingerprint)
|
||||
return False
|
||||
|
||||
|
||||
def generate_reusable_script(
|
||||
analysis_results: List[Dict[str, Any]],
|
||||
data_files: List[str],
|
||||
session_output_dir: str,
|
||||
user_requirement: str = ""
|
||||
) -> str:
|
||||
"""
|
||||
从分析结果中生成可复用的 Python 脚本
|
||||
|
||||
Args:
|
||||
analysis_results: 分析过程中记录的结果列表,每个元素包含 'code', 'result' 等
|
||||
data_files: 原始数据文件路径列表
|
||||
session_output_dir: 会话输出目录
|
||||
user_requirement: 用户的原始需求描述
|
||||
|
||||
Returns:
|
||||
生成的脚本文件路径
|
||||
"""
|
||||
# 收集所有成功执行的代码
|
||||
all_imports = set()
|
||||
code_blocks = []
|
||||
seen_load_blocks: Set[str] = set()
|
||||
|
||||
for result in analysis_results:
|
||||
# 只处理 generate_code 类型的结果
|
||||
if result.get("action") == "collect_figures":
|
||||
continue
|
||||
# Skip retry attempts
|
||||
if result.get("retry"):
|
||||
continue
|
||||
|
||||
code = result.get("code", "")
|
||||
exec_result = result.get("result", {})
|
||||
|
||||
# 只收集成功执行的代码
|
||||
if code and exec_result.get("success", False):
|
||||
# Skip pure verification/file-check code (e.g. os.listdir loops)
|
||||
if _is_verification_code(code):
|
||||
continue
|
||||
|
||||
# Skip duplicate data-loading blocks (LLM amnesia repeats)
|
||||
if _is_duplicate_data_load(code, seen_load_blocks):
|
||||
continue
|
||||
|
||||
# 提取 imports
|
||||
imports = extract_imports(code)
|
||||
all_imports.update(imports)
|
||||
|
||||
# 清理代码块
|
||||
cleaned_code = remove_imports(code)
|
||||
cleaned_code = clean_code_block(cleaned_code)
|
||||
|
||||
# 只添加非空的代码块
|
||||
if cleaned_code.strip():
|
||||
code_blocks.append({
|
||||
"round": result.get("round", 0),
|
||||
"code": cleaned_code.strip()
|
||||
})
|
||||
|
||||
if not code_blocks:
|
||||
print("[WARN] 没有成功执行的代码块,跳过脚本生成")
|
||||
return ""
|
||||
|
||||
# 生成脚本内容
|
||||
now = datetime.now()
|
||||
timestamp = now.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 构建脚本头部
|
||||
script_header = f'''#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据分析脚本 - 自动生成
|
||||
=====================================
|
||||
原始数据文件: {', '.join(data_files)}
|
||||
生成时间: {now.strftime("%Y-%m-%d %H:%M:%S")}
|
||||
原始需求: {user_requirement[:200] + '...' if len(user_requirement) > 200 else user_requirement}
|
||||
=====================================
|
||||
|
||||
使用方法:
|
||||
1. 修改下方 DATA_FILES 列表中的文件路径
|
||||
2. 修改 OUTPUT_DIR 指定输出目录
|
||||
3. 运行: python {os.path.basename(session_output_dir)}_分析脚本.py
|
||||
"""
|
||||
|
||||
import os
|
||||
'''
|
||||
|
||||
# 添加标准 imports(去重后排序)
|
||||
standard_imports = sorted([imp for imp in all_imports if imp.startswith('import ')])
|
||||
from_imports = sorted([imp for imp in all_imports if imp.startswith('from ')])
|
||||
|
||||
imports_section = '\n'.join(standard_imports + from_imports)
|
||||
|
||||
# 配置区域
|
||||
config_section = f'''
|
||||
# ========== 配置区域 (可修改) ==========
|
||||
|
||||
# 数据文件路径 - 修改此处以分析不同的数据
|
||||
DATA_FILES = {repr(data_files)}
|
||||
|
||||
# 输出目录 - 图片和报告将保存在此目录
|
||||
OUTPUT_DIR = "./analysis_output"
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# ========== 字体配置 (中文显示) ==========
|
||||
import platform
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
system_name = platform.system()
|
||||
if system_name == 'Darwin':
|
||||
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'PingFang SC', 'sans-serif']
|
||||
elif system_name == 'Windows':
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'sans-serif']
|
||||
else:
|
||||
plt.rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'sans-serif']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# 设置 session_output_dir 变量(兼容原始代码)
|
||||
session_output_dir = OUTPUT_DIR
|
||||
'''
|
||||
|
||||
# 合并代码块
|
||||
code_section = "\n# ========== 分析代码 ==========\n\n"
|
||||
|
||||
for i, block in enumerate(code_blocks, 1):
|
||||
code_section += f"# --- 第 {block['round']} 轮分析 ---\n"
|
||||
code_section += block['code'] + "\n\n"
|
||||
|
||||
# 脚本尾部
|
||||
script_footer = '''
|
||||
# ========== 完成 ==========
|
||||
print("\\n" + "=" * 50)
|
||||
print("[OK] 分析完成!")
|
||||
print(f"[OUTPUT] 输出目录: {os.path.abspath(OUTPUT_DIR)}")
|
||||
print("=" * 50)
|
||||
'''
|
||||
|
||||
# 组装完整脚本
|
||||
full_script = script_header + imports_section + config_section + code_section + script_footer
|
||||
|
||||
# 保存脚本文件
|
||||
script_filename = f"分析脚本_{timestamp}.py"
|
||||
script_path = os.path.join(session_output_dir, script_filename)
|
||||
|
||||
try:
|
||||
with open(script_path, 'w', encoding='utf-8') as f:
|
||||
f.write(full_script)
|
||||
print(f"[OK] 可复用脚本已生成: {script_path}")
|
||||
return script_path
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 保存脚本失败: {e}")
|
||||
return ""
|
||||
1217
web/main.py
Normal file
1217
web/main.py
Normal file
File diff suppressed because it is too large
Load Diff
1120
web/static/clean_style.css
Normal file
1120
web/static/clean_style.css
Normal file
File diff suppressed because it is too large
Load Diff
204
web/static/index.html
Normal file
204
web/static/index.html
Normal file
@@ -0,0 +1,204 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>IOV Data Analysis Agent</title>
|
||||
<link rel="stylesheet" href="/static/clean_style.css">
|
||||
|
||||
<!-- Fonts -->
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link
|
||||
href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&family=JetBrains+Mono:wght@400;500&display=swap"
|
||||
rel="stylesheet">
|
||||
|
||||
<!-- Icons -->
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
|
||||
|
||||
<!-- Markdown -->
|
||||
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="app-container">
|
||||
<!-- Sidebar -->
|
||||
<aside class="sidebar">
|
||||
<div class="brand">
|
||||
<i class="fa-solid fa-cube"></i>
|
||||
<span>IOV Agent</span>
|
||||
</div>
|
||||
|
||||
<nav class="nav-menu">
|
||||
<button class="nav-item active" onclick="switchView('analysis')">
|
||||
<i class="fa-solid fa-chart-line"></i> Analysis
|
||||
</button>
|
||||
|
||||
<div class="nav-divider"></div>
|
||||
<div class="nav-section-title">History</div>
|
||||
<div id="historyList" class="history-list">
|
||||
<!-- History items loaded via JS -->
|
||||
<div style="padding:0.5rem; font-size:0.8rem; color:#9CA3AF;">Loading...</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<div class="status-bar">
|
||||
<div id="statusDot" class="status-dot"></div>
|
||||
<span id="statusText">Ready</span>
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
<!-- Main Content -->
|
||||
<main class="main-content">
|
||||
<header class="header">
|
||||
<h2 id="pageTitle">Analysis Dashboard</h2>
|
||||
</header>
|
||||
|
||||
<div class="content-area">
|
||||
<!-- VIEW: ANALYSIS -->
|
||||
<div id="viewAnalysis" class="section active">
|
||||
<div class="analysis-grid">
|
||||
|
||||
<!-- Configuration Panel -->
|
||||
<div class="panel">
|
||||
<div class="panel-title">
|
||||
<span>Configuration</span>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label class="form-label">1. Data Upload</label>
|
||||
<div id="uploadZone" class="upload-area">
|
||||
<i class="fa-solid fa-cloud-arrow-up upload-icon"></i>
|
||||
<p>Click or Drag CSV/Excel Files</p>
|
||||
<div id="fileList" class="file-list"></div>
|
||||
</div>
|
||||
<input type="file" id="fileInput" multiple accept=".csv,.xlsx,.xls" hidden>
|
||||
</div>
|
||||
|
||||
<div class="form-group" id="templateSelectorGroup">
|
||||
<label class="form-label">2. Analysis Template</label>
|
||||
<div id="templateSelector" class="template-cards">
|
||||
<div class="template-card selected" data-template="" onclick="selectTemplate(this, '')">
|
||||
<div class="template-card-title">No Template</div>
|
||||
<div class="template-card-desc">Free Analysis</div>
|
||||
</div>
|
||||
<!-- Dynamic template cards loaded via JS -->
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label class="form-label">3. Requirement</label>
|
||||
<textarea id="requirementInput" class="form-textarea"
|
||||
placeholder="Describe what you want to analyze..."></textarea>
|
||||
</div>
|
||||
|
||||
<button id="startBtn" class="btn btn-primary" style="margin-top: 1rem; width: 100%;">
|
||||
<i class="fa-solid fa-play"></i> Start Analysis
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Output Panel -->
|
||||
<div class="panel" style="overflow:hidden; display:flex; flex-direction:column;">
|
||||
<div class="panel-title" style="margin-bottom:0.5rem;">
|
||||
<span>Output</span>
|
||||
<div class="tabs">
|
||||
<div class="tab active" onclick="switchTab('execution')">执行过程</div>
|
||||
<div class="tab" onclick="switchTab('datafiles')">数据文件</div>
|
||||
<div class="tab" onclick="switchTab('report')">Report</div>
|
||||
</div>
|
||||
<button id="downloadScriptBtn" class="btn btn-sm btn-secondary hidden"
|
||||
onclick="downloadScript()" style="margin-left:auto;">
|
||||
<i class="fa-solid fa-code"></i> Script
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Progress Bar -->
|
||||
<div id="progressBarContainer" class="progress-bar-container hidden">
|
||||
<div class="progress-bar-info">
|
||||
<span id="progressLabel" class="progress-label">Round 0/0</span>
|
||||
<span id="progressPercent" class="progress-percent">0%</span>
|
||||
</div>
|
||||
<div class="progress-bar-track">
|
||||
<div id="progressBarFill" class="progress-bar-fill" style="width: 0%"></div>
|
||||
</div>
|
||||
<div id="progressMessage" class="progress-message"></div>
|
||||
</div>
|
||||
|
||||
<div class="output-container" id="outputContainer">
|
||||
<!-- Execution Process Tab -->
|
||||
<div id="executionTab" class="tab-content active" style="height:100%; overflow-y:auto;">
|
||||
<div id="roundCardsWrapper" class="round-cards-wrapper">
|
||||
<div class="empty-state" id="executionEmptyState">
|
||||
<p>Waiting to start...</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Data Files Tab -->
|
||||
<div id="datafilesTab" class="tab-content hidden" style="height:100%; overflow-y:auto;">
|
||||
<div id="fileCardsGrid" class="file-cards-grid"></div>
|
||||
<div id="dataPreviewPanel" class="data-preview-panel hidden">
|
||||
<div class="data-preview-header">
|
||||
<span id="previewFileName"></span>
|
||||
<button class="btn btn-sm btn-secondary" onclick="closePreview()"><i class="fa-solid fa-xmark"></i></button>
|
||||
</div>
|
||||
<div id="previewTableContainer"></div>
|
||||
</div>
|
||||
<div class="empty-state" id="datafilesEmptyState">
|
||||
<p>No data files yet.</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Report Tab -->
|
||||
<div id="reportTab" class="tab-content hidden" style="height:100%; overflow-y:auto;">
|
||||
<div id="reportContainer" class="report-content markdown-body">
|
||||
<div class="empty-state">
|
||||
<p>Report will appear here after analysis.</p>
|
||||
</div>
|
||||
</div>
|
||||
<div id="followUpSection" class="hidden"
|
||||
style="margin-top:2rem; border-top:1px solid var(--border-color); padding-top:1rem;">
|
||||
<div class="form-group">
|
||||
<label class="form-label">Follow-up Analysis</label>
|
||||
<div style="display:flex; gap:0.5rem;">
|
||||
<input type="text" id="followUpInput" class="form-input"
|
||||
placeholder="Ask a follow-up question...">
|
||||
<button class="btn btn-primary btn-sm"
|
||||
onclick="sendFollowUp()">Send</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div style="margin-top:1rem; text-align:right">
|
||||
<button id="exportBtn" class="btn btn-secondary btn-sm"
|
||||
onclick="triggerExport()">
|
||||
<i class="fa-solid fa-download"></i> Export ZIP
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Supporting Data Modal -->
|
||||
<div class="supporting-data-modal hidden" id="supportingDataModal">
|
||||
<div class="supporting-data-content">
|
||||
<div class="supporting-data-header">
|
||||
<span>支撑数据</span>
|
||||
<button onclick="closeSupportingData()">×</button>
|
||||
</div>
|
||||
<div class="supporting-data-body" id="supportingDataBody">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</main>
|
||||
</div>
|
||||
|
||||
<script src="/static/script.js"></script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
899
web/static/script.js
Normal file
899
web/static/script.js
Normal file
@@ -0,0 +1,899 @@
|
||||
// DOM Elements
|
||||
const uploadZone = document.getElementById('uploadZone');
|
||||
const fileInput = document.getElementById('fileInput');
|
||||
const fileList = document.getElementById('fileList');
|
||||
const startBtn = document.getElementById('startBtn');
|
||||
const requirementInput = document.getElementById('requirementInput');
|
||||
const statusDot = document.getElementById('statusDot');
|
||||
const statusText = document.getElementById('statusText');
|
||||
const reportContainer = document.getElementById('reportContainer');
|
||||
const downloadScriptBtn = document.getElementById('downloadScriptBtn');
|
||||
|
||||
let isRunning = false;
|
||||
let pollingInterval = null;
|
||||
let currentSessionId = null;
|
||||
let selectedTemplate = '';
|
||||
|
||||
// 报告段落数据(用于润色功能)
|
||||
let reportParagraphs = [];
|
||||
|
||||
// Supporting data from report API
|
||||
let supportingData = {};
|
||||
|
||||
// Execution Process state
|
||||
let lastRenderedRound = 0;
|
||||
|
||||
// --- Progress Bar ---
|
||||
function updateProgressBar(percentage, message, currentRound, maxRounds) {
|
||||
const container = document.getElementById('progressBarContainer');
|
||||
const fill = document.getElementById('progressBarFill');
|
||||
const label = document.getElementById('progressLabel');
|
||||
const percent = document.getElementById('progressPercent');
|
||||
const msg = document.getElementById('progressMessage');
|
||||
if (!container || !fill) return;
|
||||
|
||||
container.classList.remove('hidden');
|
||||
fill.style.width = percentage + '%';
|
||||
if (label) label.textContent = `Round ${currentRound || 0}/${maxRounds || 0}`;
|
||||
if (percent) percent.textContent = Math.round(percentage) + '%';
|
||||
if (msg) msg.textContent = message || '';
|
||||
}
|
||||
|
||||
function hideProgressBar() {
|
||||
const container = document.getElementById('progressBarContainer');
|
||||
if (container) container.classList.add('hidden');
|
||||
}
|
||||
|
||||
// --- Upload Logic ---
|
||||
if (uploadZone) {
|
||||
uploadZone.addEventListener('dragover', (e) => {
|
||||
e.preventDefault();
|
||||
uploadZone.classList.add('dragover');
|
||||
});
|
||||
uploadZone.addEventListener('dragleave', () => uploadZone.classList.remove('dragover'));
|
||||
uploadZone.addEventListener('drop', (e) => {
|
||||
e.preventDefault();
|
||||
uploadZone.classList.remove('dragover');
|
||||
handleFiles(e.dataTransfer.files);
|
||||
});
|
||||
uploadZone.addEventListener('click', () => fileInput.click());
|
||||
}
|
||||
|
||||
if (fileInput) {
|
||||
fileInput.addEventListener('change', (e) => handleFiles(e.target.files));
|
||||
fileInput.addEventListener('click', (e) => e.stopPropagation());
|
||||
}
|
||||
|
||||
// Track all uploaded file paths for this session
|
||||
let uploadedFilePaths = [];
|
||||
|
||||
async function handleFiles(files) {
|
||||
if (files.length === 0) return;
|
||||
|
||||
const formData = new FormData();
|
||||
for (const file of files) {
|
||||
formData.append('files', file);
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await fetch('/api/upload', { method: 'POST', body: formData });
|
||||
if (!res.ok) { alert('Upload failed'); return; }
|
||||
const data = await res.json();
|
||||
// Accumulate uploaded paths
|
||||
if (data.paths) {
|
||||
uploadedFilePaths = uploadedFilePaths.concat(data.paths);
|
||||
}
|
||||
renderFileList();
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
alert('Upload failed');
|
||||
}
|
||||
}
|
||||
|
||||
function renderFileList() {
|
||||
fileList.innerHTML = '';
|
||||
for (let i = 0; i < uploadedFilePaths.length; i++) {
|
||||
const fname = uploadedFilePaths[i].split('/').pop().split('\\').pop();
|
||||
const fileItem = document.createElement('div');
|
||||
fileItem.className = 'file-item';
|
||||
fileItem.innerHTML = `<i class="fa-regular fa-file-excel"></i> ${fname}
|
||||
<span class="file-remove" onclick="removeUploadedFile(${i})" title="移除">×</span>`;
|
||||
fileList.appendChild(fileItem);
|
||||
}
|
||||
}
|
||||
|
||||
window.removeUploadedFile = function(index) {
|
||||
uploadedFilePaths.splice(index, 1);
|
||||
renderFileList();
|
||||
}
|
||||
|
||||
// --- Template Logic ---
|
||||
async function loadTemplates() {
|
||||
try {
|
||||
const res = await fetch('/api/templates');
|
||||
if (!res.ok) return;
|
||||
const data = await res.json();
|
||||
const selector = document.getElementById('templateSelector');
|
||||
if (!selector || !data.templates) return;
|
||||
|
||||
for (const tpl of data.templates) {
|
||||
const card = document.createElement('div');
|
||||
card.className = 'template-card';
|
||||
card.setAttribute('data-template', tpl.name);
|
||||
card.onclick = function() { selectTemplate(this, tpl.name); };
|
||||
card.innerHTML = `<div class="template-card-title">${tpl.display_name}</div><div class="template-card-desc">${tpl.description}</div>`;
|
||||
selector.appendChild(card);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to load templates', e);
|
||||
}
|
||||
}
|
||||
|
||||
window.selectTemplate = function(el, name) {
|
||||
document.querySelectorAll('.template-card').forEach(c => c.classList.remove('selected'));
|
||||
el.classList.add('selected');
|
||||
selectedTemplate = name;
|
||||
}
|
||||
|
||||
// --- Analysis Logic ---
|
||||
if (startBtn) {
|
||||
startBtn.addEventListener('click', startAnalysis);
|
||||
}
|
||||
|
||||
async function startAnalysis() {
|
||||
if (isRunning) return;
|
||||
|
||||
const requirement = requirementInput.value.trim();
|
||||
if (!requirement) {
|
||||
alert('Please enter analysis requirement');
|
||||
return;
|
||||
}
|
||||
|
||||
setRunningState(true);
|
||||
// Reset execution state for new analysis
|
||||
lastRenderedRound = 0;
|
||||
const wrapper = document.getElementById('roundCardsWrapper');
|
||||
if (wrapper) wrapper.innerHTML = '';
|
||||
const emptyState = document.getElementById('executionEmptyState');
|
||||
if (emptyState) emptyState.remove();
|
||||
|
||||
try {
|
||||
const body = { requirement };
|
||||
if (selectedTemplate) {
|
||||
body.template = selectedTemplate;
|
||||
}
|
||||
if (uploadedFilePaths.length > 0) {
|
||||
body.files = uploadedFilePaths;
|
||||
}
|
||||
|
||||
const res = await fetch('/api/start', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(body)
|
||||
});
|
||||
|
||||
if (res.ok) {
|
||||
const data = await res.json();
|
||||
currentSessionId = data.session_id;
|
||||
startPolling();
|
||||
switchTab('execution');
|
||||
} else {
|
||||
const err = await res.json();
|
||||
alert('Failed to start: ' + err.detail);
|
||||
setRunningState(false);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
alert('Error starting analysis');
|
||||
setRunningState(false);
|
||||
}
|
||||
}
|
||||
|
||||
function setRunningState(running) {
|
||||
isRunning = running;
|
||||
startBtn.disabled = running;
|
||||
|
||||
if (running) {
|
||||
startBtn.innerHTML = '<i class="fa-solid fa-spinner fa-spin"></i> Analysis in Progress...';
|
||||
statusDot.className = 'status-dot running';
|
||||
statusText.innerText = 'Analyzing';
|
||||
statusText.style.color = 'var(--primary-color)';
|
||||
const followUpSection = document.getElementById('followUpSection');
|
||||
if (followUpSection) followUpSection.classList.add('hidden');
|
||||
if (downloadScriptBtn) downloadScriptBtn.classList.add('hidden');
|
||||
const tplGroup = document.getElementById('templateSelectorGroup');
|
||||
if (tplGroup) tplGroup.classList.add('hidden');
|
||||
} else {
|
||||
startBtn.innerHTML = '<i class="fa-solid fa-play"></i> Start Analysis';
|
||||
statusDot.className = 'status-dot';
|
||||
statusText.innerText = 'Completed';
|
||||
statusText.style.color = 'var(--text-secondary)';
|
||||
const followUpSection = document.getElementById('followUpSection');
|
||||
if (currentSessionId && followUpSection) followUpSection.classList.remove('hidden');
|
||||
}
|
||||
}
|
||||
|
||||
// --- Polling ---
|
||||
function startPolling() {
|
||||
if (pollingInterval) clearInterval(pollingInterval);
|
||||
if (!currentSessionId) return;
|
||||
|
||||
pollingInterval = setInterval(async () => {
|
||||
try {
|
||||
const res = await fetch(`/api/status?session_id=${currentSessionId}`);
|
||||
if (!res.ok) return;
|
||||
const data = await res.json();
|
||||
|
||||
// Render round cards incrementally
|
||||
const rounds = data.rounds || [];
|
||||
renderRoundCards(rounds);
|
||||
|
||||
// Load data files during polling
|
||||
loadDataFiles();
|
||||
|
||||
// Update progress bar during analysis
|
||||
// Use rounds.length (actual completed analysis rounds) for display
|
||||
// instead of current_round (which includes non-code rounds like collect_figures)
|
||||
if (data.is_running && data.progress_percentage !== undefined) {
|
||||
const displayRound = rounds.length || data.current_round || 0;
|
||||
updateProgressBar(data.progress_percentage, data.status_message, displayRound, data.max_rounds);
|
||||
}
|
||||
|
||||
if (!data.is_running && isRunning) {
|
||||
const displayRound = rounds.length || data.current_round || data.max_rounds;
|
||||
updateProgressBar(100, 'Analysis complete', displayRound, data.max_rounds);
|
||||
setTimeout(hideProgressBar, 3000);
|
||||
|
||||
setRunningState(false);
|
||||
clearInterval(pollingInterval);
|
||||
|
||||
// Final render of rounds
|
||||
renderRoundCards(data.rounds || []);
|
||||
loadDataFiles();
|
||||
|
||||
if (data.has_report) {
|
||||
await loadReport();
|
||||
switchTab('report');
|
||||
}
|
||||
if (data.script_path && downloadScriptBtn) {
|
||||
downloadScriptBtn.classList.remove('hidden');
|
||||
downloadScriptBtn.style.display = 'inline-flex';
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Polling error', e);
|
||||
}
|
||||
}, 2000);
|
||||
}
|
||||
|
||||
// --- Execution Process Tab (Task 11) ---
|
||||
|
||||
function renderRoundCards(rounds) {
|
||||
if (!rounds || rounds.length === 0) return;
|
||||
|
||||
const wrapper = document.getElementById('roundCardsWrapper');
|
||||
if (!wrapper) return;
|
||||
|
||||
// Handle server restart: if rounds shrunk, re-render all
|
||||
if (rounds.length < lastRenderedRound) {
|
||||
lastRenderedRound = 0;
|
||||
wrapper.innerHTML = '';
|
||||
}
|
||||
|
||||
// Remove empty state if present
|
||||
const emptyState = document.getElementById('executionEmptyState');
|
||||
if (emptyState) emptyState.remove();
|
||||
|
||||
// Only render new rounds
|
||||
for (let i = lastRenderedRound; i < rounds.length; i++) {
|
||||
const rd = rounds[i];
|
||||
const card = createRoundCard(rd);
|
||||
wrapper.appendChild(card);
|
||||
}
|
||||
|
||||
lastRenderedRound = rounds.length;
|
||||
|
||||
// Auto-scroll when running
|
||||
if (isRunning) {
|
||||
const executionTab = document.getElementById('executionTab');
|
||||
if (executionTab) {
|
||||
executionTab.scrollTop = executionTab.scrollHeight;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function createRoundCard(rd) {
|
||||
const roundNum = rd.round || 0;
|
||||
const summary = escapeHtml(rd.result_summary || '');
|
||||
const reasoning = escapeHtml(rd.reasoning || '');
|
||||
const code = escapeHtml(rd.code || '');
|
||||
const rawLog = escapeHtml(rd.raw_log || '');
|
||||
|
||||
const card = document.createElement('div');
|
||||
card.className = 'round-card';
|
||||
card.setAttribute('data-round', roundNum);
|
||||
|
||||
// Build evidence table HTML
|
||||
let evidenceHtml = '';
|
||||
const evidenceRows = rd.evidence_rows || [];
|
||||
if (evidenceRows.length > 0) {
|
||||
const cols = Object.keys(evidenceRows[0]);
|
||||
evidenceHtml = `
|
||||
<div class="round-section">
|
||||
<div class="round-section-title">本轮数据案例</div>
|
||||
<table class="evidence-table">
|
||||
<thead><tr>${cols.map(c => `<th>${escapeHtml(c)}</th>`).join('')}</tr></thead>
|
||||
<tbody>${evidenceRows.map(row =>
|
||||
`<tr>${cols.map(c => `<td>${escapeHtml(String(row[c] ?? ''))}</td>`).join('')}</tr>`
|
||||
).join('')}</tbody>
|
||||
</table>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
card.innerHTML = `
|
||||
<div class="round-card-header" onclick="toggleRoundCard(${roundNum})">
|
||||
<span class="round-number">Round ${roundNum}</span>
|
||||
<span class="round-summary">${summary}</span>
|
||||
<i class="fa-solid fa-chevron-down round-toggle-icon"></i>
|
||||
</div>
|
||||
<div class="round-card-body hidden">
|
||||
<div class="round-section">
|
||||
<div class="round-section-title">AI 推理</div>
|
||||
<div class="round-reasoning">${reasoning}</div>
|
||||
</div>
|
||||
<details class="round-details">
|
||||
<summary>代码</summary>
|
||||
<pre class="round-code">${code}</pre>
|
||||
</details>
|
||||
<div class="round-section">
|
||||
<div class="round-section-title">执行结果</div>
|
||||
<div class="round-result">${summary}</div>
|
||||
</div>
|
||||
${evidenceHtml}
|
||||
<details class="round-details">
|
||||
<summary>原始日志</summary>
|
||||
<pre class="round-raw-log">${rawLog}</pre>
|
||||
</details>
|
||||
</div>
|
||||
`;
|
||||
|
||||
return card;
|
||||
}
|
||||
|
||||
window.toggleRoundCard = function(roundNum) {
|
||||
const card = document.querySelector(`.round-card[data-round="${roundNum}"]`);
|
||||
if (!card) return;
|
||||
const body = card.querySelector('.round-card-body');
|
||||
const icon = card.querySelector('.round-toggle-icon');
|
||||
if (!body) return;
|
||||
|
||||
body.classList.toggle('hidden');
|
||||
if (icon) {
|
||||
icon.classList.toggle('fa-chevron-down');
|
||||
icon.classList.toggle('fa-chevron-up');
|
||||
}
|
||||
card.classList.toggle('expanded');
|
||||
}
|
||||
|
||||
function escapeHtml(text) {
|
||||
const div = document.createElement('div');
|
||||
div.textContent = text;
|
||||
return div.innerHTML;
|
||||
}
|
||||
|
||||
// --- Data Files Tab (Task 12) ---
|
||||
|
||||
async function loadDataFiles() {
|
||||
if (!currentSessionId) return;
|
||||
try {
|
||||
const res = await fetch(`/api/data-files?session_id=${currentSessionId}`);
|
||||
if (!res.ok) return;
|
||||
const data = await res.json();
|
||||
const files = data.files || [];
|
||||
|
||||
const grid = document.getElementById('fileCardsGrid');
|
||||
const emptyState = document.getElementById('datafilesEmptyState');
|
||||
if (!grid) return;
|
||||
|
||||
if (files.length === 0) {
|
||||
grid.innerHTML = '';
|
||||
if (emptyState) emptyState.classList.remove('hidden');
|
||||
return;
|
||||
}
|
||||
|
||||
if (emptyState) emptyState.classList.add('hidden');
|
||||
|
||||
grid.innerHTML = files.map(f => {
|
||||
const desc = escapeHtml(f.description || '');
|
||||
const name = escapeHtml(f.filename || '');
|
||||
const rows = f.rows || 0;
|
||||
const iconClass = name.endsWith('.xlsx') ? 'fa-file-excel' : 'fa-file-csv';
|
||||
return `
|
||||
<div class="data-file-card" onclick="previewDataFile('${escapeHtml(f.filename)}')">
|
||||
<div class="data-file-icon"><i class="fa-regular ${iconClass}"></i></div>
|
||||
<div class="data-file-info">
|
||||
<div class="data-file-name">${name}</div>
|
||||
<div class="data-file-desc">${desc} · ${rows}行</div>
|
||||
</div>
|
||||
<button class="btn btn-sm btn-secondary" onclick="event.stopPropagation(); downloadDataFile('${escapeHtml(f.filename)}')">
|
||||
<i class="fa-solid fa-download"></i>
|
||||
</button>
|
||||
</div>`;
|
||||
}).join('');
|
||||
} catch (e) {
|
||||
console.error('Failed to load data files', e);
|
||||
}
|
||||
}
|
||||
|
||||
window.previewDataFile = async function(filename) {
|
||||
if (!currentSessionId) return;
|
||||
try {
|
||||
const res = await fetch(`/api/data-files/preview?session_id=${currentSessionId}&filename=${encodeURIComponent(filename)}`);
|
||||
if (!res.ok) {
|
||||
alert('Failed to load preview');
|
||||
return;
|
||||
}
|
||||
const data = await res.json();
|
||||
const columns = data.columns || [];
|
||||
const rows = data.rows || [];
|
||||
|
||||
const panel = document.getElementById('dataPreviewPanel');
|
||||
const nameEl = document.getElementById('previewFileName');
|
||||
const container = document.getElementById('previewTableContainer');
|
||||
if (!panel || !container) return;
|
||||
|
||||
nameEl.textContent = filename;
|
||||
|
||||
let tableHtml = '<table class="data-preview-table">';
|
||||
tableHtml += '<thead><tr>' + columns.map(c => `<th>${escapeHtml(c)}</th>`).join('') + '</tr></thead>';
|
||||
tableHtml += '<tbody>';
|
||||
for (const row of rows) {
|
||||
tableHtml += '<tr>' + columns.map(c => `<td>${escapeHtml(String(row[c] ?? ''))}</td>`).join('') + '</tr>';
|
||||
}
|
||||
tableHtml += '</tbody></table>';
|
||||
|
||||
container.innerHTML = tableHtml;
|
||||
panel.classList.remove('hidden');
|
||||
} catch (e) {
|
||||
console.error('Preview failed', e);
|
||||
}
|
||||
}
|
||||
|
||||
window.downloadDataFile = function(filename) {
|
||||
if (!currentSessionId) return;
|
||||
const link = document.createElement('a');
|
||||
link.href = `/api/data-files/download?session_id=${currentSessionId}&filename=${encodeURIComponent(filename)}`;
|
||||
link.download = '';
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
}
|
||||
|
||||
window.closePreview = function() {
|
||||
const panel = document.getElementById('dataPreviewPanel');
|
||||
if (panel) panel.classList.add('hidden');
|
||||
}
|
||||
|
||||
// --- Report Logic with Supporting Data (Task 14) ---
|
||||
|
||||
async function loadReport() {
|
||||
if (!currentSessionId) return;
|
||||
try {
|
||||
const res = await fetch(`/api/report?session_id=${currentSessionId}`);
|
||||
const data = await res.json();
|
||||
|
||||
if (!data.content || data.content === "Report not ready.") {
|
||||
reportContainer.innerHTML = '<div class="empty-state"><p>Analysis in progress or no report generated yet.</p></div>';
|
||||
reportParagraphs = [];
|
||||
supportingData = {};
|
||||
return;
|
||||
}
|
||||
|
||||
reportParagraphs = data.paragraphs || [];
|
||||
supportingData = data.supporting_data || {};
|
||||
|
||||
renderParagraphReport(reportParagraphs);
|
||||
|
||||
} catch (e) {
|
||||
reportContainer.innerHTML = '<p class="error">Failed to load report.</p>';
|
||||
}
|
||||
}
|
||||
|
||||
function renderParagraphReport(paragraphs) {
|
||||
if (!paragraphs || paragraphs.length === 0) {
|
||||
reportContainer.innerHTML = '<div class="empty-state"><p>No report content.</p></div>';
|
||||
return;
|
||||
}
|
||||
|
||||
let html = '';
|
||||
for (const p of paragraphs) {
|
||||
const renderedContent = marked.parse(p.content);
|
||||
const typeClass = `para-${p.type}`;
|
||||
const hasSupportingData = supportingData[p.id] && supportingData[p.id].length > 0;
|
||||
const supportingBtn = hasSupportingData
|
||||
? `<button class="supporting-data-btn" onclick="event.stopPropagation(); showSupportingData('${p.id}')"><i class="fa-solid fa-table"></i> 查看支撑数据</button>`
|
||||
: '';
|
||||
html += `
|
||||
<div class="report-paragraph ${typeClass}" data-para-id="${p.id}" onclick="selectParagraph('${p.id}')">
|
||||
<div class="para-content">${renderedContent}</div>
|
||||
${supportingBtn}
|
||||
<div class="para-actions hidden">
|
||||
<button class="polish-btn" onclick="event.stopPropagation(); polishParagraph('${p.id}', 'context')" title="根据上下文润色">
|
||||
<i class="fa-solid fa-wand-magic-sparkles"></i> 上下文润色
|
||||
</button>
|
||||
<button class="polish-btn" onclick="event.stopPropagation(); polishParagraph('${p.id}', 'data')" title="结合分析数据润色">
|
||||
<i class="fa-solid fa-database"></i> 数据润色
|
||||
</button>
|
||||
<button class="polish-btn" onclick="event.stopPropagation(); showCustomPolish('${p.id}')" title="自定义润色指令">
|
||||
<i class="fa-solid fa-pen"></i> 自定义
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
reportContainer.innerHTML = html;
|
||||
}
|
||||
|
||||
window.showSupportingData = function(paraId) {
|
||||
const rows = supportingData[paraId];
|
||||
if (!rows || rows.length === 0) return;
|
||||
|
||||
const modal = document.getElementById('supportingDataModal');
|
||||
const body = document.getElementById('supportingDataBody');
|
||||
if (!modal || !body) return;
|
||||
|
||||
const cols = Object.keys(rows[0]);
|
||||
let tableHtml = '<table class="evidence-table">';
|
||||
tableHtml += '<thead><tr>' + cols.map(c => `<th>${escapeHtml(c)}</th>`).join('') + '</tr></thead>';
|
||||
tableHtml += '<tbody>';
|
||||
for (const row of rows) {
|
||||
tableHtml += '<tr>' + cols.map(c => `<td>${escapeHtml(String(row[c] ?? ''))}</td>`).join('') + '</tr>';
|
||||
}
|
||||
tableHtml += '</tbody></table>';
|
||||
|
||||
body.innerHTML = tableHtml;
|
||||
modal.classList.remove('hidden');
|
||||
}
|
||||
|
||||
window.closeSupportingData = function() {
|
||||
const modal = document.getElementById('supportingDataModal');
|
||||
if (modal) modal.classList.add('hidden');
|
||||
}
|
||||
|
||||
// --- Paragraph Selection & Polishing (preserved from original) ---
|
||||
|
||||
window.selectParagraph = function(paraId) {
|
||||
document.querySelectorAll('.report-paragraph').forEach(el => {
|
||||
el.classList.remove('selected');
|
||||
el.querySelector('.para-actions')?.classList.add('hidden');
|
||||
});
|
||||
|
||||
const target = document.querySelector(`[data-para-id="${paraId}"]`);
|
||||
if (target) {
|
||||
target.classList.add('selected');
|
||||
target.querySelector('.para-actions')?.classList.remove('hidden');
|
||||
}
|
||||
}
|
||||
|
||||
window.polishParagraph = async function(paraId, mode, customInstruction = '') {
|
||||
if (!currentSessionId) return;
|
||||
|
||||
const target = document.querySelector(`[data-para-id="${paraId}"]`);
|
||||
if (!target) return;
|
||||
|
||||
const actionsEl = target.querySelector('.para-actions');
|
||||
const originalActions = actionsEl.innerHTML;
|
||||
actionsEl.innerHTML = '<span class="polish-loading"><i class="fa-solid fa-spinner fa-spin"></i> AI 润色中...</span>';
|
||||
|
||||
try {
|
||||
const res = await fetch('/api/report/polish', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
session_id: currentSessionId,
|
||||
paragraph_id: paraId,
|
||||
mode: mode,
|
||||
custom_instruction: customInstruction,
|
||||
})
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const err = await res.json();
|
||||
alert('润色失败: ' + (err.detail || 'Unknown error'));
|
||||
actionsEl.innerHTML = originalActions;
|
||||
return;
|
||||
}
|
||||
|
||||
const data = await res.json();
|
||||
showPolishDiff(target, paraId, data.original, data.polished);
|
||||
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
alert('润色请求失败');
|
||||
actionsEl.innerHTML = originalActions;
|
||||
}
|
||||
}
|
||||
|
||||
function showPolishDiff(targetEl, paraId, original, polished) {
|
||||
const polishedHtml = marked.parse(polished);
|
||||
|
||||
targetEl.innerHTML = `
|
||||
<div class="polish-diff">
|
||||
<div class="diff-header">
|
||||
<span class="diff-title"><i class="fa-solid fa-wand-magic-sparkles"></i> 润色结果预览</span>
|
||||
</div>
|
||||
<div class="diff-panels">
|
||||
<div class="diff-panel diff-original">
|
||||
<div class="diff-label">原文</div>
|
||||
<div class="diff-body">${marked.parse(original)}</div>
|
||||
</div>
|
||||
<div class="diff-panel diff-polished">
|
||||
<div class="diff-label">润色后</div>
|
||||
<div class="diff-body">${polishedHtml}</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="diff-actions">
|
||||
<button class="btn btn-primary btn-sm" id="acceptBtn-${paraId}">
|
||||
<i class="fa-solid fa-check"></i> 采纳
|
||||
</button>
|
||||
<button class="btn btn-secondary btn-sm" id="rejectBtn-${paraId}">
|
||||
<i class="fa-solid fa-xmark"></i> 放弃
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
document.getElementById(`acceptBtn-${paraId}`).addEventListener('click', (e) => {
|
||||
e.stopPropagation();
|
||||
applyPolish(paraId, polished);
|
||||
});
|
||||
document.getElementById(`rejectBtn-${paraId}`).addEventListener('click', (e) => {
|
||||
e.stopPropagation();
|
||||
rejectPolish(paraId);
|
||||
});
|
||||
}
|
||||
|
||||
window.applyPolish = async function(paraId, newContent) {
|
||||
if (!currentSessionId) return;
|
||||
|
||||
try {
|
||||
const res = await fetch('/api/report/apply', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
session_id: currentSessionId,
|
||||
paragraph_id: paraId,
|
||||
new_content: newContent,
|
||||
})
|
||||
});
|
||||
|
||||
if (res.ok) {
|
||||
await loadReport();
|
||||
} else {
|
||||
alert('应用失败');
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
alert('应用失败');
|
||||
}
|
||||
}
|
||||
|
||||
window.rejectPolish = function(paraId) {
|
||||
loadReport();
|
||||
}
|
||||
|
||||
window.showCustomPolish = function(paraId) {
|
||||
const target = document.querySelector(`[data-para-id="${paraId}"]`);
|
||||
if (!target) return;
|
||||
|
||||
const actionsEl = target.querySelector('.para-actions');
|
||||
if (!actionsEl) return;
|
||||
|
||||
actionsEl.innerHTML = `
|
||||
<div class="custom-polish-input">
|
||||
<input type="text" class="form-input" id="customInput-${paraId}" placeholder="输入润色指令,如:增加数据对比、语气更正式..." style="flex:1;">
|
||||
<button class="btn btn-primary btn-sm" onclick="event.stopPropagation(); submitCustomPolish('${paraId}')">
|
||||
<i class="fa-solid fa-paper-plane"></i>
|
||||
</button>
|
||||
<button class="btn btn-secondary btn-sm" onclick="event.stopPropagation(); loadReport()">
|
||||
<i class="fa-solid fa-xmark"></i>
|
||||
</button>
|
||||
</div>
|
||||
`;
|
||||
|
||||
document.getElementById(`customInput-${paraId}`)?.focus();
|
||||
}
|
||||
|
||||
window.submitCustomPolish = function(paraId) {
|
||||
const input = document.getElementById(`customInput-${paraId}`);
|
||||
if (!input) return;
|
||||
const instruction = input.value.trim();
|
||||
if (!instruction) {
|
||||
alert('请输入润色指令');
|
||||
return;
|
||||
}
|
||||
polishParagraph(paraId, 'custom', instruction);
|
||||
}
|
||||
|
||||
// --- Download / Export ---
|
||||
window.downloadScript = async function () {
|
||||
if (!currentSessionId) return;
|
||||
const link = document.createElement('a');
|
||||
link.href = `/api/download_script?session_id=${currentSessionId}`;
|
||||
link.download = '';
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
}
|
||||
|
||||
window.triggerExport = async function () {
|
||||
if (!currentSessionId) {
|
||||
alert("No active session to export.");
|
||||
return;
|
||||
}
|
||||
const btn = document.getElementById('exportBtn');
|
||||
const originalContent = btn.innerHTML;
|
||||
btn.innerHTML = '<i class="fa-solid fa-spinner fa-spin"></i> Zipping...';
|
||||
btn.disabled = true;
|
||||
|
||||
try {
|
||||
window.open(`/api/export?session_id=${currentSessionId}`, '_blank');
|
||||
} catch (e) {
|
||||
alert("Export failed: " + e.message);
|
||||
} finally {
|
||||
setTimeout(() => {
|
||||
btn.innerHTML = originalContent;
|
||||
btn.disabled = false;
|
||||
}, 2000);
|
||||
}
|
||||
}
|
||||
|
||||
// --- Follow-up Chat ---
|
||||
window.sendFollowUp = async function () {
|
||||
if (!currentSessionId || isRunning) return;
|
||||
const input = document.getElementById('followUpInput');
|
||||
const message = input.value.trim();
|
||||
if (!message) return;
|
||||
|
||||
input.disabled = true;
|
||||
try {
|
||||
const res = await fetch('/api/chat', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ session_id: currentSessionId, message: message })
|
||||
});
|
||||
|
||||
if (res.ok) {
|
||||
input.value = '';
|
||||
setRunningState(true);
|
||||
// Reset round rendering for follow-up
|
||||
lastRenderedRound = 0;
|
||||
const wrapper = document.getElementById('roundCardsWrapper');
|
||||
if (wrapper) wrapper.innerHTML = '';
|
||||
startPolling();
|
||||
switchTab('execution');
|
||||
} else {
|
||||
alert('Failed to send request');
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
} finally {
|
||||
input.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
// --- History Logic ---
|
||||
async function loadHistory() {
|
||||
const list = document.getElementById('historyList');
|
||||
if (!list) return;
|
||||
|
||||
try {
|
||||
const res = await fetch('/api/history');
|
||||
const data = await res.json();
|
||||
|
||||
if (data.history.length === 0) {
|
||||
list.innerHTML = '<div style="padding:0.5rem; font-size:0.8rem; color:#9CA3AF;">No history yet</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
let html = '';
|
||||
data.history.forEach(item => {
|
||||
html += `
|
||||
<div class="history-item" onclick="loadSession('${item.id}')" id="hist-${item.id}">
|
||||
<i class="fa-regular fa-clock"></i>
|
||||
<span>${item.id}</span>
|
||||
</div>
|
||||
`;
|
||||
});
|
||||
list.innerHTML = html;
|
||||
} catch (e) {
|
||||
console.error("Failed to load history", e);
|
||||
}
|
||||
}
|
||||
|
||||
window.loadSession = async function (sessionId) {
|
||||
if (isRunning) {
|
||||
alert("Analysis in progress, please wait.");
|
||||
return;
|
||||
}
|
||||
|
||||
currentSessionId = sessionId;
|
||||
|
||||
document.querySelectorAll('.history-item').forEach(el => el.classList.remove('active'));
|
||||
const activeItem = document.getElementById(`hist-${sessionId}`);
|
||||
if (activeItem) activeItem.classList.add('active');
|
||||
|
||||
reportContainer.innerHTML = "";
|
||||
if (downloadScriptBtn) downloadScriptBtn.classList.add('hidden');
|
||||
const tplGroup = document.getElementById('templateSelectorGroup');
|
||||
if (tplGroup) tplGroup.classList.add('hidden');
|
||||
|
||||
// Reset execution state for loaded session
|
||||
lastRenderedRound = 0;
|
||||
const wrapper = document.getElementById('roundCardsWrapper');
|
||||
if (wrapper) wrapper.innerHTML = '';
|
||||
|
||||
try {
|
||||
const res = await fetch(`/api/status?session_id=${sessionId}`);
|
||||
if (res.ok) {
|
||||
const data = await res.json();
|
||||
|
||||
// Render rounds for historical session
|
||||
renderRoundCards(data.rounds || []);
|
||||
|
||||
// Load data files for historical session
|
||||
loadDataFiles();
|
||||
|
||||
if (data.has_report) {
|
||||
await loadReport();
|
||||
if (data.script_path && downloadScriptBtn) {
|
||||
downloadScriptBtn.classList.remove('hidden');
|
||||
downloadScriptBtn.style.display = 'inline-flex';
|
||||
}
|
||||
switchTab('report');
|
||||
} else {
|
||||
switchTab('execution');
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Error loading session", e);
|
||||
}
|
||||
}
|
||||
|
||||
// --- Init & Navigation (Task 13) ---
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
loadTemplates();
|
||||
loadHistory();
|
||||
});
|
||||
|
||||
window.switchView = function (viewName) {
|
||||
console.log("View switch requested:", viewName);
|
||||
}
|
||||
|
||||
window.switchTab = function (tabName) {
|
||||
// Deactivate all tabs
|
||||
document.querySelectorAll('.tab').forEach(t => t.classList.remove('active'));
|
||||
|
||||
// Hide all tab content
|
||||
['execution', 'datafiles', 'report'].forEach(name => {
|
||||
const content = document.getElementById(`${name}Tab`);
|
||||
if (content) content.classList.add('hidden');
|
||||
});
|
||||
|
||||
// Activate the clicked tab button
|
||||
document.querySelectorAll('.tab').forEach(btn => {
|
||||
if (btn.getAttribute('onclick') && btn.getAttribute('onclick').includes(`'${tabName}'`)) {
|
||||
btn.classList.add('active');
|
||||
}
|
||||
});
|
||||
|
||||
// Show the selected tab content
|
||||
if (tabName === 'execution') {
|
||||
document.getElementById('executionTab').classList.remove('hidden');
|
||||
} else if (tabName === 'datafiles') {
|
||||
document.getElementById('datafilesTab').classList.remove('hidden');
|
||||
if (currentSessionId) loadDataFiles();
|
||||
} else if (tabName === 'report') {
|
||||
document.getElementById('reportTab').classList.remove('hidden');
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user