Files
document_analyzer/skills/ir_generation_skill/tests/test_step1.py
T
pzhang_zywl fec4c09ee0
CI / test (push) Successful in 8s
sync: update all skills from latest workspace code
doc_parser_skill:
- New: verify_flowchart.py (flowchart validation)
- Updated: LLM.py (multi-provider: DeepSeek + DashScope)
- Updated: image_parser.py (logic tree support, external prompts)
- Updated: SKILL.md, prompts/image_prompt.md

conflict_detection_skill:
- Updated: LLM.py (multi-provider sync)
- Updated: detect_conflicts.py (logic tree text conversion)

ir_generation_skill:
- Replaced old scripts/LLM.py + ir_generator.py with standalone project
- New: main.py, config.py, step1-3_*.py, ensemble_merge.py
- New: prompts/, tests/ subdirectories

tests:
- New: acceptance/ test suite with schema validation
- Fixed: conftest no longer globally skips non-acceptance tests
- Updated: test_sample.py for new ir_generation structure

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-30 22:45:08 +08:00

371 lines
13 KiB
Python

"""
Tests for Stage 1 (Semantic Index).
Validates that the generated semantic_index.json meets all completeness
and structural requirements, including the new iterative features:
- function_units have path fields
- concepts have parent references
- logic tree node coverage meets thresholds
"""
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import config
PASS = "[PASS]"
FAIL = "[FAIL]"
WARN = "[WARN]"
def load_inputs():
"""Load semantic_index.json and the original parsed document."""
try:
si = config.load_json(config.SEMANTIC_INDEX_JSON)
except FileNotFoundError:
print(f"{FAIL} semantic_index.json 未找到: {config.SEMANTIC_INDEX_JSON}")
print(" 请先运行 step1_semantic_index.py")
sys.exit(1)
doc = config.load_input_document()
return si, doc
def build_image_index(doc: dict) -> dict[str, dict]:
"""Build lookup: image rId -> image_analysis entry."""
idx = {}
for img in doc.get("image_analysis", []):
rid = img.get("rid", "")
if rid:
idx[rid] = img
return idx
def build_logic_tree_node_index(doc: dict) -> dict[str, set[str]]:
"""Build lookup: image rId -> set of all node IDs in that logic_tree."""
idx = {}
for img in doc.get("image_analysis", []):
rid = img.get("rid", "")
lt = img.get("logic_tree")
if lt and rid:
node_ids = {n["id"] for n in lt.get("nodes", [])}
idx[rid] = node_ids
return idx
def check_unit_ids(units: list[dict]) -> list[str]:
"""Check that every function_unit has a non-empty unit_id and name."""
errors = []
seen_ids = set()
for i, fu in enumerate(units):
uid = fu.get("unit_id", "")
name = fu.get("name", "")
if not uid:
errors.append(f"function_unit[{i}]: unit_id 为空")
elif uid in seen_ids:
errors.append(f"function_unit[{i}]: unit_id '{uid}' 重复")
seen_ids.add(uid)
if not name:
errors.append(f"function_unit[{i}] ({uid}): name 为空")
return errors
def check_unit_paths(units: list[dict]) -> list[str]:
"""Check that every function_unit has a non-empty path array."""
errors = []
for fu in units:
uid = fu.get("unit_id", "?")
path = fu.get("path", [])
if not path:
errors.append(f"{uid}: path 字段为空或缺失")
elif not isinstance(path, list):
errors.append(f"{uid}: path 必须是数组")
return errors
def check_concept_parents(concepts: list[dict]) -> list[str]:
"""Check that non-scope concepts have valid parent references."""
errors = []
concept_names = {c.get("name", "") for c in concepts}
scope_concepts = {"国内", "海外"}
for c in concepts:
name = c.get("name", "?")
parent = c.get("parent", "")
if name in scope_concepts:
# Scope concepts should have no parent
if parent:
errors.append(f"scope 概念 '{name}' 不应有 parent (当前: '{parent}')")
else:
# Non-scope concepts must have a parent
if not parent:
errors.append(f"概念 '{name}' 缺少 parent 字段")
elif parent not in concept_names:
errors.append(f"概念 '{name}' 的 parent '{parent}' 不存在于 concepts 中")
return errors
def check_sources_exist(
units: list[dict], image_index: dict[str, dict], node_index: dict[str, set[str]]
) -> list[str]:
"""Check that all source references point to real content."""
errors = []
for fu in units:
uid = fu.get("unit_id", "?")
sources = fu.get("sources", [])
if not sources:
errors.append(f"{uid}: sources 为空,必须至少引用一张图片或一段文字")
continue
has_text = False
has_image = False
for j, src in enumerate(sources):
src_type = src.get("type", "")
if src_type in ("table", "para"):
has_text = True
section = src.get("section", "")
if not section:
errors.append(f"{uid}.sources[{j}]: 缺少 section")
elif src_type == "logic_tree":
has_image = True
image_id = src.get("image_id", "")
if not image_id:
errors.append(f"{uid}.sources[{j}]: logic_tree 缺少 image_id")
continue
if image_id not in image_index:
errors.append(
f"{uid}.sources[{j}]: image_id '{image_id}' "
f"在 image_analysis 中不存在"
)
continue
node_ids = src.get("logic_tree_nodes", [])
if node_ids and image_id in node_index:
valid_nodes = node_index[image_id]
for nid in node_ids:
if nid not in valid_nodes:
errors.append(
f"{uid}.sources[{j}]: 节点 '{nid}'"
f"{image_id} 的逻辑树中不存在"
)
elif not node_ids:
errors.append(
f"{uid}.sources[{j}]: logic_tree 类型但未提供 logic_tree_nodes"
)
if not has_text and not has_image:
errors.append(f"{uid}: 必须至少引用一个文本或图片来源")
return errors
def check_logic_tree_coverage(
units: list[dict], node_index: dict[str, set[str]]
) -> list[str]:
"""Check that decision and action nodes in logic trees are covered."""
warnings = []
for image_id, all_nodes in node_index.items():
referenced = set()
for fu in units:
for src in fu.get("sources", []):
if src.get("image_id") == image_id:
for nid in src.get("logic_tree_nodes", []):
referenced.add(nid)
uncovered = all_nodes - referenced
if uncovered:
doc = config.load_input_document()
node_types = {}
for img in doc.get("image_analysis", []):
if img.get("rid") == image_id:
lt = img.get("logic_tree", {})
for n in lt.get("nodes", []):
node_types[n["id"]] = n.get("type", "?")
break
decision_action_uncovered = [
n for n in uncovered if node_types.get(n) in ("decision", "action")
]
if decision_action_uncovered:
warnings.append(
f"{image_id}: {len(decision_action_uncovered)}"
f"decision/action 节点未被引用: {decision_action_uncovered}"
)
return warnings
def check_ensemble_confidence(units: list[dict]) -> list[str]:
"""Check that every function_unit has confidence, ensemble_support, source_versions."""
errors = []
valid_conf = {"high", "medium", "low"}
for fu in units:
uid = fu.get("unit_id", "?")
conf = fu.get("confidence", "")
if not conf:
errors.append(f"{uid}: 缺少 confidence 字段")
elif conf not in valid_conf:
errors.append(f"{uid}: confidence='{conf}' 无效 (期望 high/medium/low)")
support = fu.get("ensemble_support", "")
if not support:
errors.append(f"{uid}: 缺少 ensemble_support 字段")
if "source_versions" not in fu:
errors.append(f"{uid}: 缺少 source_versions 字段")
return errors
def check_confidence_summary(si: dict) -> list[str]:
"""Check that confidence_summary counts match actual unit/concept confidence."""
errors = []
cs = si.get("confidence_summary", {})
if not cs:
errors.append("缺少 confidence_summary 字段")
return errors
units = si.get("function_units", [])
concepts = si.get("concepts", [])
# Count actual confidence levels
unit_high = sum(1 for u in units if u.get("confidence") == "high")
unit_medium = sum(1 for u in units if u.get("confidence") == "medium")
unit_low = sum(1 for u in units if u.get("confidence") == "low")
concept_high = sum(1 for c in concepts if c.get("confidence") == "high")
concept_medium = sum(1 for c in concepts if c.get("confidence") == "medium")
concept_low = sum(1 for c in concepts if c.get("confidence") == "low")
if cs.get("total_units", 0) != len(units):
errors.append(f"confidence_summary.total_units={cs.get('total_units')} != 实际 {len(units)}")
if cs.get("high", 0) != unit_high:
errors.append(f"confidence_summary.high={cs.get('high')} != 实际 {unit_high}")
if cs.get("medium", 0) != unit_medium:
errors.append(f"confidence_summary.medium={cs.get('medium')} != 实际 {unit_medium}")
if cs.get("low", 0) != unit_low:
errors.append(f"confidence_summary.low={cs.get('low')} != 实际 {unit_low}")
if cs.get("total_concepts", 0) != len(concepts):
errors.append(f"confidence_summary.total_concepts={cs.get('total_concepts')} != 实际 {len(concepts)}")
if cs.get("concept_high", 0) != concept_high:
errors.append(f"confidence_summary.concept_high={cs.get('concept_high')} != 实际 {concept_high}")
if cs.get("concept_medium", 0) != concept_medium:
errors.append(f"confidence_summary.concept_medium={cs.get('concept_medium')} != 实际 {concept_medium}")
if cs.get("concept_low", 0) != concept_low:
errors.append(f"confidence_summary.concept_low={cs.get('concept_low')} != 实际 {concept_low}")
return errors
def run_all_tests():
print("=" * 60)
print("Step 1 自检测试")
print("=" * 60)
si, doc = load_inputs()
units = si.get("function_units", [])
concepts = si.get("concepts", [])
image_index = build_image_index(doc)
node_index = build_logic_tree_node_index(doc)
all_errors = []
all_warnings = []
# Test 1: unit_id and name validity
errors = check_unit_ids(units)
if errors:
print(f"\n{FAIL} unit_id/name 检查: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} unit_id/name 检查: 全部通过 ({len(units)} 个功能单元)")
# Test 2: path fields
errors = check_unit_paths(units)
if errors:
print(f"\n{FAIL} path 字段检查: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} path 字段检查: 全部通过")
# Test 3: concept parent references
errors = check_concept_parents(concepts)
if errors:
print(f"\n{FAIL} concept parent 检查: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} concept parent 检查: 全部通过 ({len(concepts)} 个概念)")
# Test 4: source references exist
errors = check_sources_exist(units, image_index, node_index)
if errors:
print(f"\n{FAIL} 来源引用检查: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} 来源引用检查: 全部通过")
# Test 5: Logic tree coverage
warnings = check_logic_tree_coverage(units, node_index)
if warnings:
print(f"\n{WARN} 逻辑树节点覆盖率: {len(warnings)} 个警告")
for w in warnings:
print(f" - {w}")
all_warnings.extend(warnings)
else:
print(f"\n{PASS} 逻辑树节点覆盖率: 全部通过")
# Test 6: Ensemble confidence fields on function_units
errors = check_ensemble_confidence(units)
if errors:
print(f"\n{FAIL} 集成置信度字段: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} 集成置信度字段: 全部通过")
# Test 7: Confidence summary consistency
errors = check_confidence_summary(si)
if errors:
print(f"\n{FAIL} confidence_summary 一致性: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
cs = si.get("confidence_summary", {})
print(f"\n{PASS} confidence_summary 一致性: "
f"high={cs.get('high',0)}, medium={cs.get('medium',0)}, "
f"low={cs.get('low',0)}")
# Summary
print(f"\n{'='*60}")
total_failures = len(all_errors)
total_warnings = len(all_warnings)
if total_failures == 0 and total_warnings == 0:
print(f"{PASS} 所有测试通过!")
elif total_failures == 0:
print(f"{WARN} 全部通过但有 {total_warnings} 个警告")
else:
print(f"{FAIL} 测试失败: {total_failures} 个错误, {total_warnings} 个警告")
print("\n请检查 LLM 输出质量,可能需要调整 Prompt 并重新运行 step1_semantic_index.py")
print(f"\n统计:")
print(f" 功能单元数: {len(units)}")
print(f" 概念数: {len(concepts)}")
print(f" 逻辑树图片数: {len(node_index)}")
return total_failures == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)