Files
vibe_data_ana/tests/test_viz_tools.py

358 lines
11 KiB
Python
Raw Normal View History

"""可视化工具的单元测试。"""
import pytest
import pandas as pd
import numpy as np
import os
from pathlib import Path
import tempfile
import shutil
from src.tools.viz_tools import (
CreateBarChartTool,
CreateLineChartTool,
CreatePieChartTool,
CreateHeatmapTool
)
from src.models import DataProfile, ColumnInfo
@pytest.fixture
def temp_output_dir():
"""创建临时输出目录。"""
temp_dir = tempfile.mkdtemp()
yield temp_dir
# 清理
shutil.rmtree(temp_dir, ignore_errors=True)
class TestCreateBarChartTool:
"""测试柱状图工具。"""
def test_basic_functionality(self, temp_output_dir):
"""测试基本功能。"""
tool = CreateBarChartTool()
df = pd.DataFrame({
'category': ['A', 'B', 'C', 'A', 'B', 'A'],
'value': [10, 20, 30, 15, 25, 20]
})
output_path = os.path.join(temp_output_dir, 'bar_chart.png')
result = tool.execute(df, x_column='category', output_path=output_path)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['chart_type'] == 'bar'
assert result['x_column'] == 'category'
def test_with_y_column(self, temp_output_dir):
"""测试指定Y列。"""
tool = CreateBarChartTool()
df = pd.DataFrame({
'category': ['A', 'B', 'C'],
'value': [100, 200, 300]
})
output_path = os.path.join(temp_output_dir, 'bar_chart_y.png')
result = tool.execute(
df,
x_column='category',
y_column='value',
output_path=output_path
)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['y_column'] == 'value'
def test_top_n_limit(self, temp_output_dir):
"""测试 top_n 限制。"""
tool = CreateBarChartTool()
df = pd.DataFrame({
'category': [f'cat_{i}' for i in range(50)],
'value': range(50)
})
output_path = os.path.join(temp_output_dir, 'bar_chart_top.png')
result = tool.execute(
df,
x_column='category',
y_column='value',
top_n=10,
output_path=output_path
)
assert result['success'] is True
assert result['data_points'] == 10
def test_nonexistent_column(self):
"""测试不存在的列。"""
tool = CreateBarChartTool()
df = pd.DataFrame({'col1': [1, 2, 3]})
result = tool.execute(df, x_column='nonexistent')
assert 'error' in result
class TestCreateLineChartTool:
"""测试折线图工具。"""
def test_basic_functionality(self, temp_output_dir):
"""测试基本功能。"""
tool = CreateLineChartTool()
df = pd.DataFrame({
'x': range(10),
'y': [i * 2 for i in range(10)]
})
output_path = os.path.join(temp_output_dir, 'line_chart.png')
result = tool.execute(
df,
x_column='x',
y_column='y',
output_path=output_path
)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['chart_type'] == 'line'
def test_with_datetime(self, temp_output_dir):
"""测试时间序列数据。"""
tool = CreateLineChartTool()
dates = pd.date_range('2020-01-01', periods=20, freq='D')
df = pd.DataFrame({
'date': dates,
'value': range(20)
})
output_path = os.path.join(temp_output_dir, 'line_chart_time.png')
result = tool.execute(
df,
x_column='date',
y_column='value',
output_path=output_path
)
assert result['success'] is True
assert os.path.exists(output_path)
def test_large_dataset_sampling(self, temp_output_dir):
"""测试大数据集采样。"""
tool = CreateLineChartTool()
df = pd.DataFrame({
'x': range(2000),
'y': range(2000)
})
output_path = os.path.join(temp_output_dir, 'line_chart_large.png')
result = tool.execute(
df,
x_column='x',
y_column='y',
output_path=output_path
)
assert result['success'] is True
# 应该被采样到1000个点左右
assert result['data_points'] <= 1000
class TestCreatePieChartTool:
"""测试饼图工具。"""
def test_basic_functionality(self, temp_output_dir):
"""测试基本功能。"""
tool = CreatePieChartTool()
df = pd.DataFrame({
'category': ['A', 'B', 'C', 'A', 'B', 'A']
})
output_path = os.path.join(temp_output_dir, 'pie_chart.png')
result = tool.execute(
df,
column='category',
output_path=output_path
)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['chart_type'] == 'pie'
assert result['categories'] == 3
def test_top_n_with_others(self, temp_output_dir):
"""测试 top_n 并归类其他。"""
tool = CreatePieChartTool()
df = pd.DataFrame({
'category': [f'cat_{i}' for i in range(20)] * 5
})
output_path = os.path.join(temp_output_dir, 'pie_chart_top.png')
result = tool.execute(
df,
column='category',
top_n=5,
output_path=output_path
)
assert result['success'] is True
# 5个类别 + 1个"其他"
assert result['categories'] == 6
class TestCreateHeatmapTool:
"""测试热力图工具。"""
def test_basic_functionality(self, temp_output_dir):
"""测试基本功能。"""
tool = CreateHeatmapTool()
df = pd.DataFrame({
'x': range(10),
'y': [i * 2 for i in range(10)],
'z': [i * 3 for i in range(10)]
})
output_path = os.path.join(temp_output_dir, 'heatmap.png')
result = tool.execute(df, output_path=output_path)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['chart_type'] == 'heatmap'
assert len(result['columns']) == 3
def test_with_specific_columns(self, temp_output_dir):
"""测试指定列。"""
tool = CreateHeatmapTool()
df = pd.DataFrame({
'a': range(10),
'b': range(10, 20),
'c': range(20, 30),
'd': range(30, 40)
})
output_path = os.path.join(temp_output_dir, 'heatmap_cols.png')
result = tool.execute(
df,
columns=['a', 'b', 'c'],
output_path=output_path
)
assert result['success'] is True
assert len(result['columns']) == 3
assert 'd' not in result['columns']
def test_insufficient_columns(self):
"""测试列数不足。"""
tool = CreateHeatmapTool()
df = pd.DataFrame({'x': range(10)})
result = tool.execute(df)
assert 'error' in result
class TestVisualizationToolsApplicability:
"""测试可视化工具的适用性判断。"""
def test_bar_chart_applicability(self):
"""测试柱状图适用性。"""
tool = CreateBarChartTool()
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='cat', dtype='categorical', missing_rate=0.0, unique_count=5)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile) is True
def test_line_chart_applicability(self):
"""测试折线图适用性。"""
tool = CreateLineChartTool()
# 包含数值列
profile_numeric = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile_numeric) is True
# 不包含数值列
profile_text = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='text', dtype='text', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile_text) is False
def test_heatmap_applicability(self):
"""测试热力图适用性。"""
tool = CreateHeatmapTool()
# 包含至少两个数值列
profile_sufficient = DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=[
ColumnInfo(name='x', dtype='numeric', missing_rate=0.0, unique_count=50),
ColumnInfo(name='y', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile_sufficient) is True
# 只有一个数值列
profile_insufficient = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='x', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile_insufficient) is False
class TestVisualizationErrorHandling:
"""测试可视化工具的错误处理。"""
def test_invalid_output_path(self):
"""测试无效的输出路径。"""
tool = CreateBarChartTool()
df = pd.DataFrame({'cat': ['A', 'B', 'C']})
# 使用无效路径(只读目录等)
# 注意:这个测试可能在某些系统上不会失败
result = tool.execute(
df,
x_column='cat',
output_path='/invalid/path/chart.png'
)
# 应该返回错误或成功创建目录
assert 'error' in result or result['success'] is True
def test_empty_dataframe(self):
"""测试空 DataFrame。"""
tool = CreateBarChartTool()
df = pd.DataFrame()
result = tool.execute(df, x_column='nonexistent')
assert 'error' in result