# -*- coding: utf-8 -*- """ Unit and integration tests for agent-robustness-optimization features. Run: python -m pytest tests/test_unit.py -v """ import os import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import re import pytest from utils.data_privacy import ( _extract_column_from_error, _lookup_column_in_profile, generate_enriched_hint, ) from utils.analysis_templates import get_template, list_templates, TEMPLATE_REGISTRY from config.app_config import AppConfig # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- # Updated patterns matching data_analysis_agent.py DATA_CONTEXT_PATTERNS = [ # KeyError - missing key/column r"KeyError:\s*['\"](.+?)['\"]", # ValueError - value-related issues r"ValueError.*(?:column|col|field|shape|axis)", # NameError - undefined variables r"NameError.*(?:df|data|frame|series)", # Empty/missing data r"(?:empty|no\s+data|0\s+rows|No\s+data)", # IndexError - out of bounds r"IndexError.*(?:out of range|out of bounds)", # AttributeError - missing attributes r"AttributeError.*(?:DataFrame|Series|object)\s+has\s+no\s+attribute", # Pandas-specific errors r"pd\.errors\.(?:EmptyDataError|ParserError|MergeError)", r"MergeError: No common columns", # Type errors r"TypeError.*(?:unsupported operand|expected string|cannot convert)", # UnboundLocalError - undefined local variables r"UnboundLocalError.*referenced before assignment", # Syntax errors r"SyntaxError: invalid syntax", # Module/Import errors for data libraries r"ModuleNotFoundError.*(?:pandas|numpy|matplotlib)", r"ImportError.*(?:pandas|numpy|matplotlib)", ] def classify_error(error_message: str) -> str: for pattern in DATA_CONTEXT_PATTERNS: if re.search(pattern, error_message, re.IGNORECASE): return "data_context" return "other" SAMPLE_PROFILE = """| 列名 | 数据类型 | 空值率 | 唯一值数 | 特征描述 | |------|---------|--------|---------|----------| | 车型 | object | 0.0% | 5 | 低基数分类(5类) | | 模块 | object | 2.0% | 12 | 中基数分类(12类) | """ # =========================================================================== # Task 12.1: Unit tests for error classifier # =========================================================================== class TestErrorClassifier: def test_keyerror_single_quotes(self): assert classify_error("KeyError: '车型'") == "data_context" def test_keyerror_double_quotes(self): assert classify_error('KeyError: "model_name"') == "data_context" def test_valueerror_column(self): assert classify_error("ValueError: column 'x' not in DataFrame") == "data_context" def test_nameerror_df(self): assert classify_error("NameError: name 'df' is not defined") == "data_context" def test_empty_dataframe(self): assert classify_error("empty DataFrame after filtering") == "data_context" def test_zero_rows(self): assert classify_error("0 rows returned from query") == "data_context" def test_index_out_of_range(self): assert classify_error("IndexError: index 10 is out of range") == "data_context" def test_syntax_error_is_data_context(self): assert classify_error("SyntaxError: invalid syntax") == "data_context" def test_type_error_is_data_context(self): assert classify_error("TypeError: unsupported operand") == "data_context" def test_generic_text_is_other(self): assert classify_error("Something went wrong") == "other" def test_empty_string_is_other(self): assert classify_error("") == "other" # =========================================================================== # Additional tests for improved error classifier # =========================================================================== def test_attributeerror_dataframe(self): assert classify_error("AttributeError: 'DataFrame' object has no attribute 'xxx'") == "data_context" def test_attributeerror_series(self): assert classify_error("AttributeError: 'Series' object has no attribute 'xxx'") == "data_context" def test_pd_emptydataerror(self): assert classify_error("pd.errors.EmptyDataError: No data") == "data_context" def test_pd_parsererror(self): assert classify_error("pd.errors.ParserError: Error tokenizing data") == "data_context" def test_pd_mergeerror(self): assert classify_error("MergeError: No common columns to merge") == "data_context" def test_typeerror_unsupported_operand(self): assert classify_error("TypeError: unsupported operand type(s) for +: 'int' and 'str'") == "data_context" def test_typeerror_expected_string(self): assert classify_error("TypeError: expected string or bytes-like object") == "data_context" def test_unboundlocalerror(self): assert classify_error("UnboundLocalError: local variable 'df' referenced before assignment") == "data_context" def test_syntaxerror(self): assert classify_error("SyntaxError: invalid syntax") == "data_context" def test_modulenotfounderror(self): assert classify_error("ModuleNotFoundError: No module named 'pandas'") == "data_context" def test_importerror(self): assert classify_error("ImportError: cannot import name 'xxx' from 'pandas'") == "data_context" def test_valueerror_shape(self): assert classify_error("ValueError: shape mismatch") == "data_context" def test_valueerror_axis(self): assert classify_error("ValueError: axis out of bounds") == "data_context" def test_nameerror_series(self): assert classify_error("NameError: name 'series' is not defined") == "data_context" def test_no_data_message(self): assert classify_error("No data available for analysis") == "data_context" # =========================================================================== # Task 12.1 continued: Unit tests for column extraction and lookup # =========================================================================== class TestColumnExtraction: def test_extract_from_keyerror(self): assert _extract_column_from_error("KeyError: '车型'") == "车型" def test_extract_from_column_phrase(self): assert _extract_column_from_error("column '模块' not found") == "模块" def test_extract_none_for_generic(self): assert _extract_column_from_error("SyntaxError: bad") is None def test_lookup_existing_column(self): result = _lookup_column_in_profile("车型", SAMPLE_PROFILE) assert result is not None assert result["dtype"] == "object" assert result["unique_count"] == "5" def test_lookup_missing_column(self): assert _lookup_column_in_profile("不存在", SAMPLE_PROFILE) is None def test_lookup_none_column(self): assert _lookup_column_in_profile(None, SAMPLE_PROFILE) is None # =========================================================================== # Task 12.2: Unit tests for conversation trimming at boundary conditions # =========================================================================== class TestConversationTrimming: def _make_history(self, n_pairs): history = [{"role": "user", "content": "ORIGINAL"}] for i in range(n_pairs): history.append({"role": "assistant", "content": f"response {i}"}) history.append({"role": "user", "content": f"feedback {i}"}) return history def test_no_trimming_when_under_limit(self): """History with 3 pairs and window=5 should not be trimmed.""" history = self._make_history(3) # 1 + 6 = 7 messages window = 5 max_messages = window * 2 # 10 assert len(history) <= max_messages # no trimming def test_trimming_at_exact_boundary(self): """History exactly at 2*window should not be trimmed.""" window = 3 history = self._make_history(3) # 1 + 6 = 7 messages max_messages = window * 2 # 6 # 7 > 6, so trimming should happen assert len(history) > max_messages def test_first_message_always_preserved(self): """After trimming, first message must be preserved.""" history = self._make_history(10) window = 2 max_messages = window * 2 first = history[0] to_consider = history[1:] to_keep = to_consider[-max_messages:] new_history = [first, {"role": "user", "content": "[分析摘要] ..."}] new_history.extend(to_keep) assert new_history[0]["content"] == "ORIGINAL" def test_summary_replaces_old_summary(self): """If a summary already exists at index 1, it should be replaced.""" history = [ {"role": "user", "content": "ORIGINAL"}, {"role": "user", "content": "[分析摘要] old summary"}, ] for i in range(8): history.append({"role": "assistant", "content": f"resp {i}"}) history.append({"role": "user", "content": f"fb {i}"}) # Simulate trimming with existing summary has_summary = history[1]["content"].startswith("[分析摘要]") assert has_summary start_idx = 2 if has_summary else 1 assert start_idx == 2 # =========================================================================== # Task 12.3: Tests for template API # =========================================================================== class TestTemplateSystem: def test_list_templates_returns_all(self): templates = list_templates() assert len(templates) == len(TEMPLATE_REGISTRY) names = {t["name"] for t in templates} assert names == set(TEMPLATE_REGISTRY.keys()) def test_get_valid_template(self): for name in TEMPLATE_REGISTRY: t = get_template(name) assert t.name # has a display name assert len(t.steps) > 0 # template has steps def test_get_invalid_template_raises(self): with pytest.raises(ValueError): get_template("nonexistent_template_xyz") def test_template_prompt_not_empty(self): for name in TEMPLATE_REGISTRY: t = get_template(name) prompt = t.get_full_prompt() assert len(prompt) > 50 # should be substantial # =========================================================================== # Task 12.4: Tests for config # =========================================================================== class TestAppConfig: def test_defaults(self): config = AppConfig() assert config.max_data_context_retries == 2 assert config.conversation_window_size == 10 assert config.max_parallel_profiles == 4 def test_env_override(self): os.environ["APP_MAX_DATA_CONTEXT_RETRIES"] = "5" os.environ["APP_CONVERSATION_WINDOW_SIZE"] = "20" os.environ["APP_MAX_PARALLEL_PROFILES"] = "8" try: config = AppConfig.from_env() assert config.max_data_context_retries == 5 assert config.conversation_window_size == 20 assert config.max_parallel_profiles == 8 finally: del os.environ["APP_MAX_DATA_CONTEXT_RETRIES"] del os.environ["APP_CONVERSATION_WINDOW_SIZE"] del os.environ["APP_MAX_PARALLEL_PROFILES"]