Files
iov_data_analysis_agent/tests/test_unit.py

230 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
Unit and integration tests for agent-robustness-optimization features.
Run: python -m pytest tests/test_unit.py -v
"""
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import re
import pytest
from utils.data_privacy import (
_extract_column_from_error,
_lookup_column_in_profile,
generate_enriched_hint,
)
from utils.analysis_templates import get_template, list_templates, TEMPLATE_REGISTRY
from config.app_config import AppConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
DATA_CONTEXT_PATTERNS = [
r"KeyError:\s*['\"](.+?)['\"]",
r"ValueError.*(?:column|col|field)",
r"NameError.*(?:df|data|frame)",
r"(?:empty|no\s+data|0\s+rows)",
r"IndexError.*(?:out of range|out of bounds)",
]
def classify_error(error_message: str) -> str:
for pattern in DATA_CONTEXT_PATTERNS:
if re.search(pattern, error_message, re.IGNORECASE):
return "data_context"
return "other"
SAMPLE_PROFILE = """| 列名 | 数据类型 | 空值率 | 唯一值数 | 特征描述 |
|------|---------|--------|---------|----------|
| 车型 | object | 0.0% | 5 | 低基数分类5类 |
| 模块 | object | 2.0% | 12 | 中基数分类12类 |
"""
# ===========================================================================
# Task 12.1: Unit tests for error classifier
# ===========================================================================
class TestErrorClassifier:
def test_keyerror_single_quotes(self):
assert classify_error("KeyError: '车型'") == "data_context"
def test_keyerror_double_quotes(self):
assert classify_error('KeyError: "model_name"') == "data_context"
def test_valueerror_column(self):
assert classify_error("ValueError: column 'x' not in DataFrame") == "data_context"
def test_nameerror_df(self):
assert classify_error("NameError: name 'df' is not defined") == "data_context"
def test_empty_dataframe(self):
assert classify_error("empty DataFrame after filtering") == "data_context"
def test_zero_rows(self):
assert classify_error("0 rows returned from query") == "data_context"
def test_index_out_of_range(self):
assert classify_error("IndexError: index 10 is out of range") == "data_context"
def test_syntax_error_is_other(self):
assert classify_error("SyntaxError: invalid syntax") == "other"
def test_type_error_is_other(self):
assert classify_error("TypeError: unsupported operand") == "other"
def test_generic_text_is_other(self):
assert classify_error("Something went wrong") == "other"
def test_empty_string_is_other(self):
assert classify_error("") == "other"
# ===========================================================================
# Task 12.1 continued: Unit tests for column extraction and lookup
# ===========================================================================
class TestColumnExtraction:
def test_extract_from_keyerror(self):
assert _extract_column_from_error("KeyError: '车型'") == "车型"
def test_extract_from_column_phrase(self):
assert _extract_column_from_error("column '模块' not found") == "模块"
def test_extract_none_for_generic(self):
assert _extract_column_from_error("SyntaxError: bad") is None
def test_lookup_existing_column(self):
result = _lookup_column_in_profile("车型", SAMPLE_PROFILE)
assert result is not None
assert result["dtype"] == "object"
assert result["unique_count"] == "5"
def test_lookup_missing_column(self):
assert _lookup_column_in_profile("不存在", SAMPLE_PROFILE) is None
def test_lookup_none_column(self):
assert _lookup_column_in_profile(None, SAMPLE_PROFILE) is None
# ===========================================================================
# Task 12.2: Unit tests for conversation trimming at boundary conditions
# ===========================================================================
class TestConversationTrimming:
def _make_history(self, n_pairs):
history = [{"role": "user", "content": "ORIGINAL"}]
for i in range(n_pairs):
history.append({"role": "assistant", "content": f"response {i}"})
history.append({"role": "user", "content": f"feedback {i}"})
return history
def test_no_trimming_when_under_limit(self):
"""History with 3 pairs and window=5 should not be trimmed."""
history = self._make_history(3) # 1 + 6 = 7 messages
window = 5
max_messages = window * 2 # 10
assert len(history) <= max_messages # no trimming
def test_trimming_at_exact_boundary(self):
"""History exactly at 2*window should not be trimmed."""
window = 3
history = self._make_history(3) # 1 + 6 = 7 messages
max_messages = window * 2 # 6
# 7 > 6, so trimming should happen
assert len(history) > max_messages
def test_first_message_always_preserved(self):
"""After trimming, first message must be preserved."""
history = self._make_history(10)
window = 2
max_messages = window * 2
first = history[0]
to_consider = history[1:]
to_keep = to_consider[-max_messages:]
new_history = [first, {"role": "user", "content": "[分析摘要] ..."}]
new_history.extend(to_keep)
assert new_history[0]["content"] == "ORIGINAL"
def test_summary_replaces_old_summary(self):
"""If a summary already exists at index 1, it should be replaced."""
history = [
{"role": "user", "content": "ORIGINAL"},
{"role": "user", "content": "[分析摘要] old summary"},
]
for i in range(8):
history.append({"role": "assistant", "content": f"resp {i}"})
history.append({"role": "user", "content": f"fb {i}"})
# Simulate trimming with existing summary
has_summary = history[1]["content"].startswith("[分析摘要]")
assert has_summary
start_idx = 2 if has_summary else 1
assert start_idx == 2
# ===========================================================================
# Task 12.3: Tests for template API
# ===========================================================================
class TestTemplateSystem:
def test_list_templates_returns_all(self):
templates = list_templates()
assert len(templates) == len(TEMPLATE_REGISTRY)
names = {t["name"] for t in templates}
assert names == set(TEMPLATE_REGISTRY.keys())
def test_get_valid_template(self):
for name in TEMPLATE_REGISTRY:
t = get_template(name)
assert t.name # has a display name
steps = t.build_steps()
assert len(steps) > 0
def test_get_invalid_template_raises(self):
with pytest.raises(ValueError):
get_template("nonexistent_template_xyz")
def test_template_prompt_not_empty(self):
for name in TEMPLATE_REGISTRY:
t = get_template(name)
prompt = t.get_full_prompt()
assert len(prompt) > 50 # should be substantial
# ===========================================================================
# Task 12.4: Tests for config
# ===========================================================================
class TestAppConfig:
def test_defaults(self):
config = AppConfig()
assert config.max_data_context_retries == 2
assert config.conversation_window_size == 10
assert config.max_parallel_profiles == 4
def test_env_override(self):
os.environ["APP_MAX_DATA_CONTEXT_RETRIES"] = "5"
os.environ["APP_CONVERSATION_WINDOW_SIZE"] = "20"
os.environ["APP_MAX_PARALLEL_PROFILES"] = "8"
try:
config = AppConfig.from_env()
assert config.max_data_context_retries == 5
assert config.conversation_window_size == 20
assert config.max_parallel_profiles == 8
finally:
del os.environ["APP_MAX_DATA_CONTEXT_RETRIES"]
del os.environ["APP_CONVERSATION_WINDOW_SIZE"]
del os.environ["APP_MAX_PARALLEL_PROFILES"]