222 lines
5.8 KiB
Python
222 lines
5.8 KiB
Python
"""
|
||
真正的 AI 驱动数据理解引擎
|
||
AI 只能看到表头和统计摘要,通过推理理解数据
|
||
"""
|
||
|
||
import logging
|
||
from typing import Dict, Any, List
|
||
import json
|
||
from openai import OpenAI
|
||
|
||
from src.models import DataProfile, ColumnInfo
|
||
from src.config import get_config
|
||
from src.data_access import DataAccessLayer
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def ai_understand_data(data_file: str) -> DataProfile:
|
||
"""
|
||
使用 AI 理解数据(只基于元数据,不看原始数据)
|
||
|
||
参数:
|
||
data_file: 数据文件路径
|
||
|
||
返回:
|
||
数据画像
|
||
"""
|
||
profile, _ = ai_understand_data_with_dal(data_file)
|
||
return profile
|
||
|
||
|
||
def ai_understand_data_with_dal(data_file: str):
|
||
"""
|
||
使用 AI 理解数据,同时返回 DataAccessLayer 以避免重复加载。
|
||
|
||
参数:
|
||
data_file: 数据文件路径
|
||
|
||
返回:
|
||
(DataProfile, DataAccessLayer) 元组
|
||
"""
|
||
# 1. 加载数据(AI 不可见)
|
||
logger.info(f"加载数据: {data_file}")
|
||
dal = DataAccessLayer.load_from_file(data_file)
|
||
|
||
# 2. 生成数据画像(元数据)
|
||
logger.info("生成数据画像(元数据)")
|
||
profile = dal.get_profile()
|
||
|
||
# 3. 准备给 AI 的信息(只有元数据)
|
||
metadata = _prepare_metadata_for_ai(profile)
|
||
|
||
# 4. 调用 AI 分析
|
||
logger.info("调用 AI 分析数据特征...")
|
||
ai_analysis = _call_ai_for_analysis(metadata)
|
||
|
||
# 5. 更新数据画像
|
||
profile.inferred_type = ai_analysis.get('data_type', 'unknown')
|
||
profile.key_fields = ai_analysis.get('key_fields', {})
|
||
profile.quality_score = ai_analysis.get('quality_score', 0.0)
|
||
profile.summary = ai_analysis.get('summary', '')
|
||
|
||
return profile, dal
|
||
|
||
|
||
def _prepare_metadata_for_ai(profile: DataProfile) -> Dict[str, Any]:
|
||
"""
|
||
准备给 AI 的元数据(不包含原始数据)
|
||
|
||
参数:
|
||
profile: 数据画像
|
||
|
||
返回:
|
||
元数据字典
|
||
"""
|
||
metadata = {
|
||
"file_path": profile.file_path,
|
||
"row_count": profile.row_count,
|
||
"column_count": profile.column_count,
|
||
"columns": []
|
||
}
|
||
|
||
# 只提供列的元信息
|
||
for col in profile.columns:
|
||
col_info = {
|
||
"name": col.name,
|
||
"dtype": col.dtype,
|
||
"missing_rate": col.missing_rate,
|
||
"unique_count": col.unique_count,
|
||
"sample_values": col.sample_values[:5] # 最多5个示例值
|
||
}
|
||
|
||
# 如果有统计信息,也提供
|
||
if col.statistics:
|
||
col_info["statistics"] = col.statistics
|
||
|
||
metadata["columns"].append(col_info)
|
||
|
||
return metadata
|
||
|
||
|
||
def _call_ai_for_analysis(metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
调用 AI 分析数据特征
|
||
|
||
参数:
|
||
metadata: 数据元信息
|
||
|
||
返回:
|
||
AI 分析结果
|
||
"""
|
||
config = get_config()
|
||
|
||
# 创建 OpenAI 客户端
|
||
client = OpenAI(
|
||
api_key=config.llm.api_key,
|
||
base_url=config.llm.base_url
|
||
)
|
||
|
||
# 构建提示词
|
||
prompt = f"""你是一个数据分析专家。我会给你一个数据集的元信息(表头、统计摘要),你需要分析这个数据集。
|
||
|
||
重要:你只能看到元信息,看不到原始数据行。请基于列名、数据类型、统计特征进行推理。
|
||
|
||
数据元信息:
|
||
```json
|
||
{json.dumps(metadata, ensure_ascii=False, indent=2)}
|
||
```
|
||
|
||
请分析并回答以下问题:
|
||
|
||
1. 这是什么类型的数据?(工单数据/销售数据/用户数据/其他)
|
||
2. 哪些是关键字段?每个字段的业务含义是什么?
|
||
3. 数据质量如何?(0-100分)
|
||
4. 用一段话总结这个数据集的特征
|
||
|
||
请以 JSON 格式返回结果:
|
||
{{
|
||
"data_type": "ticket/sales/user/other",
|
||
"key_fields": {{
|
||
"字段名1": "业务含义1",
|
||
"字段名2": "业务含义2"
|
||
}},
|
||
"quality_score": 85.5,
|
||
"summary": "数据集的总结描述"
|
||
}}
|
||
"""
|
||
|
||
try:
|
||
# 调用 AI
|
||
response = client.chat.completions.create(
|
||
model=config.llm.model,
|
||
messages=[
|
||
{"role": "system", "content": "你是一个数据分析专家,擅长从元数据推断数据特征。"},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=0.3,
|
||
max_tokens=2000
|
||
)
|
||
|
||
# 解析响应
|
||
content = response.choices[0].message.content
|
||
logger.info(f"AI 响应: {content[:200]}...")
|
||
|
||
# 尝试提取 JSON
|
||
result = _extract_json_from_response(content)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"AI 调用失败: {e}")
|
||
# 返回默认值
|
||
return {
|
||
"data_type": "unknown",
|
||
"key_fields": {},
|
||
"quality_score": 0.0,
|
||
"summary": f"AI 分析失败: {str(e)}"
|
||
}
|
||
|
||
|
||
def _extract_json_from_response(content: str) -> Dict[str, Any]:
|
||
"""
|
||
从 AI 响应中提取 JSON
|
||
|
||
参数:
|
||
content: AI 响应内容
|
||
|
||
返回:
|
||
解析后的 JSON 字典
|
||
"""
|
||
# 尝试直接解析
|
||
try:
|
||
return json.loads(content)
|
||
except:
|
||
pass
|
||
|
||
# 尝试提取 JSON 代码块
|
||
import re
|
||
json_match = re.search(r'```json\s*(.*?)\s*```', content, re.DOTALL)
|
||
if json_match:
|
||
try:
|
||
return json.loads(json_match.group(1))
|
||
except:
|
||
pass
|
||
|
||
# 尝试提取 {} 内容
|
||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
||
if json_match:
|
||
try:
|
||
return json.loads(json_match.group(0))
|
||
except:
|
||
pass
|
||
|
||
# 如果都失败,返回默认值
|
||
logger.warning("无法从 AI 响应中提取 JSON,使用默认值")
|
||
return {
|
||
"data_type": "unknown",
|
||
"key_fields": {},
|
||
"quality_score": 0.0,
|
||
"summary": content[:500]
|
||
}
|