二次重构,加入预设模板
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
"""Task execution engine using ReAct pattern."""
|
||||
"""Task execution engine using ReAct pattern — fully AI-driven."""
|
||||
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -11,6 +11,9 @@ from src.models.analysis_plan import AnalysisTask
|
||||
from src.models.analysis_result import AnalysisResult
|
||||
from src.tools.base import AnalysisTool
|
||||
from src.data_access import DataAccessLayer
|
||||
from src.config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def execute_task(
|
||||
@@ -21,60 +24,45 @@ def execute_task(
|
||||
) -> AnalysisResult:
|
||||
"""
|
||||
Execute analysis task using ReAct pattern.
|
||||
|
||||
ReAct loop: Thought -> Action -> Observation -> repeat
|
||||
|
||||
Args:
|
||||
task: Analysis task to execute
|
||||
tools: Available analysis tools
|
||||
data_access: Data access layer for executing tools
|
||||
max_iterations: Maximum number of iterations
|
||||
|
||||
Returns:
|
||||
AnalysisResult with execution results
|
||||
|
||||
Requirements: FR-5.1
|
||||
AI decides which tools to call and with what parameters.
|
||||
No hardcoded heuristics — everything is AI-driven.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Get API key
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
config = get_config()
|
||||
api_key = config.llm.api_key
|
||||
|
||||
if not api_key:
|
||||
# Fallback to simple execution
|
||||
return _fallback_task_execution(task, tools, data_access)
|
||||
|
||||
client = OpenAI(api_key=api_key)
|
||||
|
||||
# Execution history
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url=config.llm.base_url)
|
||||
|
||||
history = []
|
||||
visualizations = []
|
||||
|
||||
column_names = data_access.columns
|
||||
|
||||
try:
|
||||
for iteration in range(max_iterations):
|
||||
# Thought: AI decides next action
|
||||
thought_prompt = _build_thought_prompt(task, tools, history)
|
||||
|
||||
thought_response = client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
prompt = _build_thought_prompt(task, tools, history, column_names)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=config.llm.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a data analyst executing analysis tasks. Use the ReAct pattern: think, act, observe."},
|
||||
{"role": "user", "content": thought_prompt}
|
||||
{"role": "system", "content": _system_prompt()},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=1000
|
||||
temperature=0.3,
|
||||
max_tokens=1200
|
||||
)
|
||||
|
||||
thought = _parse_thought_response(thought_response.choices[0].message.content)
|
||||
|
||||
thought = _parse_thought_response(response.choices[0].message.content)
|
||||
history.append({"type": "thought", "content": thought})
|
||||
|
||||
# Check if task is complete
|
||||
|
||||
if thought.get('is_completed', False):
|
||||
break
|
||||
|
||||
# Action: Execute selected tool
|
||||
|
||||
tool_name = thought.get('selected_tool')
|
||||
tool_params = thought.get('tool_params', {})
|
||||
|
||||
|
||||
if tool_name:
|
||||
tool = _find_tool(tools, tool_name)
|
||||
if tool:
|
||||
@@ -84,95 +72,125 @@ def execute_task(
|
||||
"tool": tool_name,
|
||||
"params": tool_params
|
||||
})
|
||||
|
||||
# Observation: Record result
|
||||
history.append({
|
||||
"type": "observation",
|
||||
"result": action_result
|
||||
})
|
||||
|
||||
# Track visualizations
|
||||
if 'visualization_path' in action_result:
|
||||
if isinstance(action_result, dict) and 'visualization_path' in action_result:
|
||||
visualizations.append(action_result['visualization_path'])
|
||||
|
||||
# Extract insights from history
|
||||
if isinstance(action_result, dict) and action_result.get('data', {}).get('chart_path'):
|
||||
visualizations.append(action_result['data']['chart_path'])
|
||||
else:
|
||||
history.append({
|
||||
"type": "observation",
|
||||
"result": {"error": f"Tool '{tool_name}' not found. Available: {[t.name for t in tools]}"}
|
||||
})
|
||||
|
||||
insights = extract_insights(history, client)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
|
||||
# Collect all observation data
|
||||
all_data = {}
|
||||
for entry in history:
|
||||
if entry['type'] == 'observation':
|
||||
result = entry.get('result', {})
|
||||
if isinstance(result, dict) and result.get('success', True):
|
||||
all_data[f"step_{len(all_data)}"] = result
|
||||
|
||||
return AnalysisResult(
|
||||
task_id=task.id,
|
||||
task_name=task.name,
|
||||
success=True,
|
||||
data=history[-1].get('result', {}) if history else {},
|
||||
data=all_data,
|
||||
visualizations=visualizations,
|
||||
insights=insights,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
logger.error(f"Task execution failed: {e}")
|
||||
return AnalysisResult(
|
||||
task_id=task.id,
|
||||
task_name=task.name,
|
||||
success=False,
|
||||
error=str(e),
|
||||
execution_time=execution_time
|
||||
execution_time=time.time() - start_time
|
||||
)
|
||||
|
||||
|
||||
def _system_prompt() -> str:
|
||||
return (
|
||||
"You are a data analyst executing analysis tasks by calling tools. "
|
||||
"You can ONLY see column names and tool descriptions — never raw data rows. "
|
||||
"You MUST call tools to get any data. Always respond with valid JSON. "
|
||||
"Use actual column names. Pick the right tool and parameters for the task."
|
||||
)
|
||||
|
||||
|
||||
|
||||
def _build_thought_prompt(
|
||||
task: AnalysisTask,
|
||||
tools: List[AnalysisTool],
|
||||
history: List[Dict[str, Any]]
|
||||
history: List[Dict[str, Any]],
|
||||
column_names: List[str] = None
|
||||
) -> str:
|
||||
"""Build prompt for thought step."""
|
||||
"""Build prompt for the ReAct thought step."""
|
||||
tool_descriptions = "\n".join([
|
||||
f"- {tool.name}: {tool.description}"
|
||||
f"- {tool.name}: {tool.description}\n Parameters: {json.dumps(tool.parameters.get('properties', {}), ensure_ascii=False)}"
|
||||
for tool in tools
|
||||
])
|
||||
|
||||
history_str = "\n".join([
|
||||
f"{i+1}. {h['type']}: {str(h.get('content', h.get('result', '')))[:200]}"
|
||||
for i, h in enumerate(history[-5:]) # Last 5 steps
|
||||
])
|
||||
|
||||
prompt = f"""Task: {task.description}
|
||||
Expected Output: {task.expected_output}
|
||||
|
||||
columns_str = f"\nAvailable Data Columns: {', '.join(column_names)}\n" if column_names else ""
|
||||
|
||||
history_str = ""
|
||||
if history:
|
||||
for h in history[-8:]:
|
||||
if h['type'] == 'thought':
|
||||
content = h.get('content', {})
|
||||
history_str += f"\nThought: {content.get('reasoning', '')[:200]}"
|
||||
elif h['type'] == 'action':
|
||||
history_str += f"\nAction: {h.get('tool', '')}({json.dumps(h.get('params', {}), ensure_ascii=False)})"
|
||||
elif h['type'] == 'observation':
|
||||
result = h.get('result', {})
|
||||
result_str = json.dumps(result, ensure_ascii=False, default=str)[:500]
|
||||
history_str += f"\nObservation: {result_str}"
|
||||
|
||||
actions_taken = sum(1 for h in history if h['type'] == 'action')
|
||||
|
||||
return f"""Task: {task.description}
|
||||
Expected Output: {task.expected_output}
|
||||
{columns_str}
|
||||
Available Tools:
|
||||
{tool_descriptions}
|
||||
|
||||
Execution History:
|
||||
{history_str if history else "No history yet"}
|
||||
Execution History:{history_str if history_str else " (none yet — start by calling a tool)"}
|
||||
|
||||
Think about:
|
||||
1. What is the current state?
|
||||
2. What should I do next?
|
||||
3. Which tool should I use?
|
||||
4. Is the task completed?
|
||||
Actions taken: {actions_taken}
|
||||
|
||||
Respond in JSON format:
|
||||
Instructions:
|
||||
1. Pick the most relevant tool and call it with correct column names.
|
||||
2. After each observation, decide if you need more data or can conclude.
|
||||
3. Aim for 2-4 tool calls total to gather enough data.
|
||||
4. When you have enough data, set is_completed=true and summarize findings in reasoning.
|
||||
|
||||
Respond ONLY with this JSON (no other text):
|
||||
{{
|
||||
"reasoning": "Your reasoning",
|
||||
"reasoning": "your analysis reasoning",
|
||||
"is_completed": false,
|
||||
"selected_tool": "tool_name",
|
||||
"tool_params": {{"param": "value"}}
|
||||
}}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def _parse_thought_response(response_text: str) -> Dict[str, Any]:
|
||||
"""Parse thought response from AI."""
|
||||
"""Parse AI thought response JSON."""
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {
|
||||
'reasoning': response_text,
|
||||
'is_completed': False,
|
||||
@@ -186,80 +204,78 @@ def call_tool(
|
||||
data_access: DataAccessLayer,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Call analysis tool and return result.
|
||||
|
||||
Args:
|
||||
tool: Tool to execute
|
||||
data_access: Data access layer
|
||||
**kwargs: Tool parameters
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
|
||||
Requirements: FR-5.2
|
||||
"""
|
||||
"""Call an analysis tool and return the result."""
|
||||
try:
|
||||
result = data_access.execute_tool(tool, **kwargs)
|
||||
return {
|
||||
'success': True,
|
||||
'data': result
|
||||
}
|
||||
return {'success': True, 'data': result}
|
||||
except Exception as e:
|
||||
return {
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
|
||||
def extract_insights(
|
||||
history: List[Dict[str, Any]],
|
||||
client: Optional[OpenAI] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Extract insights from execution history.
|
||||
|
||||
Args:
|
||||
history: Execution history
|
||||
client: OpenAI client (optional)
|
||||
|
||||
Returns:
|
||||
List of insights
|
||||
|
||||
Requirements: FR-5.4
|
||||
"""
|
||||
"""Extract insights from execution history using AI."""
|
||||
if not client:
|
||||
# Simple extraction without AI
|
||||
insights = []
|
||||
for entry in history:
|
||||
if entry['type'] == 'observation':
|
||||
result = entry.get('result', {})
|
||||
if isinstance(result, dict) and 'data' in result:
|
||||
insights.append(f"Found data: {str(result['data'])[:100]}")
|
||||
return insights[:5] # Limit to 5
|
||||
|
||||
# AI-driven insight extraction
|
||||
history_str = json.dumps(history, indent=2, ensure_ascii=False)[:3000]
|
||||
|
||||
return _extract_insights_from_observations(history)
|
||||
|
||||
config = get_config()
|
||||
history_str = json.dumps(history, indent=2, ensure_ascii=False, default=str)[:4000]
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
model=config.llm.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "Extract key insights from analysis execution history."},
|
||||
{"role": "user", "content": f"Execution history:\n{history_str}\n\nExtract 3-5 key insights as a JSON array of strings."}
|
||||
{"role": "system", "content": "You are a data analyst. Extract key insights from analysis results. Respond in Chinese. Return a JSON array of 3-5 insight strings with specific numbers."},
|
||||
{"role": "user", "content": f"Execution history:\n{history_str}\n\nExtract 3-5 key data-driven insights as a JSON array of strings."}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=500
|
||||
temperature=0.5,
|
||||
max_tokens=800
|
||||
)
|
||||
|
||||
insights_text = response.choices[0].message.content
|
||||
json_match = re.search(r'\[.*\]', insights_text, re.DOTALL)
|
||||
text = response.choices[0].message.content
|
||||
json_match = re.search(r'\[.*\]', text, re.DOTALL)
|
||||
if json_match:
|
||||
return json.loads(json_match.group())
|
||||
except:
|
||||
pass
|
||||
|
||||
return ["Analysis completed successfully"]
|
||||
parsed = json.loads(json_match.group())
|
||||
if isinstance(parsed, list) and len(parsed) > 0:
|
||||
return parsed
|
||||
except Exception as e:
|
||||
logger.warning(f"AI insight extraction failed: {e}")
|
||||
|
||||
return _extract_insights_from_observations(history)
|
||||
|
||||
|
||||
def _extract_insights_from_observations(history: List[Dict[str, Any]]) -> List[str]:
|
||||
"""Fallback: extract insights directly from observation data."""
|
||||
insights = []
|
||||
for entry in history:
|
||||
if entry['type'] != 'observation':
|
||||
continue
|
||||
result = entry.get('result', {})
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
data = result.get('data', result)
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
|
||||
if 'groups' in data:
|
||||
top = data['groups'][:3] if isinstance(data['groups'], list) else []
|
||||
if top:
|
||||
group_str = ', '.join(f"{g.get('group','?')}: {g.get('value',0)}" for g in top)
|
||||
insights.append(f"Top groups: {group_str}")
|
||||
if 'distribution' in data:
|
||||
dist = data['distribution'][:3] if isinstance(data['distribution'], list) else []
|
||||
if dist:
|
||||
dist_str = ', '.join(f"{d.get('value','?')}: {d.get('percentage',0):.1f}%" for d in dist)
|
||||
insights.append(f"Distribution: {dist_str}")
|
||||
if 'trend' in data:
|
||||
insights.append(f"Trend: {data['trend']}, growth rate: {data.get('growth_rate', 'N/A')}")
|
||||
if 'outlier_count' in data:
|
||||
insights.append(f"Outliers: {data['outlier_count']} ({data.get('outlier_percentage', 0):.1f}%)")
|
||||
if 'mean' in data and 'column' in data:
|
||||
insights.append(f"{data['column']}: mean={data['mean']:.2f}, median={data.get('median', 'N/A')}")
|
||||
|
||||
return insights[:5] if insights else ["Analysis completed"]
|
||||
|
||||
|
||||
def _find_tool(tools: List[AnalysisTool], tool_name: str) -> Optional[AnalysisTool]:
|
||||
@@ -275,42 +291,53 @@ def _fallback_task_execution(
|
||||
tools: List[AnalysisTool],
|
||||
data_access: DataAccessLayer
|
||||
) -> AnalysisResult:
|
||||
"""Simple fallback execution without AI."""
|
||||
"""Fallback execution without AI — runs required tools with minimal params."""
|
||||
start_time = time.time()
|
||||
|
||||
all_data = {}
|
||||
insights = []
|
||||
|
||||
try:
|
||||
# Execute first applicable tool
|
||||
for tool_name in task.required_tools:
|
||||
columns = data_access.columns
|
||||
tools_to_run = task.required_tools if task.required_tools else [t.name for t in tools[:3]]
|
||||
|
||||
for tool_name in tools_to_run:
|
||||
tool = _find_tool(tools, tool_name)
|
||||
if tool:
|
||||
result = call_tool(tool, data_access)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return AnalysisResult(
|
||||
task_id=task.id,
|
||||
task_name=task.name,
|
||||
success=result.get('success', False),
|
||||
data=result.get('data', {}),
|
||||
insights=[f"Executed {tool_name}"],
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# No tools executed
|
||||
execution_time = time.time() - start_time
|
||||
if not tool:
|
||||
continue
|
||||
# Try calling with first column as a basic param
|
||||
params = _guess_minimal_params(tool, columns)
|
||||
if params:
|
||||
result = call_tool(tool, data_access, **params)
|
||||
if result.get('success'):
|
||||
all_data[tool_name] = result.get('data', {})
|
||||
|
||||
return AnalysisResult(
|
||||
task_id=task.id,
|
||||
task_name=task.name,
|
||||
success=False,
|
||||
error="No applicable tools found",
|
||||
execution_time=execution_time
|
||||
success=True,
|
||||
data=all_data,
|
||||
insights=insights or ["Fallback execution completed"],
|
||||
execution_time=time.time() - start_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return AnalysisResult(
|
||||
task_id=task.id,
|
||||
task_name=task.name,
|
||||
success=False,
|
||||
error=str(e),
|
||||
execution_time=execution_time
|
||||
execution_time=time.time() - start_time
|
||||
)
|
||||
|
||||
|
||||
def _guess_minimal_params(tool: AnalysisTool, columns: List[str]) -> Optional[Dict[str, Any]]:
|
||||
"""Guess minimal params for fallback — just pick first applicable column."""
|
||||
props = tool.parameters.get('properties', {})
|
||||
required = tool.parameters.get('required', [])
|
||||
params = {}
|
||||
for param_name in required:
|
||||
prop = props.get(param_name, {})
|
||||
if prop.get('type') == 'string' and 'column' in param_name.lower():
|
||||
params[param_name] = columns[0] if columns else ''
|
||||
elif prop.get('type') == 'string':
|
||||
params[param_name] = columns[0] if columns else ''
|
||||
return params if params else None
|
||||
|
||||
Reference in New Issue
Block a user