2026-04-19 21:30:08 +08:00
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
"""
|
|
|
|
|
|
Unit and integration tests for agent-robustness-optimization features.
|
|
|
|
|
|
|
|
|
|
|
|
Run: python -m pytest tests/test_unit.py -v
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
|
|
|
|
|
|
|
|
import re
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
|
|
from utils.data_privacy import (
|
|
|
|
|
|
_extract_column_from_error,
|
|
|
|
|
|
_lookup_column_in_profile,
|
|
|
|
|
|
generate_enriched_hint,
|
|
|
|
|
|
)
|
|
|
|
|
|
from utils.analysis_templates import get_template, list_templates, TEMPLATE_REGISTRY
|
|
|
|
|
|
from config.app_config import AppConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
# Helpers
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
2026-04-20 14:56:39 +08:00
|
|
|
|
# Updated patterns matching data_analysis_agent.py
|
2026-04-19 21:30:08 +08:00
|
|
|
|
DATA_CONTEXT_PATTERNS = [
|
2026-04-20 14:56:39 +08:00
|
|
|
|
# KeyError - missing key/column
|
2026-04-19 21:30:08 +08:00
|
|
|
|
r"KeyError:\s*['\"](.+?)['\"]",
|
2026-04-20 14:56:39 +08:00
|
|
|
|
# ValueError - value-related issues
|
|
|
|
|
|
r"ValueError.*(?:column|col|field|shape|axis)",
|
|
|
|
|
|
# NameError - undefined variables
|
|
|
|
|
|
r"NameError.*(?:df|data|frame|series)",
|
|
|
|
|
|
# Empty/missing data
|
|
|
|
|
|
r"(?:empty|no\s+data|0\s+rows|No\s+data)",
|
|
|
|
|
|
# IndexError - out of bounds
|
2026-04-19 21:30:08 +08:00
|
|
|
|
r"IndexError.*(?:out of range|out of bounds)",
|
2026-04-20 14:56:39 +08:00
|
|
|
|
# AttributeError - missing attributes
|
|
|
|
|
|
r"AttributeError.*(?:DataFrame|Series|object)\s+has\s+no\s+attribute",
|
|
|
|
|
|
# Pandas-specific errors
|
|
|
|
|
|
r"pd\.errors\.(?:EmptyDataError|ParserError|MergeError)",
|
|
|
|
|
|
r"MergeError: No common columns",
|
|
|
|
|
|
# Type errors
|
|
|
|
|
|
r"TypeError.*(?:unsupported operand|expected string|cannot convert)",
|
|
|
|
|
|
# UnboundLocalError - undefined local variables
|
|
|
|
|
|
r"UnboundLocalError.*referenced before assignment",
|
|
|
|
|
|
# Syntax errors
|
|
|
|
|
|
r"SyntaxError: invalid syntax",
|
|
|
|
|
|
# Module/Import errors for data libraries
|
|
|
|
|
|
r"ModuleNotFoundError.*(?:pandas|numpy|matplotlib)",
|
|
|
|
|
|
r"ImportError.*(?:pandas|numpy|matplotlib)",
|
2026-04-19 21:30:08 +08:00
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_error(error_message: str) -> str:
|
|
|
|
|
|
for pattern in DATA_CONTEXT_PATTERNS:
|
|
|
|
|
|
if re.search(pattern, error_message, re.IGNORECASE):
|
|
|
|
|
|
return "data_context"
|
|
|
|
|
|
return "other"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SAMPLE_PROFILE = """| 列名 | 数据类型 | 空值率 | 唯一值数 | 特征描述 |
|
|
|
|
|
|
|------|---------|--------|---------|----------|
|
|
|
|
|
|
| 车型 | object | 0.0% | 5 | 低基数分类(5类) |
|
|
|
|
|
|
| 模块 | object | 2.0% | 12 | 中基数分类(12类) |
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
|
# Task 12.1: Unit tests for error classifier
|
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
|
|
|
|
class TestErrorClassifier:
|
|
|
|
|
|
def test_keyerror_single_quotes(self):
|
|
|
|
|
|
assert classify_error("KeyError: '车型'") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_keyerror_double_quotes(self):
|
|
|
|
|
|
assert classify_error('KeyError: "model_name"') == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_valueerror_column(self):
|
|
|
|
|
|
assert classify_error("ValueError: column 'x' not in DataFrame") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_nameerror_df(self):
|
|
|
|
|
|
assert classify_error("NameError: name 'df' is not defined") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_empty_dataframe(self):
|
|
|
|
|
|
assert classify_error("empty DataFrame after filtering") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_zero_rows(self):
|
|
|
|
|
|
assert classify_error("0 rows returned from query") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_index_out_of_range(self):
|
|
|
|
|
|
assert classify_error("IndexError: index 10 is out of range") == "data_context"
|
|
|
|
|
|
|
2026-04-20 14:56:39 +08:00
|
|
|
|
def test_syntax_error_is_data_context(self):
|
|
|
|
|
|
assert classify_error("SyntaxError: invalid syntax") == "data_context"
|
2026-04-19 21:30:08 +08:00
|
|
|
|
|
2026-04-20 14:56:39 +08:00
|
|
|
|
def test_type_error_is_data_context(self):
|
|
|
|
|
|
assert classify_error("TypeError: unsupported operand") == "data_context"
|
2026-04-19 21:30:08 +08:00
|
|
|
|
|
|
|
|
|
|
def test_generic_text_is_other(self):
|
|
|
|
|
|
assert classify_error("Something went wrong") == "other"
|
|
|
|
|
|
|
|
|
|
|
|
def test_empty_string_is_other(self):
|
|
|
|
|
|
assert classify_error("") == "other"
|
|
|
|
|
|
|
2026-04-20 14:56:39 +08:00
|
|
|
|
# ===========================================================================
|
|
|
|
|
|
# Additional tests for improved error classifier
|
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
|
|
|
|
def test_attributeerror_dataframe(self):
|
|
|
|
|
|
assert classify_error("AttributeError: 'DataFrame' object has no attribute 'xxx'") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_attributeerror_series(self):
|
|
|
|
|
|
assert classify_error("AttributeError: 'Series' object has no attribute 'xxx'") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_pd_emptydataerror(self):
|
|
|
|
|
|
assert classify_error("pd.errors.EmptyDataError: No data") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_pd_parsererror(self):
|
|
|
|
|
|
assert classify_error("pd.errors.ParserError: Error tokenizing data") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_pd_mergeerror(self):
|
|
|
|
|
|
assert classify_error("MergeError: No common columns to merge") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_typeerror_unsupported_operand(self):
|
|
|
|
|
|
assert classify_error("TypeError: unsupported operand type(s) for +: 'int' and 'str'") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_typeerror_expected_string(self):
|
|
|
|
|
|
assert classify_error("TypeError: expected string or bytes-like object") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_unboundlocalerror(self):
|
|
|
|
|
|
assert classify_error("UnboundLocalError: local variable 'df' referenced before assignment") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_syntaxerror(self):
|
|
|
|
|
|
assert classify_error("SyntaxError: invalid syntax") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_modulenotfounderror(self):
|
|
|
|
|
|
assert classify_error("ModuleNotFoundError: No module named 'pandas'") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_importerror(self):
|
|
|
|
|
|
assert classify_error("ImportError: cannot import name 'xxx' from 'pandas'") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_valueerror_shape(self):
|
|
|
|
|
|
assert classify_error("ValueError: shape mismatch") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_valueerror_axis(self):
|
|
|
|
|
|
assert classify_error("ValueError: axis out of bounds") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_nameerror_series(self):
|
|
|
|
|
|
assert classify_error("NameError: name 'series' is not defined") == "data_context"
|
|
|
|
|
|
|
|
|
|
|
|
def test_no_data_message(self):
|
|
|
|
|
|
assert classify_error("No data available for analysis") == "data_context"
|
|
|
|
|
|
|
2026-04-19 21:30:08 +08:00
|
|
|
|
|
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
|
# Task 12.1 continued: Unit tests for column extraction and lookup
|
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
|
|
|
|
class TestColumnExtraction:
|
|
|
|
|
|
def test_extract_from_keyerror(self):
|
|
|
|
|
|
assert _extract_column_from_error("KeyError: '车型'") == "车型"
|
|
|
|
|
|
|
|
|
|
|
|
def test_extract_from_column_phrase(self):
|
|
|
|
|
|
assert _extract_column_from_error("column '模块' not found") == "模块"
|
|
|
|
|
|
|
|
|
|
|
|
def test_extract_none_for_generic(self):
|
|
|
|
|
|
assert _extract_column_from_error("SyntaxError: bad") is None
|
|
|
|
|
|
|
|
|
|
|
|
def test_lookup_existing_column(self):
|
|
|
|
|
|
result = _lookup_column_in_profile("车型", SAMPLE_PROFILE)
|
|
|
|
|
|
assert result is not None
|
|
|
|
|
|
assert result["dtype"] == "object"
|
|
|
|
|
|
assert result["unique_count"] == "5"
|
|
|
|
|
|
|
|
|
|
|
|
def test_lookup_missing_column(self):
|
|
|
|
|
|
assert _lookup_column_in_profile("不存在", SAMPLE_PROFILE) is None
|
|
|
|
|
|
|
|
|
|
|
|
def test_lookup_none_column(self):
|
|
|
|
|
|
assert _lookup_column_in_profile(None, SAMPLE_PROFILE) is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
|
# Task 12.2: Unit tests for conversation trimming at boundary conditions
|
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
|
|
|
|
class TestConversationTrimming:
|
|
|
|
|
|
def _make_history(self, n_pairs):
|
|
|
|
|
|
history = [{"role": "user", "content": "ORIGINAL"}]
|
|
|
|
|
|
for i in range(n_pairs):
|
|
|
|
|
|
history.append({"role": "assistant", "content": f"response {i}"})
|
|
|
|
|
|
history.append({"role": "user", "content": f"feedback {i}"})
|
|
|
|
|
|
return history
|
|
|
|
|
|
|
|
|
|
|
|
def test_no_trimming_when_under_limit(self):
|
|
|
|
|
|
"""History with 3 pairs and window=5 should not be trimmed."""
|
|
|
|
|
|
history = self._make_history(3) # 1 + 6 = 7 messages
|
|
|
|
|
|
window = 5
|
|
|
|
|
|
max_messages = window * 2 # 10
|
|
|
|
|
|
assert len(history) <= max_messages # no trimming
|
|
|
|
|
|
|
|
|
|
|
|
def test_trimming_at_exact_boundary(self):
|
|
|
|
|
|
"""History exactly at 2*window should not be trimmed."""
|
|
|
|
|
|
window = 3
|
|
|
|
|
|
history = self._make_history(3) # 1 + 6 = 7 messages
|
|
|
|
|
|
max_messages = window * 2 # 6
|
|
|
|
|
|
# 7 > 6, so trimming should happen
|
|
|
|
|
|
assert len(history) > max_messages
|
|
|
|
|
|
|
|
|
|
|
|
def test_first_message_always_preserved(self):
|
|
|
|
|
|
"""After trimming, first message must be preserved."""
|
|
|
|
|
|
history = self._make_history(10)
|
|
|
|
|
|
window = 2
|
|
|
|
|
|
max_messages = window * 2
|
|
|
|
|
|
|
|
|
|
|
|
first = history[0]
|
|
|
|
|
|
to_consider = history[1:]
|
|
|
|
|
|
to_keep = to_consider[-max_messages:]
|
|
|
|
|
|
|
|
|
|
|
|
new_history = [first, {"role": "user", "content": "[分析摘要] ..."}]
|
|
|
|
|
|
new_history.extend(to_keep)
|
|
|
|
|
|
|
|
|
|
|
|
assert new_history[0]["content"] == "ORIGINAL"
|
|
|
|
|
|
|
|
|
|
|
|
def test_summary_replaces_old_summary(self):
|
|
|
|
|
|
"""If a summary already exists at index 1, it should be replaced."""
|
|
|
|
|
|
history = [
|
|
|
|
|
|
{"role": "user", "content": "ORIGINAL"},
|
|
|
|
|
|
{"role": "user", "content": "[分析摘要] old summary"},
|
|
|
|
|
|
]
|
|
|
|
|
|
for i in range(8):
|
|
|
|
|
|
history.append({"role": "assistant", "content": f"resp {i}"})
|
|
|
|
|
|
history.append({"role": "user", "content": f"fb {i}"})
|
|
|
|
|
|
|
|
|
|
|
|
# Simulate trimming with existing summary
|
|
|
|
|
|
has_summary = history[1]["content"].startswith("[分析摘要]")
|
|
|
|
|
|
assert has_summary
|
|
|
|
|
|
start_idx = 2 if has_summary else 1
|
|
|
|
|
|
assert start_idx == 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
|
# Task 12.3: Tests for template API
|
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
|
|
|
|
class TestTemplateSystem:
|
|
|
|
|
|
def test_list_templates_returns_all(self):
|
|
|
|
|
|
templates = list_templates()
|
|
|
|
|
|
assert len(templates) == len(TEMPLATE_REGISTRY)
|
|
|
|
|
|
names = {t["name"] for t in templates}
|
|
|
|
|
|
assert names == set(TEMPLATE_REGISTRY.keys())
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_valid_template(self):
|
|
|
|
|
|
for name in TEMPLATE_REGISTRY:
|
|
|
|
|
|
t = get_template(name)
|
|
|
|
|
|
assert t.name # has a display name
|
2026-04-20 14:56:39 +08:00
|
|
|
|
assert len(t.steps) > 0 # template has steps
|
2026-04-19 21:30:08 +08:00
|
|
|
|
|
|
|
|
|
|
def test_get_invalid_template_raises(self):
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
|
get_template("nonexistent_template_xyz")
|
|
|
|
|
|
|
|
|
|
|
|
def test_template_prompt_not_empty(self):
|
|
|
|
|
|
for name in TEMPLATE_REGISTRY:
|
|
|
|
|
|
t = get_template(name)
|
|
|
|
|
|
prompt = t.get_full_prompt()
|
|
|
|
|
|
assert len(prompt) > 50 # should be substantial
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
|
# Task 12.4: Tests for config
|
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
|
|
|
|
class TestAppConfig:
|
|
|
|
|
|
def test_defaults(self):
|
|
|
|
|
|
config = AppConfig()
|
|
|
|
|
|
assert config.max_data_context_retries == 2
|
|
|
|
|
|
assert config.conversation_window_size == 10
|
|
|
|
|
|
assert config.max_parallel_profiles == 4
|
|
|
|
|
|
|
|
|
|
|
|
def test_env_override(self):
|
|
|
|
|
|
os.environ["APP_MAX_DATA_CONTEXT_RETRIES"] = "5"
|
|
|
|
|
|
os.environ["APP_CONVERSATION_WINDOW_SIZE"] = "20"
|
|
|
|
|
|
os.environ["APP_MAX_PARALLEL_PROFILES"] = "8"
|
|
|
|
|
|
try:
|
|
|
|
|
|
config = AppConfig.from_env()
|
|
|
|
|
|
assert config.max_data_context_retries == 5
|
|
|
|
|
|
assert config.conversation_window_size == 20
|
|
|
|
|
|
assert config.max_parallel_profiles == 8
|
|
|
|
|
|
finally:
|
|
|
|
|
|
del os.environ["APP_MAX_DATA_CONTEXT_RETRIES"]
|
|
|
|
|
|
del os.environ["APP_CONVERSATION_WINDOW_SIZE"]
|
|
|
|
|
|
del os.environ["APP_MAX_PARALLEL_PROFILES"]
|