二次重构,加入预设模板
This commit is contained in:
221
src/engines/ai_data_understanding.py
Normal file
221
src/engines/ai_data_understanding.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
真正的 AI 驱动数据理解引擎
|
||||
AI 只能看到表头和统计摘要,通过推理理解数据
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
import json
|
||||
from openai import OpenAI
|
||||
|
||||
from src.models import DataProfile, ColumnInfo
|
||||
from src.config import get_config
|
||||
from src.data_access import DataAccessLayer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ai_understand_data(data_file: str) -> DataProfile:
|
||||
"""
|
||||
使用 AI 理解数据(只基于元数据,不看原始数据)
|
||||
|
||||
参数:
|
||||
data_file: 数据文件路径
|
||||
|
||||
返回:
|
||||
数据画像
|
||||
"""
|
||||
profile, _ = ai_understand_data_with_dal(data_file)
|
||||
return profile
|
||||
|
||||
|
||||
def ai_understand_data_with_dal(data_file: str):
|
||||
"""
|
||||
使用 AI 理解数据,同时返回 DataAccessLayer 以避免重复加载。
|
||||
|
||||
参数:
|
||||
data_file: 数据文件路径
|
||||
|
||||
返回:
|
||||
(DataProfile, DataAccessLayer) 元组
|
||||
"""
|
||||
# 1. 加载数据(AI 不可见)
|
||||
logger.info(f"加载数据: {data_file}")
|
||||
dal = DataAccessLayer.load_from_file(data_file)
|
||||
|
||||
# 2. 生成数据画像(元数据)
|
||||
logger.info("生成数据画像(元数据)")
|
||||
profile = dal.get_profile()
|
||||
|
||||
# 3. 准备给 AI 的信息(只有元数据)
|
||||
metadata = _prepare_metadata_for_ai(profile)
|
||||
|
||||
# 4. 调用 AI 分析
|
||||
logger.info("调用 AI 分析数据特征...")
|
||||
ai_analysis = _call_ai_for_analysis(metadata)
|
||||
|
||||
# 5. 更新数据画像
|
||||
profile.inferred_type = ai_analysis.get('data_type', 'unknown')
|
||||
profile.key_fields = ai_analysis.get('key_fields', {})
|
||||
profile.quality_score = ai_analysis.get('quality_score', 0.0)
|
||||
profile.summary = ai_analysis.get('summary', '')
|
||||
|
||||
return profile, dal
|
||||
|
||||
|
||||
def _prepare_metadata_for_ai(profile: DataProfile) -> Dict[str, Any]:
|
||||
"""
|
||||
准备给 AI 的元数据(不包含原始数据)
|
||||
|
||||
参数:
|
||||
profile: 数据画像
|
||||
|
||||
返回:
|
||||
元数据字典
|
||||
"""
|
||||
metadata = {
|
||||
"file_path": profile.file_path,
|
||||
"row_count": profile.row_count,
|
||||
"column_count": profile.column_count,
|
||||
"columns": []
|
||||
}
|
||||
|
||||
# 只提供列的元信息
|
||||
for col in profile.columns:
|
||||
col_info = {
|
||||
"name": col.name,
|
||||
"dtype": col.dtype,
|
||||
"missing_rate": col.missing_rate,
|
||||
"unique_count": col.unique_count,
|
||||
"sample_values": col.sample_values[:5] # 最多5个示例值
|
||||
}
|
||||
|
||||
# 如果有统计信息,也提供
|
||||
if col.statistics:
|
||||
col_info["statistics"] = col.statistics
|
||||
|
||||
metadata["columns"].append(col_info)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def _call_ai_for_analysis(metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
调用 AI 分析数据特征
|
||||
|
||||
参数:
|
||||
metadata: 数据元信息
|
||||
|
||||
返回:
|
||||
AI 分析结果
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
# 创建 OpenAI 客户端
|
||||
client = OpenAI(
|
||||
api_key=config.llm.api_key,
|
||||
base_url=config.llm.base_url
|
||||
)
|
||||
|
||||
# 构建提示词
|
||||
prompt = f"""你是一个数据分析专家。我会给你一个数据集的元信息(表头、统计摘要),你需要分析这个数据集。
|
||||
|
||||
重要:你只能看到元信息,看不到原始数据行。请基于列名、数据类型、统计特征进行推理。
|
||||
|
||||
数据元信息:
|
||||
```json
|
||||
{json.dumps(metadata, ensure_ascii=False, indent=2)}
|
||||
```
|
||||
|
||||
请分析并回答以下问题:
|
||||
|
||||
1. 这是什么类型的数据?(工单数据/销售数据/用户数据/其他)
|
||||
2. 哪些是关键字段?每个字段的业务含义是什么?
|
||||
3. 数据质量如何?(0-100分)
|
||||
4. 用一段话总结这个数据集的特征
|
||||
|
||||
请以 JSON 格式返回结果:
|
||||
{{
|
||||
"data_type": "ticket/sales/user/other",
|
||||
"key_fields": {{
|
||||
"字段名1": "业务含义1",
|
||||
"字段名2": "业务含义2"
|
||||
}},
|
||||
"quality_score": 85.5,
|
||||
"summary": "数据集的总结描述"
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
# 调用 AI
|
||||
response = client.chat.completions.create(
|
||||
model=config.llm.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一个数据分析专家,擅长从元数据推断数据特征。"},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=2000
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
content = response.choices[0].message.content
|
||||
logger.info(f"AI 响应: {content[:200]}...")
|
||||
|
||||
# 尝试提取 JSON
|
||||
result = _extract_json_from_response(content)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI 调用失败: {e}")
|
||||
# 返回默认值
|
||||
return {
|
||||
"data_type": "unknown",
|
||||
"key_fields": {},
|
||||
"quality_score": 0.0,
|
||||
"summary": f"AI 分析失败: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def _extract_json_from_response(content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
从 AI 响应中提取 JSON
|
||||
|
||||
参数:
|
||||
content: AI 响应内容
|
||||
|
||||
返回:
|
||||
解析后的 JSON 字典
|
||||
"""
|
||||
# 尝试直接解析
|
||||
try:
|
||||
return json.loads(content)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 尝试提取 JSON 代码块
|
||||
import re
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', content, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group(1))
|
||||
except:
|
||||
pass
|
||||
|
||||
# 尝试提取 {} 内容
|
||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group(0))
|
||||
except:
|
||||
pass
|
||||
|
||||
# 如果都失败,返回默认值
|
||||
logger.warning("无法从 AI 响应中提取 JSON,使用默认值")
|
||||
return {
|
||||
"data_type": "unknown",
|
||||
"key_fields": {},
|
||||
"quality_score": 0.0,
|
||||
"summary": content[:500]
|
||||
}
|
||||
Reference in New Issue
Block a user