239 lines
9.0 KiB
Python
239 lines
9.0 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
||
|
|
"""
|
||
|
|
Unit tests for Phase 1: Backend Data Model + API Changes
|
||
|
|
|
||
|
|
Run: python -m pytest tests/test_phase1.py -v
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
import json
|
||
|
|
import tempfile
|
||
|
|
|
||
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
|
||
|
|
# ===========================================================================
|
||
|
|
# Task 1: SessionData model extension
|
||
|
|
# ===========================================================================
|
||
|
|
|
||
|
|
class TestSessionDataExtension:
|
||
|
|
def test_rounds_initialized_to_empty_list(self):
|
||
|
|
"""Task 1.1: rounds attribute exists and defaults to []"""
|
||
|
|
from web.main import SessionData
|
||
|
|
session = SessionData("test-id")
|
||
|
|
assert hasattr(session, "rounds")
|
||
|
|
assert session.rounds == []
|
||
|
|
assert isinstance(session.rounds, list)
|
||
|
|
|
||
|
|
def test_data_files_initialized_to_empty_list(self):
|
||
|
|
"""Task 1.2: data_files attribute exists and defaults to []"""
|
||
|
|
from web.main import SessionData
|
||
|
|
session = SessionData("test-id")
|
||
|
|
assert hasattr(session, "data_files")
|
||
|
|
assert session.data_files == []
|
||
|
|
assert isinstance(session.data_files, list)
|
||
|
|
|
||
|
|
def test_existing_fields_unchanged(self):
|
||
|
|
"""Existing SessionData fields still work."""
|
||
|
|
from web.main import SessionData
|
||
|
|
session = SessionData("test-id")
|
||
|
|
assert session.session_id == "test-id"
|
||
|
|
assert session.is_running is False
|
||
|
|
assert session.analysis_results == []
|
||
|
|
assert session.current_round == 0
|
||
|
|
assert session.max_rounds == 20
|
||
|
|
|
||
|
|
def test_reconstruct_session_loads_rounds_and_data_files(self):
|
||
|
|
"""Task 1.3: _reconstruct_session loads rounds and data_files from results.json"""
|
||
|
|
from web.main import SessionManager
|
||
|
|
|
||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||
|
|
session_dir = os.path.join(tmpdir, "session_test123")
|
||
|
|
os.makedirs(session_dir)
|
||
|
|
|
||
|
|
test_rounds = [{"round": 1, "reasoning": "test", "code": "x=1"}]
|
||
|
|
test_data_files = [{"filename": "out.csv", "rows": 10}]
|
||
|
|
results = {
|
||
|
|
"analysis_results": [{"round": 1}],
|
||
|
|
"rounds": test_rounds,
|
||
|
|
"data_files": test_data_files,
|
||
|
|
}
|
||
|
|
with open(os.path.join(session_dir, "results.json"), "w") as f:
|
||
|
|
json.dump(results, f)
|
||
|
|
|
||
|
|
sm = SessionManager()
|
||
|
|
session = sm._reconstruct_session("test123", session_dir)
|
||
|
|
|
||
|
|
assert session.rounds == test_rounds
|
||
|
|
assert session.data_files == test_data_files
|
||
|
|
assert session.analysis_results == [{"round": 1}]
|
||
|
|
|
||
|
|
def test_reconstruct_session_legacy_format(self):
|
||
|
|
"""Task 1.3: _reconstruct_session handles legacy list format gracefully"""
|
||
|
|
from web.main import SessionManager
|
||
|
|
|
||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||
|
|
session_dir = os.path.join(tmpdir, "session_legacy")
|
||
|
|
os.makedirs(session_dir)
|
||
|
|
|
||
|
|
# Legacy format: results.json is a plain list
|
||
|
|
legacy_results = [{"round": 1, "code": "x=1"}]
|
||
|
|
with open(os.path.join(session_dir, "results.json"), "w") as f:
|
||
|
|
json.dump(legacy_results, f)
|
||
|
|
|
||
|
|
sm = SessionManager()
|
||
|
|
session = sm._reconstruct_session("legacy", session_dir)
|
||
|
|
|
||
|
|
assert session.analysis_results == legacy_results
|
||
|
|
assert session.rounds == []
|
||
|
|
assert session.data_files == []
|
||
|
|
|
||
|
|
|
||
|
|
# ===========================================================================
|
||
|
|
# Task 2: Status API response
|
||
|
|
# ===========================================================================
|
||
|
|
|
||
|
|
class TestStatusAPIResponse:
|
||
|
|
def test_status_response_contains_rounds(self):
|
||
|
|
"""Task 2.1: GET /api/status response includes rounds field"""
|
||
|
|
from web.main import SessionData, session_manager
|
||
|
|
|
||
|
|
session = SessionData("status-test")
|
||
|
|
session.rounds = [{"round": 1, "reasoning": "r1"}]
|
||
|
|
with session_manager.lock:
|
||
|
|
session_manager.sessions["status-test"] = session
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Simulate what the endpoint returns
|
||
|
|
response = {
|
||
|
|
"is_running": session.is_running,
|
||
|
|
"log": "",
|
||
|
|
"has_report": session.generated_report is not None,
|
||
|
|
"rounds": session.rounds,
|
||
|
|
"current_round": session.current_round,
|
||
|
|
"max_rounds": session.max_rounds,
|
||
|
|
"progress_percentage": session.progress_percentage,
|
||
|
|
"status_message": session.status_message,
|
||
|
|
}
|
||
|
|
assert "rounds" in response
|
||
|
|
assert response["rounds"] == [{"round": 1, "reasoning": "r1"}]
|
||
|
|
finally:
|
||
|
|
with session_manager.lock:
|
||
|
|
del session_manager.sessions["status-test"]
|
||
|
|
|
||
|
|
def test_status_backward_compat_fields(self):
|
||
|
|
"""Task 2.2: Existing fields remain unchanged"""
|
||
|
|
from web.main import SessionData
|
||
|
|
|
||
|
|
session = SessionData("compat-test")
|
||
|
|
session.status_message = "分析中"
|
||
|
|
session.progress_percentage = 50.0
|
||
|
|
session.current_round = 5
|
||
|
|
session.max_rounds = 20
|
||
|
|
|
||
|
|
response = {
|
||
|
|
"is_running": session.is_running,
|
||
|
|
"log": "",
|
||
|
|
"has_report": session.generated_report is not None,
|
||
|
|
"progress_percentage": session.progress_percentage,
|
||
|
|
"current_round": session.current_round,
|
||
|
|
"max_rounds": session.max_rounds,
|
||
|
|
"status_message": session.status_message,
|
||
|
|
"rounds": session.rounds,
|
||
|
|
}
|
||
|
|
|
||
|
|
assert response["is_running"] is False
|
||
|
|
assert response["has_report"] is False
|
||
|
|
assert response["progress_percentage"] == 50.0
|
||
|
|
assert response["current_round"] == 5
|
||
|
|
assert response["max_rounds"] == 20
|
||
|
|
assert response["status_message"] == "分析中"
|
||
|
|
assert "log" in response
|
||
|
|
|
||
|
|
|
||
|
|
# ===========================================================================
|
||
|
|
# Task 4: Evidence extraction
|
||
|
|
# ===========================================================================
|
||
|
|
|
||
|
|
class TestEvidenceExtraction:
|
||
|
|
def test_extract_evidence_basic(self):
|
||
|
|
"""Task 4.1: Parse evidence annotations and build supporting_data"""
|
||
|
|
from web.main import _extract_evidence_annotations, SessionData
|
||
|
|
|
||
|
|
session = SessionData("ev-test")
|
||
|
|
session.rounds = [
|
||
|
|
{"round": 1, "evidence_rows": [{"col": "val1"}]},
|
||
|
|
{"round": 2, "evidence_rows": [{"col": "val2"}]},
|
||
|
|
]
|
||
|
|
|
||
|
|
paragraphs = [
|
||
|
|
{"id": "p-0", "type": "text", "content": "Some intro text"},
|
||
|
|
{"id": "p-1", "type": "text", "content": "Analysis result <!-- evidence:round_1 -->"},
|
||
|
|
{"id": "p-2", "type": "text", "content": "More analysis <!-- evidence:round_2 -->"},
|
||
|
|
]
|
||
|
|
|
||
|
|
result = _extract_evidence_annotations(paragraphs, session)
|
||
|
|
|
||
|
|
assert "p-0" not in result # no annotation
|
||
|
|
assert result["p-1"] == [{"col": "val1"}]
|
||
|
|
assert result["p-2"] == [{"col": "val2"}]
|
||
|
|
|
||
|
|
def test_extract_evidence_no_annotations(self):
|
||
|
|
"""Task 4.1: No annotations means empty mapping"""
|
||
|
|
from web.main import _extract_evidence_annotations, SessionData
|
||
|
|
|
||
|
|
session = SessionData("ev-test2")
|
||
|
|
session.rounds = [{"round": 1, "evidence_rows": [{"a": 1}]}]
|
||
|
|
|
||
|
|
paragraphs = [
|
||
|
|
{"id": "p-0", "type": "text", "content": "No evidence here"},
|
||
|
|
]
|
||
|
|
|
||
|
|
result = _extract_evidence_annotations(paragraphs, session)
|
||
|
|
assert result == {}
|
||
|
|
|
||
|
|
def test_extract_evidence_out_of_range_round(self):
|
||
|
|
"""Task 4.1: Round number beyond available rounds is ignored"""
|
||
|
|
from web.main import _extract_evidence_annotations, SessionData
|
||
|
|
|
||
|
|
session = SessionData("ev-test3")
|
||
|
|
session.rounds = [{"round": 1, "evidence_rows": [{"a": 1}]}]
|
||
|
|
|
||
|
|
paragraphs = [
|
||
|
|
{"id": "p-0", "type": "text", "content": "Ref to round 5 <!-- evidence:round_5 -->"},
|
||
|
|
]
|
||
|
|
|
||
|
|
result = _extract_evidence_annotations(paragraphs, session)
|
||
|
|
assert result == {}
|
||
|
|
|
||
|
|
def test_extract_evidence_empty_evidence_rows(self):
|
||
|
|
"""Task 4.1: Round with empty evidence_rows is excluded"""
|
||
|
|
from web.main import _extract_evidence_annotations, SessionData
|
||
|
|
|
||
|
|
session = SessionData("ev-test4")
|
||
|
|
session.rounds = [{"round": 1, "evidence_rows": []}]
|
||
|
|
|
||
|
|
paragraphs = [
|
||
|
|
{"id": "p-0", "type": "text", "content": "Has annotation <!-- evidence:round_1 -->"},
|
||
|
|
]
|
||
|
|
|
||
|
|
result = _extract_evidence_annotations(paragraphs, session)
|
||
|
|
assert result == {}
|
||
|
|
|
||
|
|
def test_extract_evidence_whitespace_in_comment(self):
|
||
|
|
"""Task 4.1: Handles whitespace variations in HTML comment"""
|
||
|
|
from web.main import _extract_evidence_annotations, SessionData
|
||
|
|
|
||
|
|
session = SessionData("ev-test5")
|
||
|
|
session.rounds = [{"round": 1, "evidence_rows": [{"x": 42}]}]
|
||
|
|
|
||
|
|
paragraphs = [
|
||
|
|
{"id": "p-0", "type": "text", "content": "Text <!-- evidence:round_1 -->"},
|
||
|
|
]
|
||
|
|
|
||
|
|
result = _extract_evidence_annotations(paragraphs, session)
|
||
|
|
assert result["p-0"] == [{"x": 42}]
|