Complete AI Data Analysis Agent implementation with 95.7% test coverage
This commit is contained in:
328
tests/test_requirement_understanding.py
Normal file
328
tests/test_requirement_understanding.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""Unit tests for requirement understanding engine."""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from src.engines.requirement_understanding import (
|
||||
understand_requirement,
|
||||
parse_template,
|
||||
check_data_requirement_match,
|
||||
_fallback_requirement_understanding
|
||||
)
|
||||
from src.models.data_profile import DataProfile, ColumnInfo
|
||||
from src.models.requirement_spec import RequirementSpec, AnalysisObjective
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data_profile():
|
||||
"""Create a sample data profile for testing."""
|
||||
return DataProfile(
|
||||
file_path='test.csv',
|
||||
row_count=1000,
|
||||
column_count=5,
|
||||
columns=[
|
||||
ColumnInfo(
|
||||
name='created_at',
|
||||
dtype='datetime',
|
||||
missing_rate=0.0,
|
||||
unique_count=1000,
|
||||
sample_values=['2024-01-01', '2024-01-02'],
|
||||
statistics={}
|
||||
),
|
||||
ColumnInfo(
|
||||
name='status',
|
||||
dtype='categorical',
|
||||
missing_rate=0.1,
|
||||
unique_count=5,
|
||||
sample_values=['open', 'closed', 'pending'],
|
||||
statistics={}
|
||||
),
|
||||
ColumnInfo(
|
||||
name='type',
|
||||
dtype='categorical',
|
||||
missing_rate=0.0,
|
||||
unique_count=10,
|
||||
sample_values=['bug', 'feature'],
|
||||
statistics={}
|
||||
),
|
||||
ColumnInfo(
|
||||
name='priority',
|
||||
dtype='numeric',
|
||||
missing_rate=0.0,
|
||||
unique_count=5,
|
||||
sample_values=[1, 2, 3, 4, 5],
|
||||
statistics={'mean': 3.0, 'std': 1.2}
|
||||
),
|
||||
ColumnInfo(
|
||||
name='description',
|
||||
dtype='text',
|
||||
missing_rate=0.05,
|
||||
unique_count=950,
|
||||
sample_values=['Issue 1', 'Issue 2'],
|
||||
statistics={}
|
||||
)
|
||||
],
|
||||
inferred_type='ticket',
|
||||
key_fields={'time': 'created_at', 'status': 'status', 'type': 'type'},
|
||||
quality_score=85.0,
|
||||
summary='Ticket data with 1000 rows and 5 columns'
|
||||
)
|
||||
|
||||
|
||||
def test_understand_health_requirement(sample_data_profile):
|
||||
"""Test understanding "健康度" requirement."""
|
||||
user_input = "我想了解工单的健康度"
|
||||
|
||||
# Use fallback to avoid API dependency
|
||||
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
|
||||
|
||||
# Verify basic structure
|
||||
assert isinstance(requirement, RequirementSpec)
|
||||
assert requirement.user_input == user_input
|
||||
assert len(requirement.objectives) > 0
|
||||
|
||||
# Verify health-related objective exists
|
||||
health_objectives = [obj for obj in requirement.objectives if '健康' in obj.name]
|
||||
assert len(health_objectives) > 0
|
||||
|
||||
# Verify objective has metrics
|
||||
health_obj = health_objectives[0]
|
||||
assert len(health_obj.metrics) > 0
|
||||
assert health_obj.priority >= 1 and health_obj.priority <= 5
|
||||
|
||||
|
||||
def test_understand_trend_requirement(sample_data_profile):
|
||||
"""Test understanding trend analysis requirement."""
|
||||
user_input = "分析趋势"
|
||||
|
||||
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
|
||||
|
||||
# Verify trend objective exists
|
||||
trend_objectives = [obj for obj in requirement.objectives if '趋势' in obj.name]
|
||||
assert len(trend_objectives) > 0
|
||||
|
||||
# Verify metrics
|
||||
trend_obj = trend_objectives[0]
|
||||
assert len(trend_obj.metrics) > 0
|
||||
|
||||
|
||||
def test_understand_distribution_requirement(sample_data_profile):
|
||||
"""Test understanding distribution analysis requirement."""
|
||||
user_input = "查看分布情况"
|
||||
|
||||
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
|
||||
|
||||
# Verify distribution objective exists
|
||||
dist_objectives = [obj for obj in requirement.objectives if '分布' in obj.name]
|
||||
assert len(dist_objectives) > 0
|
||||
|
||||
|
||||
def test_understand_generic_requirement(sample_data_profile):
|
||||
"""Test understanding generic requirement without specific keywords."""
|
||||
user_input = "帮我分析一下"
|
||||
|
||||
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
|
||||
|
||||
# Should still generate at least one objective
|
||||
assert len(requirement.objectives) > 0
|
||||
|
||||
# Should have default objective
|
||||
assert any('综合' in obj.name or 'analysis' in obj.name.lower() for obj in requirement.objectives)
|
||||
|
||||
|
||||
def test_parse_template_with_sections():
|
||||
"""Test parsing template with sections."""
|
||||
template_content = """# 分析报告
|
||||
|
||||
## 数据概览
|
||||
这是数据概览部分
|
||||
|
||||
## 趋势分析
|
||||
指标: 增长率, 变化趋势
|
||||
图表: 时间序列图
|
||||
|
||||
## 分布分析
|
||||
指标: 类别分布
|
||||
图表: 柱状图, 饼图
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f:
|
||||
f.write(template_content)
|
||||
template_path = f.name
|
||||
|
||||
try:
|
||||
template_req = parse_template(template_path)
|
||||
|
||||
# Verify sections
|
||||
assert len(template_req['sections']) >= 3
|
||||
assert '分析报告' in template_req['sections']
|
||||
assert '数据概览' in template_req['sections']
|
||||
|
||||
# Verify metrics
|
||||
assert len(template_req['required_metrics']) >= 2
|
||||
|
||||
# Verify charts
|
||||
assert len(template_req['required_charts']) >= 2
|
||||
|
||||
finally:
|
||||
os.unlink(template_path)
|
||||
|
||||
|
||||
def test_parse_nonexistent_template():
|
||||
"""Test parsing non-existent template."""
|
||||
template_req = parse_template('nonexistent.md')
|
||||
|
||||
# Should return empty structure
|
||||
assert template_req['sections'] == []
|
||||
assert template_req['required_metrics'] == []
|
||||
assert template_req['required_charts'] == []
|
||||
|
||||
|
||||
def test_check_data_satisfies_requirement(sample_data_profile):
|
||||
"""Test checking when data satisfies requirement."""
|
||||
# Create requirement that data can satisfy
|
||||
requirement = RequirementSpec(
|
||||
user_input="分析状态分布",
|
||||
objectives=[
|
||||
AnalysisObjective(
|
||||
name="状态分析",
|
||||
description="分析状态字段的分布",
|
||||
metrics=["状态分布"],
|
||||
priority=5
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
match_result = check_data_requirement_match(requirement, sample_data_profile)
|
||||
|
||||
# Should be satisfied
|
||||
assert match_result['can_proceed'] is True
|
||||
assert len(match_result['satisfied_objectives']) > 0
|
||||
|
||||
|
||||
def test_check_data_missing_fields(sample_data_profile):
|
||||
"""Test checking when data is missing required fields."""
|
||||
# Create requirement that needs fields not in data
|
||||
requirement = RequirementSpec(
|
||||
user_input="分析地理分布",
|
||||
objectives=[
|
||||
AnalysisObjective(
|
||||
name="地理分析",
|
||||
description="分析地理位置分布",
|
||||
metrics=["地理分布", "区域统计"],
|
||||
priority=5
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
match_result = check_data_requirement_match(requirement, sample_data_profile)
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(match_result, dict)
|
||||
assert 'missing_fields' in match_result
|
||||
assert 'unsatisfied_objectives' in match_result
|
||||
|
||||
|
||||
def test_check_time_based_requirement(sample_data_profile):
|
||||
"""Test checking time-based requirement."""
|
||||
requirement = RequirementSpec(
|
||||
user_input="分析时间趋势",
|
||||
objectives=[
|
||||
AnalysisObjective(
|
||||
name="时间分析",
|
||||
description="分析随时间的变化",
|
||||
metrics=["时间序列", "趋势"],
|
||||
priority=5
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
match_result = check_data_requirement_match(requirement, sample_data_profile)
|
||||
|
||||
# Should be satisfied since we have datetime column
|
||||
assert match_result['can_proceed'] is True
|
||||
|
||||
|
||||
def test_check_status_based_requirement(sample_data_profile):
|
||||
"""Test checking status-based requirement."""
|
||||
requirement = RequirementSpec(
|
||||
user_input="分析状态",
|
||||
objectives=[
|
||||
AnalysisObjective(
|
||||
name="状态分析",
|
||||
description="分析状态字段",
|
||||
metrics=["状态分布", "状态变化"],
|
||||
priority=5
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
match_result = check_data_requirement_match(requirement, sample_data_profile)
|
||||
|
||||
# Should be satisfied since we have status column
|
||||
assert match_result['can_proceed'] is True
|
||||
assert len(match_result['satisfied_objectives']) > 0
|
||||
|
||||
|
||||
def test_requirement_with_template(sample_data_profile):
|
||||
"""Test requirement understanding with template."""
|
||||
template_content = """# 工单分析报告
|
||||
|
||||
## 状态分析
|
||||
指标: 状态分布, 完成率
|
||||
|
||||
## 类型分析
|
||||
指标: 类型分布
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f:
|
||||
f.write(template_content)
|
||||
template_path = f.name
|
||||
|
||||
try:
|
||||
requirement = _fallback_requirement_understanding(
|
||||
"按模板分析",
|
||||
sample_data_profile,
|
||||
template_path
|
||||
)
|
||||
|
||||
# Verify template is included
|
||||
assert requirement.template_path == template_path
|
||||
assert requirement.template_requirements is not None
|
||||
|
||||
# Verify template requirements structure
|
||||
assert 'sections' in requirement.template_requirements
|
||||
assert 'required_metrics' in requirement.template_requirements
|
||||
|
||||
finally:
|
||||
os.unlink(template_path)
|
||||
|
||||
|
||||
def test_multiple_objectives_priority():
|
||||
"""Test that multiple objectives have proper priorities."""
|
||||
data_profile = DataProfile(
|
||||
file_path='test.csv',
|
||||
row_count=100,
|
||||
column_count=3,
|
||||
columns=[
|
||||
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100),
|
||||
ColumnInfo(name='col2', dtype='categorical', missing_rate=0.0, unique_count=5),
|
||||
ColumnInfo(name='col3', dtype='datetime', missing_rate=0.0, unique_count=100)
|
||||
],
|
||||
inferred_type='unknown',
|
||||
quality_score=90.0
|
||||
)
|
||||
|
||||
requirement = _fallback_requirement_understanding(
|
||||
"完整分析,包括健康度和趋势",
|
||||
data_profile,
|
||||
None
|
||||
)
|
||||
|
||||
# Should have multiple objectives
|
||||
assert len(requirement.objectives) >= 2
|
||||
|
||||
# All priorities should be valid
|
||||
for obj in requirement.objectives:
|
||||
assert 1 <= obj.priority <= 5
|
||||
Reference in New Issue
Block a user