大更新,架构调整,数据分析能力提升,
This commit is contained in:
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# tests package
|
||||
11
tests/conftest.py
Normal file
11
tests/conftest.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Conftest for property-based tests.
|
||||
Ensures the project root is on sys.path for direct module imports.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add project root to sys.path so we can import modules directly
|
||||
# (e.g., `from config.app_config import AppConfig`)
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
651
tests/test_dashboard_properties.py
Normal file
651
tests/test_dashboard_properties.py
Normal file
@@ -0,0 +1,651 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Property-based tests for analysis-dashboard-redesign features.
|
||||
Uses hypothesis with max_examples=100 as specified in the design document.
|
||||
|
||||
Run: python -m pytest tests/test_dashboard_properties.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
# Ensure project root is on path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
from hypothesis.extra.pandas import column, data_frames, range_indexes
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers / Strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Strategy for generating random execution results (success or failure)
|
||||
execution_result_st = st.fixed_dictionaries({
|
||||
"success": st.booleans(),
|
||||
"output": st.text(min_size=0, max_size=200),
|
||||
"error": st.text(min_size=0, max_size=200),
|
||||
"variables": st.just({}),
|
||||
"evidence_rows": st.lists(
|
||||
st.dictionaries(
|
||||
keys=st.text(min_size=1, max_size=10, alphabet="abcdefghijklmnopqrstuvwxyz"),
|
||||
values=st.one_of(st.integers(), st.text(min_size=0, max_size=20), st.floats(allow_nan=False)),
|
||||
min_size=1,
|
||||
max_size=5,
|
||||
),
|
||||
min_size=0,
|
||||
max_size=10,
|
||||
),
|
||||
"auto_exported_files": st.just([]),
|
||||
"prompt_saved_files": st.just([]),
|
||||
})
|
||||
|
||||
# Strategy for reasoning text (may be empty, simulating missing YAML field)
|
||||
reasoning_st = st.one_of(st.just(""), st.text(min_size=1, max_size=200))
|
||||
|
||||
# Strategy for code text
|
||||
code_st = st.text(min_size=1, max_size=500, alphabet="abcdefghijklmnopqrstuvwxyz0123456789 =()._\n")
|
||||
|
||||
# Strategy for feedback/raw_log text
|
||||
feedback_st = st.text(min_size=0, max_size=300)
|
||||
|
||||
|
||||
def build_round_data(round_num, reasoning, code, result, feedback):
|
||||
"""Construct a Round_Data dict the same way DataAnalysisAgent._handle_generate_code does."""
|
||||
def summarize_result(r):
|
||||
if r.get("success"):
|
||||
evidence_rows = r.get("evidence_rows", [])
|
||||
if evidence_rows:
|
||||
num_rows = len(evidence_rows)
|
||||
num_cols = len(evidence_rows[0]) if evidence_rows else 0
|
||||
return f"执行成功,输出 DataFrame ({num_rows}行×{num_cols}列)"
|
||||
output = r.get("output", "")
|
||||
if output:
|
||||
first_line = output.strip().split("\n")[0][:80]
|
||||
return f"执行成功: {first_line}"
|
||||
return "执行成功"
|
||||
else:
|
||||
error = r.get("error", "未知错误")
|
||||
if len(error) > 100:
|
||||
error = error[:100] + "..."
|
||||
return f"执行失败: {error}"
|
||||
|
||||
return {
|
||||
"round": round_num,
|
||||
"reasoning": reasoning,
|
||||
"code": code,
|
||||
"result_summary": summarize_result(result),
|
||||
"evidence_rows": result.get("evidence_rows", []),
|
||||
"raw_log": feedback,
|
||||
"auto_exported_files": result.get("auto_exported_files", []),
|
||||
"prompt_saved_files": result.get("prompt_saved_files", []),
|
||||
}
|
||||
|
||||
|
||||
# Regex for parsing DATA_FILE_SAVED markers (same as CodeExecutor)
|
||||
_DATA_FILE_SAVED_RE = re.compile(
|
||||
r"\[DATA_FILE_SAVED\]\s*filename:\s*(.+?),\s*rows:\s*(\d+),\s*description:\s*(.+)"
|
||||
)
|
||||
|
||||
|
||||
def parse_data_file_saved_markers(stdout_text):
|
||||
"""Parse [DATA_FILE_SAVED] marker lines — mirrors CodeExecutor._parse_data_file_saved_markers."""
|
||||
results = []
|
||||
for line in stdout_text.splitlines():
|
||||
m = _DATA_FILE_SAVED_RE.search(line)
|
||||
if m:
|
||||
results.append({
|
||||
"filename": m.group(1).strip(),
|
||||
"rows": int(m.group(2)),
|
||||
"description": m.group(3).strip(),
|
||||
})
|
||||
return results
|
||||
|
||||
|
||||
# Evidence annotation regex (same as web/main.py)
|
||||
_EVIDENCE_PATTERN = re.compile(r"<!--\s*evidence:round_(\d+)\s*-->")
|
||||
|
||||
|
||||
def split_report_to_paragraphs(markdown_content):
|
||||
"""Mirrors _split_report_to_paragraphs from web/main.py."""
|
||||
lines = markdown_content.split("\n")
|
||||
paragraphs = []
|
||||
current_block = []
|
||||
current_type = "text"
|
||||
para_id = 0
|
||||
|
||||
def flush_block():
|
||||
nonlocal para_id, current_block, current_type
|
||||
text = "\n".join(current_block).strip()
|
||||
if text:
|
||||
paragraphs.append({
|
||||
"id": f"p-{para_id}",
|
||||
"type": current_type,
|
||||
"content": text,
|
||||
})
|
||||
para_id += 1
|
||||
current_block = []
|
||||
current_type = "text"
|
||||
|
||||
in_table = False
|
||||
in_code = False
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped.startswith("```"):
|
||||
if in_code:
|
||||
current_block.append(line)
|
||||
flush_block()
|
||||
in_code = False
|
||||
continue
|
||||
else:
|
||||
flush_block()
|
||||
current_block.append(line)
|
||||
current_type = "code"
|
||||
in_code = True
|
||||
continue
|
||||
|
||||
if in_code:
|
||||
current_block.append(line)
|
||||
continue
|
||||
|
||||
if re.match(r"^#{1,6}\s", stripped):
|
||||
flush_block()
|
||||
current_block.append(line)
|
||||
current_type = "heading"
|
||||
flush_block()
|
||||
continue
|
||||
|
||||
if re.match(r"^!\[.*\]\(.*\)", stripped):
|
||||
flush_block()
|
||||
current_block.append(line)
|
||||
current_type = "image"
|
||||
flush_block()
|
||||
continue
|
||||
|
||||
if stripped.startswith("|"):
|
||||
if not in_table:
|
||||
flush_block()
|
||||
in_table = True
|
||||
current_type = "table"
|
||||
current_block.append(line)
|
||||
continue
|
||||
else:
|
||||
if in_table:
|
||||
flush_block()
|
||||
in_table = False
|
||||
|
||||
if not stripped:
|
||||
flush_block()
|
||||
continue
|
||||
|
||||
current_block.append(line)
|
||||
|
||||
flush_block()
|
||||
return paragraphs
|
||||
|
||||
|
||||
def extract_evidence_annotations(paragraphs, rounds):
|
||||
"""Mirrors _extract_evidence_annotations from web/main.py, using a rounds list instead of session."""
|
||||
supporting_data = {}
|
||||
for para in paragraphs:
|
||||
content = para.get("content", "")
|
||||
match = _EVIDENCE_PATTERN.search(content)
|
||||
if match:
|
||||
round_num = int(match.group(1))
|
||||
idx = round_num - 1
|
||||
if 0 <= idx < len(rounds):
|
||||
evidence_rows = rounds[idx].get("evidence_rows", [])
|
||||
if evidence_rows:
|
||||
supporting_data[para["id"]] = evidence_rows
|
||||
return supporting_data
|
||||
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 1: Round_Data Structural Completeness (Task 16.1)
|
||||
# Feature: analysis-dashboard-redesign, Property 1: Round_Data structural completeness
|
||||
# Validates: Requirements 1.1, 1.3, 1.4
|
||||
# ===========================================================================
|
||||
|
||||
ROUND_DATA_REQUIRED_FIELDS = {
|
||||
"round": int,
|
||||
"reasoning": str,
|
||||
"code": str,
|
||||
"result_summary": str,
|
||||
"evidence_rows": list,
|
||||
"raw_log": str,
|
||||
}
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
num_rounds=st.integers(min_value=1, max_value=20),
|
||||
results=st.lists(execution_result_st, min_size=1, max_size=20),
|
||||
reasonings=st.lists(reasoning_st, min_size=1, max_size=20),
|
||||
codes=st.lists(code_st, min_size=1, max_size=20),
|
||||
feedbacks=st.lists(feedback_st, min_size=1, max_size=20),
|
||||
)
|
||||
def test_prop1_round_data_structural_completeness(num_rounds, results, reasonings, codes, feedbacks):
|
||||
"""Round_Data objects must contain all required fields with correct types and preserve insertion order.
|
||||
|
||||
**Validates: Requirements 1.1, 1.3, 1.4**
|
||||
"""
|
||||
# Build a list of rounds using the same number of entries
|
||||
count = min(num_rounds, len(results), len(reasonings), len(codes), len(feedbacks))
|
||||
rounds_list = []
|
||||
for i in range(count):
|
||||
rd = build_round_data(i + 1, reasonings[i], codes[i], results[i], feedbacks[i])
|
||||
rounds_list.append(rd)
|
||||
|
||||
# Verify all required fields present with correct types
|
||||
for rd in rounds_list:
|
||||
for field, expected_type in ROUND_DATA_REQUIRED_FIELDS.items():
|
||||
assert field in rd, f"Missing field: {field}"
|
||||
assert isinstance(rd[field], expected_type), (
|
||||
f"Field '{field}' expected {expected_type.__name__}, got {type(rd[field]).__name__}"
|
||||
)
|
||||
|
||||
# Verify insertion order preserved
|
||||
for i in range(len(rounds_list) - 1):
|
||||
assert rounds_list[i]["round"] <= rounds_list[i + 1]["round"], (
|
||||
f"Insertion order violated: round {rounds_list[i]['round']} > {rounds_list[i + 1]['round']}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 2: Evidence Capture Bounded (Task 16.2)
|
||||
# Feature: analysis-dashboard-redesign, Property 2: Evidence capture bounded
|
||||
# Validates: Requirements 4.1, 4.2, 4.3
|
||||
# ===========================================================================
|
||||
|
||||
# Strategy for generating random DataFrames with 0-10000 rows and 1-50 columns
|
||||
col_name_st = st.text(
|
||||
min_size=1, max_size=10,
|
||||
alphabet=st.sampled_from("abcdefghijklmnopqrstuvwxyz_"),
|
||||
).filter(lambda s: s[0] != "_") # column names shouldn't start with _
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
num_rows=st.integers(min_value=0, max_value=10000),
|
||||
num_cols=st.integers(min_value=1, max_value=50),
|
||||
)
|
||||
def test_prop2_evidence_capture_bounded(num_rows, num_cols):
|
||||
"""Evidence capture must return at most 10 rows with keys matching DataFrame columns.
|
||||
|
||||
**Validates: Requirements 4.1, 4.2, 4.3**
|
||||
"""
|
||||
# Generate a DataFrame with the given dimensions
|
||||
import numpy as np
|
||||
columns = [f"col_{i}" for i in range(num_cols)]
|
||||
if num_rows == 0:
|
||||
df = pd.DataFrame(columns=columns)
|
||||
else:
|
||||
data = np.random.randint(0, 100, size=(num_rows, num_cols))
|
||||
df = pd.DataFrame(data, columns=columns)
|
||||
|
||||
# Simulate the evidence capture logic: df.head(10).to_dict(orient='records')
|
||||
evidence_rows = df.head(10).to_dict(orient="records")
|
||||
|
||||
# Verify length constraints
|
||||
assert len(evidence_rows) <= 10
|
||||
assert len(evidence_rows) == min(10, len(df))
|
||||
|
||||
# Verify each row dict has keys matching the DataFrame's column names
|
||||
expected_keys = set(df.columns)
|
||||
for row in evidence_rows:
|
||||
assert set(row.keys()) == expected_keys
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 3: Filename Deduplication (Task 16.3)
|
||||
# Feature: analysis-dashboard-redesign, Property 3: Filename deduplication
|
||||
# Validates: Requirements 5.3
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
num_exports=st.integers(min_value=1, max_value=20),
|
||||
var_name=st.text(min_size=1, max_size=20, alphabet="abcdefghijklmnopqrstuvwxyz_0123456789").filter(
|
||||
lambda s: s[0].isalpha()
|
||||
),
|
||||
)
|
||||
def test_prop3_filename_deduplication(num_exports, var_name):
|
||||
"""All generated filenames from same-name exports must be unique.
|
||||
|
||||
**Validates: Requirements 5.3**
|
||||
"""
|
||||
output_dir = tempfile.mkdtemp()
|
||||
generated_filenames = []
|
||||
|
||||
for _ in range(num_exports):
|
||||
# Simulate _export_dataframe dedup logic
|
||||
base_filename = f"{var_name}.csv"
|
||||
filepath = os.path.join(output_dir, base_filename)
|
||||
|
||||
if os.path.exists(filepath):
|
||||
suffix = 1
|
||||
while True:
|
||||
dedup_filename = f"{var_name}_{suffix}.csv"
|
||||
filepath = os.path.join(output_dir, dedup_filename)
|
||||
if not os.path.exists(filepath):
|
||||
base_filename = dedup_filename
|
||||
break
|
||||
suffix += 1
|
||||
|
||||
# Create the file to simulate the export
|
||||
with open(filepath, "w") as f:
|
||||
f.write("dummy")
|
||||
|
||||
generated_filenames.append(base_filename)
|
||||
|
||||
# Verify all filenames are unique
|
||||
assert len(generated_filenames) == len(set(generated_filenames)), (
|
||||
f"Duplicate filenames found: {generated_filenames}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 4: Auto-Export Metadata Completeness (Task 16.4)
|
||||
# Feature: analysis-dashboard-redesign, Property 4: Auto-export metadata completeness
|
||||
# Validates: Requirements 5.4, 5.5
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
var_name=st.text(min_size=1, max_size=20, alphabet="abcdefghijklmnopqrstuvwxyz_0123456789").filter(
|
||||
lambda s: s[0].isalpha()
|
||||
),
|
||||
num_rows=st.integers(min_value=0, max_value=1000),
|
||||
num_cols=st.integers(min_value=1, max_value=50),
|
||||
)
|
||||
def test_prop4_auto_export_metadata_completeness(var_name, num_rows, num_cols):
|
||||
"""Auto-export metadata must contain all required fields with correct values.
|
||||
|
||||
**Validates: Requirements 5.4, 5.5**
|
||||
"""
|
||||
import numpy as np
|
||||
output_dir = tempfile.mkdtemp()
|
||||
columns = [f"col_{i}" for i in range(num_cols)]
|
||||
|
||||
if num_rows == 0:
|
||||
df = pd.DataFrame(columns=columns)
|
||||
else:
|
||||
data = np.random.randint(0, 100, size=(num_rows, num_cols))
|
||||
df = pd.DataFrame(data, columns=columns)
|
||||
|
||||
# Simulate _export_dataframe logic
|
||||
base_filename = f"{var_name}.csv"
|
||||
filepath = os.path.join(output_dir, base_filename)
|
||||
|
||||
if os.path.exists(filepath):
|
||||
suffix = 1
|
||||
while True:
|
||||
dedup_filename = f"{var_name}_{suffix}.csv"
|
||||
filepath = os.path.join(output_dir, dedup_filename)
|
||||
if not os.path.exists(filepath):
|
||||
base_filename = dedup_filename
|
||||
break
|
||||
suffix += 1
|
||||
|
||||
df.to_csv(filepath, index=False)
|
||||
metadata = {
|
||||
"variable_name": var_name,
|
||||
"filename": base_filename,
|
||||
"rows": len(df),
|
||||
"cols": len(df.columns),
|
||||
"columns": list(df.columns),
|
||||
}
|
||||
|
||||
# Verify all required fields present
|
||||
for field in ("variable_name", "filename", "rows", "cols", "columns"):
|
||||
assert field in metadata, f"Missing field: {field}"
|
||||
|
||||
# Verify values match the source DataFrame
|
||||
assert metadata["rows"] == len(df)
|
||||
assert metadata["cols"] == len(df.columns)
|
||||
assert metadata["columns"] == list(df.columns)
|
||||
assert metadata["variable_name"] == var_name
|
||||
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 5: DATA_FILE_SAVED Marker Parsing Round-Trip (Task 16.5)
|
||||
# Feature: analysis-dashboard-redesign, Property 5: DATA_FILE_SAVED marker parsing round-trip
|
||||
# Validates: Requirements 6.3
|
||||
# ===========================================================================
|
||||
|
||||
# Strategy for filenames: alphanumeric + Chinese + underscores + hyphens, with extension
|
||||
filename_base_st = st.text(
|
||||
min_size=1,
|
||||
max_size=30,
|
||||
alphabet=st.sampled_from(
|
||||
"abcdefghijklmnopqrstuvwxyz"
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"0123456789"
|
||||
"_-"
|
||||
"数据分析结果汇总报告"
|
||||
),
|
||||
).filter(lambda s: len(s.strip()) > 0 and "," not in s)
|
||||
|
||||
filename_ext_st = st.sampled_from([".csv", ".xlsx"])
|
||||
|
||||
filename_st = st.builds(lambda base, ext: base.strip() + ext, filename_base_st, filename_ext_st)
|
||||
|
||||
description_st = st.text(
|
||||
min_size=1,
|
||||
max_size=100,
|
||||
alphabet=st.sampled_from(
|
||||
"abcdefghijklmnopqrstuvwxyz"
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"0123456789 "
|
||||
"各类型问题聚合统计分析结果"
|
||||
),
|
||||
).filter(lambda s: len(s.strip()) > 0)
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
filename=filename_st,
|
||||
rows=st.integers(min_value=1, max_value=1000000),
|
||||
description=description_st,
|
||||
)
|
||||
def test_prop5_data_file_saved_marker_round_trip(filename, rows, description):
|
||||
"""Formatting then parsing a DATA_FILE_SAVED marker must recover original values.
|
||||
|
||||
**Validates: Requirements 6.3**
|
||||
"""
|
||||
# Format the marker
|
||||
marker = f"[DATA_FILE_SAVED] filename: {filename}, rows: {rows}, description: {description}"
|
||||
|
||||
# Parse using the same logic as CodeExecutor
|
||||
parsed = parse_data_file_saved_markers(marker)
|
||||
|
||||
assert len(parsed) == 1, f"Expected 1 parsed result, got {len(parsed)}"
|
||||
assert parsed[0]["filename"] == filename.strip()
|
||||
assert parsed[0]["rows"] == rows
|
||||
assert parsed[0]["description"] == description.strip()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 6: Data File Preview Bounded Rows (Task 16.6)
|
||||
# Feature: analysis-dashboard-redesign, Property 6: Data file preview bounded rows
|
||||
# Validates: Requirements 7.2
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
num_rows=st.integers(min_value=0, max_value=10000),
|
||||
num_cols=st.integers(min_value=1, max_value=50),
|
||||
)
|
||||
def test_prop6_data_file_preview_bounded_rows(num_rows, num_cols):
|
||||
"""Preview of a CSV file must return at most 5 rows with correct column names.
|
||||
|
||||
**Validates: Requirements 7.2**
|
||||
"""
|
||||
import numpy as np
|
||||
columns = [f"col_{i}" for i in range(num_cols)]
|
||||
|
||||
if num_rows == 0:
|
||||
df = pd.DataFrame(columns=columns)
|
||||
else:
|
||||
data = np.random.randint(0, 100, size=(num_rows, num_cols))
|
||||
df = pd.DataFrame(data, columns=columns)
|
||||
|
||||
# Write to a temp CSV file
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
csv_path = os.path.join(tmp_dir, "test_data.csv")
|
||||
df.to_csv(csv_path, index=False)
|
||||
|
||||
# Read back using the same logic as the preview endpoint
|
||||
preview_df = pd.read_csv(csv_path, nrows=5)
|
||||
|
||||
# Verify at most 5 rows
|
||||
assert len(preview_df) <= 5
|
||||
assert len(preview_df) == min(5, num_rows)
|
||||
|
||||
# Verify column names match exactly
|
||||
assert list(preview_df.columns) == columns
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 7: Evidence Annotation Parsing (Task 16.7)
|
||||
# Feature: analysis-dashboard-redesign, Property 7: Evidence annotation parsing
|
||||
# Validates: Requirements 11.3, 11.4
|
||||
# ===========================================================================
|
||||
|
||||
# Strategy for generating paragraphs with/without evidence annotations
|
||||
annotated_paragraph_st = st.builds(
|
||||
lambda text, round_num: f"{text} <!-- evidence:round_{round_num} -->",
|
||||
st.text(min_size=1, max_size=100, alphabet="abcdefghijklmnopqrstuvwxyz .,!"),
|
||||
st.integers(min_value=1, max_value=100),
|
||||
)
|
||||
|
||||
plain_paragraph_st = st.text(
|
||||
min_size=1,
|
||||
max_size=100,
|
||||
alphabet="abcdefghijklmnopqrstuvwxyz .,!",
|
||||
).filter(lambda s: "evidence:" not in s and len(s.strip()) > 0)
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
annotated=st.lists(annotated_paragraph_st, min_size=0, max_size=10),
|
||||
plain=st.lists(plain_paragraph_st, min_size=0, max_size=10),
|
||||
)
|
||||
def test_prop7_evidence_annotation_parsing(annotated, plain):
|
||||
"""Annotated paragraphs must be correctly extracted; non-annotated must be excluded.
|
||||
|
||||
**Validates: Requirements 11.3, 11.4**
|
||||
"""
|
||||
assume(len(annotated) + len(plain) > 0)
|
||||
|
||||
# Build markdown by interleaving annotated and plain paragraphs
|
||||
all_paragraphs = []
|
||||
for p in annotated:
|
||||
all_paragraphs.append(("annotated", p))
|
||||
for p in plain:
|
||||
all_paragraphs.append(("plain", p))
|
||||
|
||||
# Build markdown content with blank lines between paragraphs
|
||||
markdown = "\n\n".join(text for _, text in all_paragraphs)
|
||||
|
||||
# Parse into paragraphs
|
||||
paragraphs = split_report_to_paragraphs(markdown)
|
||||
|
||||
# Build fake rounds data (up to 100 rounds, each with some evidence)
|
||||
rounds = [
|
||||
{"evidence_rows": [{"key": f"value_{i}"}]}
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
# Extract evidence annotations
|
||||
supporting_data = extract_evidence_annotations(paragraphs, rounds)
|
||||
|
||||
# Verify: annotated paragraphs with valid round numbers should be in supporting_data
|
||||
for para in paragraphs:
|
||||
content = para.get("content", "")
|
||||
match = _EVIDENCE_PATTERN.search(content)
|
||||
if match:
|
||||
round_num = int(match.group(1))
|
||||
idx = round_num - 1
|
||||
if 0 <= idx < len(rounds) and rounds[idx].get("evidence_rows"):
|
||||
assert para["id"] in supporting_data, (
|
||||
f"Annotated paragraph {para['id']} with round {round_num} not in supporting_data"
|
||||
)
|
||||
else:
|
||||
# Non-annotated paragraphs must NOT be in supporting_data
|
||||
assert para["id"] not in supporting_data, (
|
||||
f"Non-annotated paragraph {para['id']} should not be in supporting_data"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 8: SessionData JSON Round-Trip (Task 16.8)
|
||||
# Feature: analysis-dashboard-redesign, Property 8: SessionData JSON round-trip
|
||||
# Validates: Requirements 12.4
|
||||
# ===========================================================================
|
||||
|
||||
# Strategy for Round_Data dicts
|
||||
round_data_st = st.fixed_dictionaries({
|
||||
"round": st.integers(min_value=1, max_value=100),
|
||||
"reasoning": st.text(min_size=0, max_size=200),
|
||||
"code": st.text(min_size=0, max_size=200),
|
||||
"result_summary": st.text(min_size=0, max_size=200),
|
||||
"evidence_rows": st.lists(
|
||||
st.dictionaries(
|
||||
keys=st.text(min_size=1, max_size=10, alphabet="abcdefghijklmnopqrstuvwxyz"),
|
||||
values=st.one_of(
|
||||
st.integers(min_value=-1000, max_value=1000),
|
||||
st.text(min_size=0, max_size=20),
|
||||
),
|
||||
min_size=0,
|
||||
max_size=5,
|
||||
),
|
||||
min_size=0,
|
||||
max_size=10,
|
||||
),
|
||||
"raw_log": st.text(min_size=0, max_size=200),
|
||||
})
|
||||
|
||||
# Strategy for file metadata dicts
|
||||
file_metadata_st = st.fixed_dictionaries({
|
||||
"filename": st.text(min_size=1, max_size=30, alphabet="abcdefghijklmnopqrstuvwxyz0123456789_."),
|
||||
"description": st.text(min_size=0, max_size=100),
|
||||
"rows": st.integers(min_value=0, max_value=100000),
|
||||
"cols": st.integers(min_value=0, max_value=100),
|
||||
"columns": st.lists(st.text(min_size=1, max_size=10, alphabet="abcdefghijklmnopqrstuvwxyz"), max_size=10),
|
||||
"size_bytes": st.integers(min_value=0, max_value=10000000),
|
||||
"source": st.sampled_from(["auto", "prompt"]),
|
||||
})
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
rounds=st.lists(round_data_st, min_size=0, max_size=20),
|
||||
data_files=st.lists(file_metadata_st, min_size=0, max_size=20),
|
||||
)
|
||||
def test_prop8_session_data_json_round_trip(rounds, data_files):
|
||||
"""Serializing rounds and data_files to JSON and back must produce equal data.
|
||||
|
||||
**Validates: Requirements 12.4**
|
||||
"""
|
||||
data = {
|
||||
"rounds": rounds,
|
||||
"data_files": data_files,
|
||||
}
|
||||
|
||||
# Serialize using the same approach as the codebase
|
||||
serialized = json.dumps(data, default=str)
|
||||
deserialized = json.loads(serialized)
|
||||
|
||||
assert deserialized["rounds"] == rounds
|
||||
assert deserialized["data_files"] == data_files
|
||||
238
tests/test_phase1.py
Normal file
238
tests/test_phase1.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Unit tests for Phase 1: Backend Data Model + API Changes
|
||||
|
||||
Run: python -m pytest tests/test_phase1.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 1: SessionData model extension
|
||||
# ===========================================================================
|
||||
|
||||
class TestSessionDataExtension:
|
||||
def test_rounds_initialized_to_empty_list(self):
|
||||
"""Task 1.1: rounds attribute exists and defaults to []"""
|
||||
from web.main import SessionData
|
||||
session = SessionData("test-id")
|
||||
assert hasattr(session, "rounds")
|
||||
assert session.rounds == []
|
||||
assert isinstance(session.rounds, list)
|
||||
|
||||
def test_data_files_initialized_to_empty_list(self):
|
||||
"""Task 1.2: data_files attribute exists and defaults to []"""
|
||||
from web.main import SessionData
|
||||
session = SessionData("test-id")
|
||||
assert hasattr(session, "data_files")
|
||||
assert session.data_files == []
|
||||
assert isinstance(session.data_files, list)
|
||||
|
||||
def test_existing_fields_unchanged(self):
|
||||
"""Existing SessionData fields still work."""
|
||||
from web.main import SessionData
|
||||
session = SessionData("test-id")
|
||||
assert session.session_id == "test-id"
|
||||
assert session.is_running is False
|
||||
assert session.analysis_results == []
|
||||
assert session.current_round == 0
|
||||
assert session.max_rounds == 20
|
||||
|
||||
def test_reconstruct_session_loads_rounds_and_data_files(self):
|
||||
"""Task 1.3: _reconstruct_session loads rounds and data_files from results.json"""
|
||||
from web.main import SessionManager
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
session_dir = os.path.join(tmpdir, "session_test123")
|
||||
os.makedirs(session_dir)
|
||||
|
||||
test_rounds = [{"round": 1, "reasoning": "test", "code": "x=1"}]
|
||||
test_data_files = [{"filename": "out.csv", "rows": 10}]
|
||||
results = {
|
||||
"analysis_results": [{"round": 1}],
|
||||
"rounds": test_rounds,
|
||||
"data_files": test_data_files,
|
||||
}
|
||||
with open(os.path.join(session_dir, "results.json"), "w") as f:
|
||||
json.dump(results, f)
|
||||
|
||||
sm = SessionManager()
|
||||
session = sm._reconstruct_session("test123", session_dir)
|
||||
|
||||
assert session.rounds == test_rounds
|
||||
assert session.data_files == test_data_files
|
||||
assert session.analysis_results == [{"round": 1}]
|
||||
|
||||
def test_reconstruct_session_legacy_format(self):
|
||||
"""Task 1.3: _reconstruct_session handles legacy list format gracefully"""
|
||||
from web.main import SessionManager
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
session_dir = os.path.join(tmpdir, "session_legacy")
|
||||
os.makedirs(session_dir)
|
||||
|
||||
# Legacy format: results.json is a plain list
|
||||
legacy_results = [{"round": 1, "code": "x=1"}]
|
||||
with open(os.path.join(session_dir, "results.json"), "w") as f:
|
||||
json.dump(legacy_results, f)
|
||||
|
||||
sm = SessionManager()
|
||||
session = sm._reconstruct_session("legacy", session_dir)
|
||||
|
||||
assert session.analysis_results == legacy_results
|
||||
assert session.rounds == []
|
||||
assert session.data_files == []
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 2: Status API response
|
||||
# ===========================================================================
|
||||
|
||||
class TestStatusAPIResponse:
|
||||
def test_status_response_contains_rounds(self):
|
||||
"""Task 2.1: GET /api/status response includes rounds field"""
|
||||
from web.main import SessionData, session_manager
|
||||
|
||||
session = SessionData("status-test")
|
||||
session.rounds = [{"round": 1, "reasoning": "r1"}]
|
||||
with session_manager.lock:
|
||||
session_manager.sessions["status-test"] = session
|
||||
|
||||
try:
|
||||
# Simulate what the endpoint returns
|
||||
response = {
|
||||
"is_running": session.is_running,
|
||||
"log": "",
|
||||
"has_report": session.generated_report is not None,
|
||||
"rounds": session.rounds,
|
||||
"current_round": session.current_round,
|
||||
"max_rounds": session.max_rounds,
|
||||
"progress_percentage": session.progress_percentage,
|
||||
"status_message": session.status_message,
|
||||
}
|
||||
assert "rounds" in response
|
||||
assert response["rounds"] == [{"round": 1, "reasoning": "r1"}]
|
||||
finally:
|
||||
with session_manager.lock:
|
||||
del session_manager.sessions["status-test"]
|
||||
|
||||
def test_status_backward_compat_fields(self):
|
||||
"""Task 2.2: Existing fields remain unchanged"""
|
||||
from web.main import SessionData
|
||||
|
||||
session = SessionData("compat-test")
|
||||
session.status_message = "分析中"
|
||||
session.progress_percentage = 50.0
|
||||
session.current_round = 5
|
||||
session.max_rounds = 20
|
||||
|
||||
response = {
|
||||
"is_running": session.is_running,
|
||||
"log": "",
|
||||
"has_report": session.generated_report is not None,
|
||||
"progress_percentage": session.progress_percentage,
|
||||
"current_round": session.current_round,
|
||||
"max_rounds": session.max_rounds,
|
||||
"status_message": session.status_message,
|
||||
"rounds": session.rounds,
|
||||
}
|
||||
|
||||
assert response["is_running"] is False
|
||||
assert response["has_report"] is False
|
||||
assert response["progress_percentage"] == 50.0
|
||||
assert response["current_round"] == 5
|
||||
assert response["max_rounds"] == 20
|
||||
assert response["status_message"] == "分析中"
|
||||
assert "log" in response
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 4: Evidence extraction
|
||||
# ===========================================================================
|
||||
|
||||
class TestEvidenceExtraction:
|
||||
def test_extract_evidence_basic(self):
|
||||
"""Task 4.1: Parse evidence annotations and build supporting_data"""
|
||||
from web.main import _extract_evidence_annotations, SessionData
|
||||
|
||||
session = SessionData("ev-test")
|
||||
session.rounds = [
|
||||
{"round": 1, "evidence_rows": [{"col": "val1"}]},
|
||||
{"round": 2, "evidence_rows": [{"col": "val2"}]},
|
||||
]
|
||||
|
||||
paragraphs = [
|
||||
{"id": "p-0", "type": "text", "content": "Some intro text"},
|
||||
{"id": "p-1", "type": "text", "content": "Analysis result <!-- evidence:round_1 -->"},
|
||||
{"id": "p-2", "type": "text", "content": "More analysis <!-- evidence:round_2 -->"},
|
||||
]
|
||||
|
||||
result = _extract_evidence_annotations(paragraphs, session)
|
||||
|
||||
assert "p-0" not in result # no annotation
|
||||
assert result["p-1"] == [{"col": "val1"}]
|
||||
assert result["p-2"] == [{"col": "val2"}]
|
||||
|
||||
def test_extract_evidence_no_annotations(self):
|
||||
"""Task 4.1: No annotations means empty mapping"""
|
||||
from web.main import _extract_evidence_annotations, SessionData
|
||||
|
||||
session = SessionData("ev-test2")
|
||||
session.rounds = [{"round": 1, "evidence_rows": [{"a": 1}]}]
|
||||
|
||||
paragraphs = [
|
||||
{"id": "p-0", "type": "text", "content": "No evidence here"},
|
||||
]
|
||||
|
||||
result = _extract_evidence_annotations(paragraphs, session)
|
||||
assert result == {}
|
||||
|
||||
def test_extract_evidence_out_of_range_round(self):
|
||||
"""Task 4.1: Round number beyond available rounds is ignored"""
|
||||
from web.main import _extract_evidence_annotations, SessionData
|
||||
|
||||
session = SessionData("ev-test3")
|
||||
session.rounds = [{"round": 1, "evidence_rows": [{"a": 1}]}]
|
||||
|
||||
paragraphs = [
|
||||
{"id": "p-0", "type": "text", "content": "Ref to round 5 <!-- evidence:round_5 -->"},
|
||||
]
|
||||
|
||||
result = _extract_evidence_annotations(paragraphs, session)
|
||||
assert result == {}
|
||||
|
||||
def test_extract_evidence_empty_evidence_rows(self):
|
||||
"""Task 4.1: Round with empty evidence_rows is excluded"""
|
||||
from web.main import _extract_evidence_annotations, SessionData
|
||||
|
||||
session = SessionData("ev-test4")
|
||||
session.rounds = [{"round": 1, "evidence_rows": []}]
|
||||
|
||||
paragraphs = [
|
||||
{"id": "p-0", "type": "text", "content": "Has annotation <!-- evidence:round_1 -->"},
|
||||
]
|
||||
|
||||
result = _extract_evidence_annotations(paragraphs, session)
|
||||
assert result == {}
|
||||
|
||||
def test_extract_evidence_whitespace_in_comment(self):
|
||||
"""Task 4.1: Handles whitespace variations in HTML comment"""
|
||||
from web.main import _extract_evidence_annotations, SessionData
|
||||
|
||||
session = SessionData("ev-test5")
|
||||
session.rounds = [{"round": 1, "evidence_rows": [{"x": 42}]}]
|
||||
|
||||
paragraphs = [
|
||||
{"id": "p-0", "type": "text", "content": "Text <!-- evidence:round_1 -->"},
|
||||
]
|
||||
|
||||
result = _extract_evidence_annotations(paragraphs, session)
|
||||
assert result["p-0"] == [{"x": 42}]
|
||||
217
tests/test_phase2.py
Normal file
217
tests/test_phase2.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Unit tests for Phase 2: CodeExecutor Enhancements
|
||||
|
||||
Run: python -m pytest tests/test_phase2.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from utils.code_executor import CodeExecutor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def executor(tmp_path):
|
||||
"""Create a CodeExecutor with a temp output directory."""
|
||||
return CodeExecutor(output_dir=str(tmp_path))
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 5: Evidence capture
|
||||
# ===========================================================================
|
||||
|
||||
class TestEvidenceCapture:
|
||||
def test_evidence_from_result_dataframe(self, executor):
|
||||
"""5.1: When result.result is a DataFrame, capture head(10) as evidence_rows."""
|
||||
code = "import pandas as pd\npd.DataFrame({'a': [1,2,3], 'b': [4,5,6]})"
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert "evidence_rows" in result
|
||||
assert len(result["evidence_rows"]) == 3
|
||||
assert result["evidence_rows"][0] == {"a": 1, "b": 4}
|
||||
|
||||
def test_evidence_capped_at_10(self, executor):
|
||||
"""5.1: Evidence rows are capped at 10."""
|
||||
code = "import pandas as pd\npd.DataFrame({'x': list(range(100))})"
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert len(result["evidence_rows"]) == 10
|
||||
|
||||
def test_evidence_fallback_to_namespace(self, executor):
|
||||
"""5.2: When result.result is not a DataFrame, fallback to namespace."""
|
||||
code = "import pandas as pd\nmy_data = pd.DataFrame({'col': [10, 20]})\nprint('done')"
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert len(result["evidence_rows"]) == 2
|
||||
assert result["evidence_rows"][0] == {"col": 10}
|
||||
|
||||
def test_evidence_empty_when_no_dataframe(self, executor):
|
||||
"""5.3: Returns empty list when no DataFrame is produced."""
|
||||
executor.reset_environment()
|
||||
code = "x = 42"
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert result["evidence_rows"] == []
|
||||
|
||||
def test_evidence_key_in_failure(self, executor):
|
||||
"""5.3: evidence_rows key present even on failure."""
|
||||
code = "import not_a_real_module"
|
||||
result = executor.execute_code(code)
|
||||
assert "evidence_rows" in result
|
||||
assert result["evidence_rows"] == []
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 6: DataFrame auto-detection and export
|
||||
# ===========================================================================
|
||||
|
||||
class TestDataFrameAutoExport:
|
||||
def test_new_dataframe_exported(self, executor, tmp_path):
|
||||
"""6.1-6.4: New DataFrame is detected and exported to CSV."""
|
||||
code = "import pandas as pd\nresult_df = pd.DataFrame({'a': [1], 'b': [2]})"
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert len(result["auto_exported_files"]) >= 1
|
||||
exported = result["auto_exported_files"][0]
|
||||
assert exported["variable_name"] == "result_df"
|
||||
assert exported["filename"] == "result_df.csv"
|
||||
assert exported["rows"] == 1
|
||||
assert exported["cols"] == 2
|
||||
assert exported["columns"] == ["a", "b"]
|
||||
# Verify file actually exists
|
||||
assert os.path.exists(os.path.join(str(tmp_path), "result_df.csv"))
|
||||
|
||||
def test_dedup_suffix(self, executor, tmp_path):
|
||||
"""6.3: Numeric suffix deduplication when file exists."""
|
||||
# Create first file
|
||||
code1 = "import pandas as pd\nmy_df = pd.DataFrame({'x': [1]})"
|
||||
result1 = executor.execute_code(code1)
|
||||
assert result1["success"] is True
|
||||
|
||||
# Reset the DataFrame to force a new id
|
||||
code2 = "my_df = pd.DataFrame({'x': [2]})"
|
||||
result2 = executor.execute_code(code2)
|
||||
assert result2["success"] is True
|
||||
exported_files = result2["auto_exported_files"]
|
||||
assert len(exported_files) >= 1
|
||||
# The second export should have _1 suffix
|
||||
assert exported_files[0]["filename"] == "my_df_1.csv"
|
||||
|
||||
def test_skip_module_names(self, executor):
|
||||
"""6.1: Module-level names like pd, np are skipped."""
|
||||
code = "x = 42" # pd and np already in namespace from setup
|
||||
result = executor.execute_code(code)
|
||||
# Should not export pd or np as DataFrames
|
||||
for f in result["auto_exported_files"]:
|
||||
assert f["variable_name"] not in ("pd", "np", "plt", "sns")
|
||||
|
||||
def test_auto_exported_files_key_in_result(self, executor):
|
||||
"""6.5: auto_exported_files key always present."""
|
||||
code = "x = 1"
|
||||
result = executor.execute_code(code)
|
||||
assert "auto_exported_files" in result
|
||||
assert isinstance(result["auto_exported_files"], list)
|
||||
|
||||
def test_changed_dataframe_detected(self, executor, tmp_path):
|
||||
"""6.2: Changed DataFrame (same name, new object) is detected."""
|
||||
code1 = "import pandas as pd\ndf_test = pd.DataFrame({'a': [1]})"
|
||||
executor.execute_code(code1)
|
||||
|
||||
code2 = "df_test = pd.DataFrame({'a': [1, 2, 3]})"
|
||||
result2 = executor.execute_code(code2)
|
||||
assert result2["success"] is True
|
||||
exported = [f for f in result2["auto_exported_files"] if f["variable_name"] == "df_test"]
|
||||
assert len(exported) == 1
|
||||
assert exported[0]["rows"] == 3
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 7: DATA_FILE_SAVED marker parsing
|
||||
# ===========================================================================
|
||||
|
||||
class TestDataFileSavedMarkerParsing:
|
||||
def test_parse_single_marker(self, executor):
|
||||
"""7.1-7.2: Parse a single DATA_FILE_SAVED marker from stdout."""
|
||||
code = 'print("[DATA_FILE_SAVED] filename: output.csv, rows: 42, description: Test data")'
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert len(result["prompt_saved_files"]) == 1
|
||||
parsed = result["prompt_saved_files"][0]
|
||||
assert parsed["filename"] == "output.csv"
|
||||
assert parsed["rows"] == 42
|
||||
assert parsed["description"] == "Test data"
|
||||
|
||||
def test_parse_multiple_markers(self, executor):
|
||||
"""7.1-7.2: Parse multiple markers."""
|
||||
code = (
|
||||
'print("[DATA_FILE_SAVED] filename: a.csv, rows: 10, description: File A")\n'
|
||||
'print("[DATA_FILE_SAVED] filename: b.xlsx, rows: 20, description: File B")'
|
||||
)
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert len(result["prompt_saved_files"]) == 2
|
||||
assert result["prompt_saved_files"][0]["filename"] == "a.csv"
|
||||
assert result["prompt_saved_files"][1]["filename"] == "b.xlsx"
|
||||
|
||||
def test_no_markers(self, executor):
|
||||
"""7.3: No markers means empty list."""
|
||||
code = 'print("hello world")'
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert result["prompt_saved_files"] == []
|
||||
|
||||
def test_prompt_saved_files_key_in_result(self, executor):
|
||||
"""7.3: prompt_saved_files key always present."""
|
||||
code = "x = 1"
|
||||
result = executor.execute_code(code)
|
||||
assert "prompt_saved_files" in result
|
||||
assert isinstance(result["prompt_saved_files"], list)
|
||||
|
||||
def test_malformed_marker_skipped(self, executor):
|
||||
"""7.1: Malformed markers are silently skipped."""
|
||||
code = 'print("[DATA_FILE_SAVED] this is not valid")'
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert result["prompt_saved_files"] == []
|
||||
|
||||
def test_chinese_filename_and_description(self, executor):
|
||||
"""7.2: Chinese characters in filename and description."""
|
||||
code = 'print("[DATA_FILE_SAVED] filename: 数据汇总.csv, rows: 100, description: 各类型TOP问题聚合统计")'
|
||||
result = executor.execute_code(code)
|
||||
assert result["success"] is True
|
||||
assert len(result["prompt_saved_files"]) == 1
|
||||
assert result["prompt_saved_files"][0]["filename"] == "数据汇总.csv"
|
||||
assert result["prompt_saved_files"][0]["description"] == "各类型TOP问题聚合统计"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Return structure integrity
|
||||
# ===========================================================================
|
||||
|
||||
class TestReturnStructure:
|
||||
def test_success_return_has_all_keys(self, executor):
|
||||
"""All 7 keys present on success."""
|
||||
result = executor.execute_code("x = 1")
|
||||
expected_keys = {"success", "output", "error", "variables",
|
||||
"evidence_rows", "auto_exported_files", "prompt_saved_files"}
|
||||
assert expected_keys.issubset(set(result.keys()))
|
||||
|
||||
def test_safety_failure_has_all_keys(self, executor):
|
||||
"""All 7 keys present on safety check failure."""
|
||||
result = executor.execute_code("import socket")
|
||||
expected_keys = {"success", "output", "error", "variables",
|
||||
"evidence_rows", "auto_exported_files", "prompt_saved_files"}
|
||||
assert expected_keys.issubset(set(result.keys()))
|
||||
|
||||
def test_execution_error_has_all_keys(self, executor):
|
||||
"""All 7 keys present on execution error."""
|
||||
result = executor.execute_code("1/0")
|
||||
expected_keys = {"success", "output", "error", "variables",
|
||||
"evidence_rows", "auto_exported_files", "prompt_saved_files"}
|
||||
assert expected_keys.issubset(set(result.keys()))
|
||||
233
tests/test_phase3.py
Normal file
233
tests/test_phase3.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Unit tests for Phase 3: Agent Changes
|
||||
|
||||
Run: python -m pytest tests/test_phase3.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
from data_analysis_agent import DataAnalysisAgent
|
||||
from prompts import data_analysis_system_prompt, final_report_system_prompt
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 8.1: _summarize_result
|
||||
# ===========================================================================
|
||||
|
||||
class TestSummarizeResult:
|
||||
@pytest.fixture
|
||||
def agent(self):
|
||||
"""Create a minimal DataAnalysisAgent for testing."""
|
||||
agent = DataAnalysisAgent.__new__(DataAnalysisAgent)
|
||||
agent._session_ref = None
|
||||
return agent
|
||||
|
||||
def test_success_with_evidence_rows(self, agent):
|
||||
"""8.1: Success with evidence rows produces DataFrame summary."""
|
||||
result = {
|
||||
"success": True,
|
||||
"evidence_rows": [{"a": 1, "b": 2}, {"a": 3, "b": 4}],
|
||||
"auto_exported_files": [{"variable_name": "df", "filename": "df.csv", "rows": 150, "cols": 8, "columns": []}],
|
||||
}
|
||||
summary = agent._summarize_result(result)
|
||||
assert "执行成功" in summary
|
||||
assert "DataFrame" in summary
|
||||
assert "150" in summary
|
||||
assert "8" in summary
|
||||
|
||||
def test_success_with_evidence_no_auto_files(self, agent):
|
||||
"""8.1: Success with evidence but no auto_exported_files uses evidence length."""
|
||||
result = {
|
||||
"success": True,
|
||||
"evidence_rows": [{"x": 1}, {"x": 2}, {"x": 3}],
|
||||
"auto_exported_files": [],
|
||||
}
|
||||
summary = agent._summarize_result(result)
|
||||
assert "执行成功" in summary
|
||||
assert "DataFrame" in summary
|
||||
|
||||
def test_success_with_output(self, agent):
|
||||
"""8.1: Success with output but no evidence shows first line."""
|
||||
result = {
|
||||
"success": True,
|
||||
"evidence_rows": [],
|
||||
"output": "Hello World\nSecond line",
|
||||
}
|
||||
summary = agent._summarize_result(result)
|
||||
assert "执行成功" in summary
|
||||
assert "Hello World" in summary
|
||||
|
||||
def test_success_no_output(self, agent):
|
||||
"""8.1: Success with no output or evidence."""
|
||||
result = {"success": True, "evidence_rows": [], "output": ""}
|
||||
summary = agent._summarize_result(result)
|
||||
assert summary == "执行成功"
|
||||
|
||||
def test_failure_short_error(self, agent):
|
||||
"""8.1: Failure with short error message."""
|
||||
result = {"success": False, "error": "KeyError: 'col_x'"}
|
||||
summary = agent._summarize_result(result)
|
||||
assert "执行失败" in summary
|
||||
assert "KeyError" in summary
|
||||
|
||||
def test_failure_long_error_truncated(self, agent):
|
||||
"""8.1: Failure with long error is truncated to 100 chars."""
|
||||
long_error = "A" * 200
|
||||
result = {"success": False, "error": long_error}
|
||||
summary = agent._summarize_result(result)
|
||||
assert "执行失败" in summary
|
||||
assert "..." in summary
|
||||
# The error portion should be at most 103 chars (100 + "...")
|
||||
error_part = summary.split("执行失败: ")[1]
|
||||
assert len(error_part) <= 104
|
||||
|
||||
def test_failure_no_error_field(self, agent):
|
||||
"""8.1: Failure with missing error field."""
|
||||
result = {"success": False}
|
||||
summary = agent._summarize_result(result)
|
||||
assert "执行失败" in summary
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 8.2-8.4: Round_Data construction and session integration
|
||||
# ===========================================================================
|
||||
|
||||
class TestRoundDataConstruction:
|
||||
def test_handle_generate_code_returns_reasoning(self):
|
||||
"""8.2: _handle_generate_code returns reasoning from yaml_data."""
|
||||
agent = DataAnalysisAgent.__new__(DataAnalysisAgent)
|
||||
agent._session_ref = None
|
||||
# We need a minimal executor mock
|
||||
from unittest.mock import MagicMock
|
||||
agent.executor = MagicMock()
|
||||
agent.executor.execute_code.return_value = {
|
||||
"success": True, "output": "ok", "error": "",
|
||||
"variables": {}, "evidence_rows": [],
|
||||
"auto_exported_files": [], "prompt_saved_files": [],
|
||||
}
|
||||
yaml_data = {"code": "x = 1", "reasoning": "Testing reasoning field"}
|
||||
result = agent._handle_generate_code("response text", yaml_data)
|
||||
assert result["reasoning"] == "Testing reasoning field"
|
||||
|
||||
def test_handle_generate_code_empty_reasoning(self):
|
||||
"""8.2: _handle_generate_code returns empty reasoning when not in yaml_data."""
|
||||
agent = DataAnalysisAgent.__new__(DataAnalysisAgent)
|
||||
agent._session_ref = None
|
||||
from unittest.mock import MagicMock
|
||||
agent.executor = MagicMock()
|
||||
agent.executor.execute_code.return_value = {
|
||||
"success": True, "output": "", "error": "",
|
||||
"variables": {}, "evidence_rows": [],
|
||||
"auto_exported_files": [], "prompt_saved_files": [],
|
||||
}
|
||||
yaml_data = {"code": "x = 1"}
|
||||
result = agent._handle_generate_code("response text", yaml_data)
|
||||
assert result["reasoning"] == ""
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 8.3: set_session_ref
|
||||
# ===========================================================================
|
||||
|
||||
class TestSetSessionRef:
|
||||
def test_session_ref_default_none(self):
|
||||
"""8.3: _session_ref defaults to None."""
|
||||
agent = DataAnalysisAgent()
|
||||
assert agent._session_ref is None
|
||||
|
||||
def test_set_session_ref(self):
|
||||
"""8.3: set_session_ref stores the session reference."""
|
||||
agent = DataAnalysisAgent()
|
||||
|
||||
class FakeSession:
|
||||
rounds = []
|
||||
data_files = []
|
||||
|
||||
session = FakeSession()
|
||||
agent.set_session_ref(session)
|
||||
assert agent._session_ref is session
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 9.1: Prompt - intermediate data saving instructions
|
||||
# ===========================================================================
|
||||
|
||||
class TestPromptDataSaving:
|
||||
def test_data_saving_instructions_in_system_prompt(self):
|
||||
"""9.1: data_analysis_system_prompt contains DATA_FILE_SAVED instructions."""
|
||||
assert "[DATA_FILE_SAVED]" in data_analysis_system_prompt
|
||||
assert "中间数据保存规则" in data_analysis_system_prompt
|
||||
|
||||
def test_data_saving_example_in_prompt(self):
|
||||
"""9.1: Prompt contains example of saving and printing marker."""
|
||||
assert "to_csv" in data_analysis_system_prompt
|
||||
assert "session_output_dir" in data_analysis_system_prompt
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 9.2: Prompt - evidence annotation instructions
|
||||
# ===========================================================================
|
||||
|
||||
class TestPromptEvidenceAnnotation:
|
||||
def test_evidence_annotation_in_report_prompt(self):
|
||||
"""9.2: final_report_system_prompt contains evidence annotation instructions."""
|
||||
assert "evidence:round_" in final_report_system_prompt
|
||||
assert "证据标注规则" in final_report_system_prompt
|
||||
|
||||
def test_evidence_annotation_example(self):
|
||||
"""9.2: Prompt contains example of evidence annotation."""
|
||||
assert "<!-- evidence:round_3 -->" in final_report_system_prompt
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 9.3: _build_final_report_prompt includes evidence
|
||||
# ===========================================================================
|
||||
|
||||
class TestBuildFinalReportPromptEvidence:
|
||||
def test_evidence_included_when_session_has_rounds(self):
|
||||
"""9.3: _build_final_report_prompt includes evidence data when rounds exist."""
|
||||
agent = DataAnalysisAgent.__new__(DataAnalysisAgent)
|
||||
agent.analysis_results = []
|
||||
agent.current_round = 2
|
||||
agent.session_output_dir = "/tmp/test"
|
||||
agent.data_profile = "test profile"
|
||||
|
||||
class FakeSession:
|
||||
rounds = [
|
||||
{
|
||||
"round": 1,
|
||||
"reasoning": "分析车型分布",
|
||||
"result_summary": "执行成功,输出 DataFrame (10行×3列)",
|
||||
"evidence_rows": [{"车型": "A", "数量": 42}],
|
||||
},
|
||||
{
|
||||
"round": 2,
|
||||
"reasoning": "分析模块分布",
|
||||
"result_summary": "执行成功",
|
||||
"evidence_rows": [],
|
||||
},
|
||||
]
|
||||
|
||||
agent._session_ref = FakeSession()
|
||||
prompt = agent._build_final_report_prompt([])
|
||||
assert "各轮次分析证据数据" in prompt
|
||||
assert "第1轮" in prompt
|
||||
assert "第2轮" in prompt
|
||||
assert "车型" in prompt
|
||||
|
||||
def test_no_evidence_when_no_session_ref(self):
|
||||
"""9.3: _build_final_report_prompt works without session ref."""
|
||||
agent = DataAnalysisAgent.__new__(DataAnalysisAgent)
|
||||
agent.analysis_results = []
|
||||
agent.current_round = 1
|
||||
agent.session_output_dir = "/tmp/test"
|
||||
agent.data_profile = "test profile"
|
||||
agent._session_ref = None
|
||||
prompt = agent._build_final_report_prompt([])
|
||||
assert "各轮次分析证据数据" not in prompt
|
||||
285
tests/test_properties.py
Normal file
285
tests/test_properties.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Property-based tests for agent-robustness-optimization features.
|
||||
Uses hypothesis with reduced examples (max_examples=20) for fast execution.
|
||||
|
||||
Run: python -m pytest tests/test_properties.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
|
||||
# Ensure project root is on path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from utils.data_privacy import (
|
||||
_extract_column_from_error,
|
||||
_lookup_column_in_profile,
|
||||
generate_enriched_hint,
|
||||
)
|
||||
from utils.analysis_templates import get_template, list_templates, TEMPLATE_REGISTRY
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DATA_CONTEXT_PATTERNS = [
|
||||
r"KeyError:\s*['\"](.+?)['\"]",
|
||||
r"ValueError.*(?:column|col|field)",
|
||||
r"NameError.*(?:df|data|frame)",
|
||||
r"(?:empty|no\s+data|0\s+rows)",
|
||||
r"IndexError.*(?:out of range|out of bounds)",
|
||||
]
|
||||
|
||||
|
||||
def classify_error(error_message: str) -> str:
|
||||
"""Mirror of DataAnalysisAgent._classify_error for testing without IPython."""
|
||||
for pattern in DATA_CONTEXT_PATTERNS:
|
||||
if re.search(pattern, error_message, re.IGNORECASE):
|
||||
return "data_context"
|
||||
return "other"
|
||||
|
||||
|
||||
SAMPLE_SAFE_PROFILE = """# 数据结构概览 (Schema Profile)
|
||||
|
||||
## 文件: test.csv
|
||||
|
||||
- **维度**: 100 行 x 3 列
|
||||
- **列名**: `车型, 模块, 问题类型`
|
||||
|
||||
### 列结构:
|
||||
|
||||
| 列名 | 数据类型 | 空值率 | 唯一值数 | 特征描述 |
|
||||
|------|---------|--------|---------|----------|
|
||||
| 车型 | object | 0.0% | 5 | 低基数分类(5类) |
|
||||
| 模块 | object | 2.0% | 12 | 中基数分类(12类) |
|
||||
| 问题类型 | object | 0.0% | 8 | 低基数分类(8类) |
|
||||
"""
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 1: Error Classification Correctness (Task 11.1)
|
||||
# ===========================================================================
|
||||
|
||||
# Strategy: generate error messages that contain data-context patterns
|
||||
data_context_error_st = st.one_of(
|
||||
st.from_regex(r"KeyError: '[a-zA-Z_]+'" , fullmatch=True),
|
||||
st.from_regex(r'KeyError: "[a-zA-Z_]+"', fullmatch=True),
|
||||
st.just("ValueError: column 'x' not found"),
|
||||
st.just("NameError: name 'df' is not defined"),
|
||||
st.just("empty DataFrame"),
|
||||
st.just("0 rows returned"),
|
||||
st.just("IndexError: index 5 is out of range"),
|
||||
)
|
||||
|
||||
non_data_error_st = st.one_of(
|
||||
st.just("SyntaxError: invalid syntax"),
|
||||
st.just("TypeError: unsupported operand"),
|
||||
st.just("ZeroDivisionError: division by zero"),
|
||||
st.just("ImportError: No module named 'foo'"),
|
||||
st.text(min_size=1, max_size=50).filter(
|
||||
lambda s: not any(re.search(p, s, re.IGNORECASE) for p in DATA_CONTEXT_PATTERNS)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(err=data_context_error_st)
|
||||
def test_prop1_data_context_errors_classified(err):
|
||||
"""Data-context error messages must be classified as 'data_context'."""
|
||||
assert classify_error(err) == "data_context"
|
||||
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(err=non_data_error_st)
|
||||
def test_prop1_non_data_errors_classified(err):
|
||||
"""Non-data error messages must be classified as 'other'."""
|
||||
assert classify_error(err) == "other"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 3: Enriched Hint Contains Column Metadata Without Real Data (11.2)
|
||||
# ===========================================================================
|
||||
|
||||
known_columns = ["车型", "模块", "问题类型"]
|
||||
column_st = st.sampled_from(known_columns)
|
||||
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(col=column_st)
|
||||
def test_prop3_enriched_hint_contains_column_meta(col):
|
||||
"""Enriched hint for a known column must contain its metadata."""
|
||||
error_msg = f"KeyError: '{col}'"
|
||||
hint = generate_enriched_hint(error_msg, SAMPLE_SAFE_PROFILE)
|
||||
assert col in hint
|
||||
assert "数据类型" in hint
|
||||
assert "唯一值数量" in hint
|
||||
assert "空值率" in hint
|
||||
assert "特征描述" in hint
|
||||
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(col=column_st)
|
||||
def test_prop3_enriched_hint_no_real_data(col):
|
||||
"""Enriched hint must NOT contain real data values (min/max/mean/sample rows)."""
|
||||
error_msg = f"KeyError: '{col}'"
|
||||
hint = generate_enriched_hint(error_msg, SAMPLE_SAFE_PROFILE)
|
||||
# Should not contain statistical values or sample data
|
||||
for forbidden in ["Min=", "Max=", "Mean=", "TOP 5 高频值"]:
|
||||
assert forbidden not in hint
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 4: Env Var Config Override (Task 11.3)
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=10)
|
||||
@given(val=st.integers(min_value=1, max_value=100))
|
||||
def test_prop4_env_override_max_data_context_retries(val):
|
||||
"""APP_MAX_DATA_CONTEXT_RETRIES env var must override config."""
|
||||
from config.app_config import AppConfig
|
||||
os.environ["APP_MAX_DATA_CONTEXT_RETRIES"] = str(val)
|
||||
try:
|
||||
config = AppConfig.from_env()
|
||||
assert config.max_data_context_retries == val
|
||||
finally:
|
||||
del os.environ["APP_MAX_DATA_CONTEXT_RETRIES"]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 5: Sliding Window Trimming Invariants (Task 11.4)
|
||||
# ===========================================================================
|
||||
|
||||
def make_history(n_pairs: int, first_msg: str = "initial requirement"):
|
||||
"""Build a fake conversation history with n_pairs of user+assistant messages."""
|
||||
history = [{"role": "user", "content": first_msg}]
|
||||
for i in range(n_pairs):
|
||||
history.append({"role": "assistant", "content": f'action: "generate_code"\ncode: | print({i})'})
|
||||
history.append({"role": "user", "content": f"代码执行反馈:\n成功 round {i}"})
|
||||
return history
|
||||
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(
|
||||
n_pairs=st.integers(min_value=1, max_value=30),
|
||||
window=st.integers(min_value=1, max_value=10),
|
||||
)
|
||||
def test_prop5_trimming_preserves_first_message(n_pairs, window):
|
||||
"""After trimming, the first user message is always at index 0."""
|
||||
history = make_history(n_pairs, first_msg="ORIGINAL_REQ")
|
||||
max_messages = window * 2
|
||||
|
||||
if len(history) <= max_messages:
|
||||
return # no trimming needed, invariant trivially holds
|
||||
|
||||
first_message = history[0]
|
||||
start_idx = 1
|
||||
has_summary = (
|
||||
len(history) > 1
|
||||
and history[1]["role"] == "user"
|
||||
and history[1]["content"].startswith("[分析摘要]")
|
||||
)
|
||||
if has_summary:
|
||||
start_idx = 2
|
||||
|
||||
messages_to_consider = history[start_idx:]
|
||||
messages_to_trim = messages_to_consider[:-max_messages]
|
||||
messages_to_keep = messages_to_consider[-max_messages:]
|
||||
|
||||
if not messages_to_trim:
|
||||
return
|
||||
|
||||
new_history = [first_message]
|
||||
new_history.append({"role": "user", "content": "[分析摘要] summary"})
|
||||
new_history.extend(messages_to_keep)
|
||||
|
||||
assert new_history[0]["content"] == "ORIGINAL_REQ"
|
||||
assert len(new_history) <= max_messages + 2 # first + summary + window
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 6: Trimming Summary Content (Task 11.5)
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(n_pairs=st.integers(min_value=2, max_value=15))
|
||||
def test_prop6_summary_excludes_code_blocks(n_pairs):
|
||||
"""Compressed summary must not contain code blocks or raw output."""
|
||||
history = make_history(n_pairs)
|
||||
# Simulate _compress_trimmed_messages logic
|
||||
summary_parts = ["[分析摘要] 以下是之前分析轮次的概要:"]
|
||||
round_num = 0
|
||||
for msg in history[1:]: # skip first
|
||||
content = msg["content"]
|
||||
if msg["role"] == "assistant":
|
||||
round_num += 1
|
||||
action = "generate_code"
|
||||
if "collect_figures" in content:
|
||||
action = "collect_figures"
|
||||
summary_parts.append(f"- 轮次{round_num}: 动作={action}")
|
||||
elif msg["role"] == "user" and "代码执行反馈" in content:
|
||||
success = "失败" if "[ERROR]" in content or "执行错误" in content else "成功"
|
||||
if summary_parts and summary_parts[-1].startswith("- 轮次"):
|
||||
summary_parts[-1] += f", 执行结果={success}"
|
||||
|
||||
summary = "\n".join(summary_parts)
|
||||
assert "```" not in summary
|
||||
assert "print(" not in summary
|
||||
assert "[分析摘要]" in summary
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 7: Template Prompt Integration (Task 11.6)
|
||||
# ===========================================================================
|
||||
|
||||
valid_template_names = list(TEMPLATE_REGISTRY.keys())
|
||||
|
||||
|
||||
@settings(max_examples=len(valid_template_names))
|
||||
@given(name=st.sampled_from(valid_template_names))
|
||||
def test_prop7_template_prompt_prepended(name):
|
||||
"""For any valid template, get_full_prompt() output must be non-empty."""
|
||||
template = get_template(name)
|
||||
prompt = template.get_full_prompt()
|
||||
assert len(prompt) > 0
|
||||
assert template.name in prompt
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 8: Invalid Template Name Raises Error (Task 11.7)
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=20)
|
||||
@given(name=st.text(min_size=1, max_size=30).filter(lambda s: s not in TEMPLATE_REGISTRY))
|
||||
def test_prop8_invalid_template_raises_error(name):
|
||||
"""Invalid template names must raise ValueError listing available templates."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_template(name)
|
||||
error_msg = str(exc_info.value)
|
||||
# Must list available template names
|
||||
for valid_name in TEMPLATE_REGISTRY:
|
||||
assert valid_name in error_msg
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 11: Parallel Profile Merge With Error Resilience (Task 11.8)
|
||||
# ===========================================================================
|
||||
|
||||
def test_prop11_parallel_profile_error_resilience():
|
||||
"""Parallel profiling with mix of valid/invalid files includes all entries."""
|
||||
from utils.data_privacy import build_safe_profile, build_local_profile
|
||||
|
||||
valid_file = "uploads/data_simple_200.csv"
|
||||
invalid_file = "/nonexistent/fake_file.csv"
|
||||
|
||||
# Test build_safe_profile handles missing files gracefully
|
||||
safe = build_safe_profile([valid_file, invalid_file])
|
||||
assert "fake_file.csv" in safe # error entry present
|
||||
if os.path.exists(valid_file):
|
||||
assert "data_simple_200.csv" in safe # valid entry present
|
||||
229
tests/test_unit.py
Normal file
229
tests/test_unit.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Unit and integration tests for agent-robustness-optimization features.
|
||||
|
||||
Run: python -m pytest tests/test_unit.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import re
|
||||
import pytest
|
||||
|
||||
from utils.data_privacy import (
|
||||
_extract_column_from_error,
|
||||
_lookup_column_in_profile,
|
||||
generate_enriched_hint,
|
||||
)
|
||||
from utils.analysis_templates import get_template, list_templates, TEMPLATE_REGISTRY
|
||||
from config.app_config import AppConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DATA_CONTEXT_PATTERNS = [
|
||||
r"KeyError:\s*['\"](.+?)['\"]",
|
||||
r"ValueError.*(?:column|col|field)",
|
||||
r"NameError.*(?:df|data|frame)",
|
||||
r"(?:empty|no\s+data|0\s+rows)",
|
||||
r"IndexError.*(?:out of range|out of bounds)",
|
||||
]
|
||||
|
||||
|
||||
def classify_error(error_message: str) -> str:
|
||||
for pattern in DATA_CONTEXT_PATTERNS:
|
||||
if re.search(pattern, error_message, re.IGNORECASE):
|
||||
return "data_context"
|
||||
return "other"
|
||||
|
||||
|
||||
SAMPLE_PROFILE = """| 列名 | 数据类型 | 空值率 | 唯一值数 | 特征描述 |
|
||||
|------|---------|--------|---------|----------|
|
||||
| 车型 | object | 0.0% | 5 | 低基数分类(5类) |
|
||||
| 模块 | object | 2.0% | 12 | 中基数分类(12类) |
|
||||
"""
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 12.1: Unit tests for error classifier
|
||||
# ===========================================================================
|
||||
|
||||
class TestErrorClassifier:
|
||||
def test_keyerror_single_quotes(self):
|
||||
assert classify_error("KeyError: '车型'") == "data_context"
|
||||
|
||||
def test_keyerror_double_quotes(self):
|
||||
assert classify_error('KeyError: "model_name"') == "data_context"
|
||||
|
||||
def test_valueerror_column(self):
|
||||
assert classify_error("ValueError: column 'x' not in DataFrame") == "data_context"
|
||||
|
||||
def test_nameerror_df(self):
|
||||
assert classify_error("NameError: name 'df' is not defined") == "data_context"
|
||||
|
||||
def test_empty_dataframe(self):
|
||||
assert classify_error("empty DataFrame after filtering") == "data_context"
|
||||
|
||||
def test_zero_rows(self):
|
||||
assert classify_error("0 rows returned from query") == "data_context"
|
||||
|
||||
def test_index_out_of_range(self):
|
||||
assert classify_error("IndexError: index 10 is out of range") == "data_context"
|
||||
|
||||
def test_syntax_error_is_other(self):
|
||||
assert classify_error("SyntaxError: invalid syntax") == "other"
|
||||
|
||||
def test_type_error_is_other(self):
|
||||
assert classify_error("TypeError: unsupported operand") == "other"
|
||||
|
||||
def test_generic_text_is_other(self):
|
||||
assert classify_error("Something went wrong") == "other"
|
||||
|
||||
def test_empty_string_is_other(self):
|
||||
assert classify_error("") == "other"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 12.1 continued: Unit tests for column extraction and lookup
|
||||
# ===========================================================================
|
||||
|
||||
class TestColumnExtraction:
|
||||
def test_extract_from_keyerror(self):
|
||||
assert _extract_column_from_error("KeyError: '车型'") == "车型"
|
||||
|
||||
def test_extract_from_column_phrase(self):
|
||||
assert _extract_column_from_error("column '模块' not found") == "模块"
|
||||
|
||||
def test_extract_none_for_generic(self):
|
||||
assert _extract_column_from_error("SyntaxError: bad") is None
|
||||
|
||||
def test_lookup_existing_column(self):
|
||||
result = _lookup_column_in_profile("车型", SAMPLE_PROFILE)
|
||||
assert result is not None
|
||||
assert result["dtype"] == "object"
|
||||
assert result["unique_count"] == "5"
|
||||
|
||||
def test_lookup_missing_column(self):
|
||||
assert _lookup_column_in_profile("不存在", SAMPLE_PROFILE) is None
|
||||
|
||||
def test_lookup_none_column(self):
|
||||
assert _lookup_column_in_profile(None, SAMPLE_PROFILE) is None
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 12.2: Unit tests for conversation trimming at boundary conditions
|
||||
# ===========================================================================
|
||||
|
||||
class TestConversationTrimming:
|
||||
def _make_history(self, n_pairs):
|
||||
history = [{"role": "user", "content": "ORIGINAL"}]
|
||||
for i in range(n_pairs):
|
||||
history.append({"role": "assistant", "content": f"response {i}"})
|
||||
history.append({"role": "user", "content": f"feedback {i}"})
|
||||
return history
|
||||
|
||||
def test_no_trimming_when_under_limit(self):
|
||||
"""History with 3 pairs and window=5 should not be trimmed."""
|
||||
history = self._make_history(3) # 1 + 6 = 7 messages
|
||||
window = 5
|
||||
max_messages = window * 2 # 10
|
||||
assert len(history) <= max_messages # no trimming
|
||||
|
||||
def test_trimming_at_exact_boundary(self):
|
||||
"""History exactly at 2*window should not be trimmed."""
|
||||
window = 3
|
||||
history = self._make_history(3) # 1 + 6 = 7 messages
|
||||
max_messages = window * 2 # 6
|
||||
# 7 > 6, so trimming should happen
|
||||
assert len(history) > max_messages
|
||||
|
||||
def test_first_message_always_preserved(self):
|
||||
"""After trimming, first message must be preserved."""
|
||||
history = self._make_history(10)
|
||||
window = 2
|
||||
max_messages = window * 2
|
||||
|
||||
first = history[0]
|
||||
to_consider = history[1:]
|
||||
to_keep = to_consider[-max_messages:]
|
||||
|
||||
new_history = [first, {"role": "user", "content": "[分析摘要] ..."}]
|
||||
new_history.extend(to_keep)
|
||||
|
||||
assert new_history[0]["content"] == "ORIGINAL"
|
||||
|
||||
def test_summary_replaces_old_summary(self):
|
||||
"""If a summary already exists at index 1, it should be replaced."""
|
||||
history = [
|
||||
{"role": "user", "content": "ORIGINAL"},
|
||||
{"role": "user", "content": "[分析摘要] old summary"},
|
||||
]
|
||||
for i in range(8):
|
||||
history.append({"role": "assistant", "content": f"resp {i}"})
|
||||
history.append({"role": "user", "content": f"fb {i}"})
|
||||
|
||||
# Simulate trimming with existing summary
|
||||
has_summary = history[1]["content"].startswith("[分析摘要]")
|
||||
assert has_summary
|
||||
start_idx = 2 if has_summary else 1
|
||||
assert start_idx == 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 12.3: Tests for template API
|
||||
# ===========================================================================
|
||||
|
||||
class TestTemplateSystem:
|
||||
def test_list_templates_returns_all(self):
|
||||
templates = list_templates()
|
||||
assert len(templates) == len(TEMPLATE_REGISTRY)
|
||||
names = {t["name"] for t in templates}
|
||||
assert names == set(TEMPLATE_REGISTRY.keys())
|
||||
|
||||
def test_get_valid_template(self):
|
||||
for name in TEMPLATE_REGISTRY:
|
||||
t = get_template(name)
|
||||
assert t.name # has a display name
|
||||
steps = t.build_steps()
|
||||
assert len(steps) > 0
|
||||
|
||||
def test_get_invalid_template_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
get_template("nonexistent_template_xyz")
|
||||
|
||||
def test_template_prompt_not_empty(self):
|
||||
for name in TEMPLATE_REGISTRY:
|
||||
t = get_template(name)
|
||||
prompt = t.get_full_prompt()
|
||||
assert len(prompt) > 50 # should be substantial
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Task 12.4: Tests for config
|
||||
# ===========================================================================
|
||||
|
||||
class TestAppConfig:
|
||||
def test_defaults(self):
|
||||
config = AppConfig()
|
||||
assert config.max_data_context_retries == 2
|
||||
assert config.conversation_window_size == 10
|
||||
assert config.max_parallel_profiles == 4
|
||||
|
||||
def test_env_override(self):
|
||||
os.environ["APP_MAX_DATA_CONTEXT_RETRIES"] = "5"
|
||||
os.environ["APP_CONVERSATION_WINDOW_SIZE"] = "20"
|
||||
os.environ["APP_MAX_PARALLEL_PROFILES"] = "8"
|
||||
try:
|
||||
config = AppConfig.from_env()
|
||||
assert config.max_data_context_retries == 5
|
||||
assert config.conversation_window_size == 20
|
||||
assert config.max_parallel_profiles == 8
|
||||
finally:
|
||||
del os.environ["APP_MAX_DATA_CONTEXT_RETRIES"]
|
||||
del os.environ["APP_CONVERSATION_WINDOW_SIZE"]
|
||||
del os.environ["APP_MAX_PARALLEL_PROFILES"]
|
||||
Reference in New Issue
Block a user