Files
vibe_data_ana/tests/test_config.py

431 lines
14 KiB
Python
Raw Normal View History

"""配置管理模块的单元测试。"""
import os
import json
import pytest
from pathlib import Path
from unittest.mock import patch
from src.config import (
LLMConfig,
PerformanceConfig,
OutputConfig,
Config,
get_config,
set_config,
load_config_from_env,
load_config_from_file
)
class TestLLMConfig:
"""测试 LLM 配置。"""
def test_default_config(self):
"""测试默认配置。"""
config = LLMConfig(api_key="test_key")
assert config.provider == "openai"
assert config.api_key == "test_key"
assert config.base_url == "https://api.openai.com/v1"
assert config.model == "gpt-4"
assert config.timeout == 120
assert config.max_retries == 3
assert config.temperature == 0.7
assert config.max_tokens is None
def test_custom_config(self):
"""测试自定义配置。"""
config = LLMConfig(
provider="gemini",
api_key="gemini_key",
base_url="https://gemini.api",
model="gemini-pro",
timeout=60,
max_retries=5,
temperature=0.5,
max_tokens=1000
)
assert config.provider == "gemini"
assert config.api_key == "gemini_key"
assert config.base_url == "https://gemini.api"
assert config.model == "gemini-pro"
assert config.timeout == 60
assert config.max_retries == 5
assert config.temperature == 0.5
assert config.max_tokens == 1000
def test_empty_api_key(self):
"""测试空 API key。"""
with pytest.raises(ValueError, match="API key 不能为空"):
LLMConfig(api_key="")
def test_invalid_provider(self):
"""测试无效的 provider。"""
with pytest.raises(ValueError, match="不支持的 LLM provider"):
LLMConfig(api_key="test", provider="invalid")
def test_invalid_timeout(self):
"""测试无效的 timeout。"""
with pytest.raises(ValueError, match="timeout 必须大于 0"):
LLMConfig(api_key="test", timeout=0)
def test_invalid_max_retries(self):
"""测试无效的 max_retries。"""
with pytest.raises(ValueError, match="max_retries 不能为负数"):
LLMConfig(api_key="test", max_retries=-1)
class TestPerformanceConfig:
"""测试性能配置。"""
def test_default_config(self):
"""测试默认配置。"""
config = PerformanceConfig()
assert config.agent_max_rounds == 20
assert config.agent_timeout == 300
assert config.tool_max_query_rows == 10000
assert config.tool_execution_timeout == 60
assert config.data_max_rows == 1000000
assert config.data_sample_threshold == 1000000
assert config.max_concurrent_tasks == 1
def test_custom_config(self):
"""测试自定义配置。"""
config = PerformanceConfig(
agent_max_rounds=10,
agent_timeout=600,
tool_max_query_rows=5000,
tool_execution_timeout=30,
data_max_rows=500000,
data_sample_threshold=500000,
max_concurrent_tasks=2
)
assert config.agent_max_rounds == 10
assert config.agent_timeout == 600
assert config.tool_max_query_rows == 5000
assert config.tool_execution_timeout == 30
assert config.data_max_rows == 500000
assert config.data_sample_threshold == 500000
assert config.max_concurrent_tasks == 2
def test_invalid_agent_max_rounds(self):
"""测试无效的 agent_max_rounds。"""
with pytest.raises(ValueError, match="agent_max_rounds 必须大于 0"):
PerformanceConfig(agent_max_rounds=0)
def test_invalid_tool_max_query_rows(self):
"""测试无效的 tool_max_query_rows。"""
with pytest.raises(ValueError, match="tool_max_query_rows 必须大于 0"):
PerformanceConfig(tool_max_query_rows=-1)
class TestOutputConfig:
"""测试输出配置。"""
def test_default_config(self):
"""测试默认配置。"""
config = OutputConfig()
assert config.output_dir == "output"
assert config.log_dir == "output"
assert config.chart_dir == str(Path("output") / "charts")
assert config.report_filename == "analysis_report.md"
assert config.log_level == "INFO"
assert config.log_to_file is True
assert config.log_to_console is True
def test_custom_config(self):
"""测试自定义配置。"""
config = OutputConfig(
output_dir="results",
log_dir="logs",
chart_dir="charts",
report_filename="report.md",
log_level="DEBUG",
log_to_file=False,
log_to_console=True
)
assert config.output_dir == "results"
assert config.log_dir == "logs"
assert config.chart_dir == "charts"
assert config.report_filename == "report.md"
assert config.log_level == "DEBUG"
assert config.log_to_file is False
assert config.log_to_console is True
def test_invalid_log_level(self):
"""测试无效的 log_level。"""
with pytest.raises(ValueError, match="不支持的 log_level"):
OutputConfig(log_level="INVALID")
def test_get_paths(self):
"""测试路径获取方法。"""
config = OutputConfig(
output_dir="results",
log_dir="logs",
chart_dir="charts"
)
assert config.get_output_path() == Path("results")
assert config.get_log_path() == Path("logs")
assert config.get_chart_path() == Path("charts")
assert config.get_report_path() == Path("results/analysis_report.md")
class TestConfig:
"""测试系统配置。"""
def test_default_config(self):
"""测试默认配置。"""
config = Config(
llm=LLMConfig(api_key="test_key")
)
assert config.llm.api_key == "test_key"
assert config.performance.agent_max_rounds == 20
assert config.output.output_dir == "output"
assert config.code_repo_enable_reuse is True
def test_from_env(self):
"""测试从环境变量加载配置。"""
env_vars = {
"LLM_PROVIDER": "openai",
"OPENAI_API_KEY": "env_test_key",
"OPENAI_BASE_URL": "https://test.api",
"OPENAI_MODEL": "gpt-3.5-turbo",
"AGENT_MAX_ROUNDS": "15",
"AGENT_OUTPUT_DIR": "test_output",
"TOOL_MAX_QUERY_ROWS": "5000",
"CODE_REPO_ENABLE_REUSE": "false"
}
with patch.dict(os.environ, env_vars, clear=True):
config = Config.from_env()
assert config.llm.provider == "openai"
assert config.llm.api_key == "env_test_key"
assert config.llm.base_url == "https://test.api"
assert config.llm.model == "gpt-3.5-turbo"
assert config.performance.agent_max_rounds == 15
assert config.performance.tool_max_query_rows == 5000
assert config.output.output_dir == "test_output"
assert config.code_repo_enable_reuse is False
def test_from_env_gemini(self):
"""测试从环境变量加载 Gemini 配置。"""
env_vars = {
"LLM_PROVIDER": "gemini",
"GEMINI_API_KEY": "gemini_key",
"GEMINI_BASE_URL": "https://gemini.api",
"GEMINI_MODEL": "gemini-pro"
}
with patch.dict(os.environ, env_vars, clear=True):
config = Config.from_env()
assert config.llm.provider == "gemini"
assert config.llm.api_key == "gemini_key"
assert config.llm.base_url == "https://gemini.api"
assert config.llm.model == "gemini-pro"
def test_from_dict(self):
"""测试从字典加载配置。"""
config_dict = {
"llm": {
"provider": "openai",
"api_key": "dict_test_key",
"base_url": "https://dict.api",
"model": "gpt-4",
"timeout": 90,
"max_retries": 2,
"temperature": 0.5,
"max_tokens": 2000
},
"performance": {
"agent_max_rounds": 25,
"tool_max_query_rows": 8000
},
"output": {
"output_dir": "dict_output",
"log_level": "DEBUG"
},
"code_repo_enable_reuse": False
}
config = Config.from_dict(config_dict)
assert config.llm.api_key == "dict_test_key"
assert config.llm.base_url == "https://dict.api"
assert config.llm.timeout == 90
assert config.llm.max_retries == 2
assert config.llm.temperature == 0.5
assert config.llm.max_tokens == 2000
assert config.performance.agent_max_rounds == 25
assert config.performance.tool_max_query_rows == 8000
assert config.output.output_dir == "dict_output"
assert config.output.log_level == "DEBUG"
assert config.code_repo_enable_reuse is False
def test_from_file(self, tmp_path):
"""测试从文件加载配置。"""
config_file = tmp_path / "test_config.json"
config_dict = {
"llm": {
"provider": "openai",
"api_key": "file_test_key",
"model": "gpt-4"
},
"performance": {
"agent_max_rounds": 30
}
}
with open(config_file, 'w') as f:
json.dump(config_dict, f)
config = Config.from_file(str(config_file))
assert config.llm.api_key == "file_test_key"
assert config.llm.model == "gpt-4"
assert config.performance.agent_max_rounds == 30
def test_from_file_not_found(self):
"""测试加载不存在的配置文件。"""
with pytest.raises(FileNotFoundError):
Config.from_file("nonexistent.json")
def test_to_dict(self):
"""测试转换为字典。"""
config = Config(
llm=LLMConfig(
api_key="test_key",
model="gpt-4"
),
performance=PerformanceConfig(
agent_max_rounds=15
),
output=OutputConfig(
output_dir="test_output"
)
)
config_dict = config.to_dict()
assert config_dict["llm"]["api_key"] == "***" # API key 应该被隐藏
assert config_dict["llm"]["model"] == "gpt-4"
assert config_dict["performance"]["agent_max_rounds"] == 15
assert config_dict["output"]["output_dir"] == "test_output"
def test_save_to_file(self, tmp_path):
"""测试保存配置到文件。"""
config_file = tmp_path / "saved_config.json"
config = Config(
llm=LLMConfig(api_key="test_key"),
performance=PerformanceConfig(agent_max_rounds=15)
)
config.save_to_file(str(config_file))
assert config_file.exists()
with open(config_file, 'r') as f:
saved_dict = json.load(f)
assert saved_dict["llm"]["api_key"] == "***"
assert saved_dict["performance"]["agent_max_rounds"] == 15
def test_validate_success(self):
"""测试配置验证成功。"""
config = Config(
llm=LLMConfig(api_key="test_key")
)
assert config.validate() is True
def test_validate_missing_api_key(self):
"""测试配置验证失败(缺少 API key"""
config = Config(
llm=LLMConfig(api_key="test_key")
)
config.llm.api_key = "" # 手动清空
assert config.validate() is False
class TestGlobalConfig:
"""测试全局配置管理。"""
def test_get_config(self):
"""测试获取全局配置。"""
# 重置全局配置
set_config(None)
# 模拟环境变量
env_vars = {
"OPENAI_API_KEY": "global_test_key"
}
with patch.dict(os.environ, env_vars, clear=True):
config = get_config()
assert config is not None
assert config.llm.api_key == "global_test_key"
def test_set_config(self):
"""测试设置全局配置。"""
custom_config = Config(
llm=LLMConfig(api_key="custom_key")
)
set_config(custom_config)
config = get_config()
assert config.llm.api_key == "custom_key"
def test_load_config_from_env(self):
"""测试从环境变量加载全局配置。"""
env_vars = {
"OPENAI_API_KEY": "env_global_key",
"AGENT_MAX_ROUNDS": "25"
}
with patch.dict(os.environ, env_vars, clear=True):
config = load_config_from_env()
assert config.llm.api_key == "env_global_key"
assert config.performance.agent_max_rounds == 25
# 验证全局配置已更新
global_config = get_config()
assert global_config.llm.api_key == "env_global_key"
def test_load_config_from_file(self, tmp_path):
"""测试从文件加载全局配置。"""
config_file = tmp_path / "global_config.json"
config_dict = {
"llm": {
"provider": "openai",
"api_key": "file_global_key",
"model": "gpt-4"
}
}
with open(config_file, 'w') as f:
json.dump(config_dict, f)
config = load_config_from_file(str(config_file))
assert config.llm.api_key == "file_global_key"
# 验证全局配置已更新
global_config = get_config()
assert global_config.llm.api_key == "file_global_key"