doc_parser_skill: - New: verify_flowchart.py (flowchart validation) - Updated: LLM.py (multi-provider: DeepSeek + DashScope) - Updated: image_parser.py (logic tree support, external prompts) - Updated: SKILL.md, prompts/image_prompt.md conflict_detection_skill: - Updated: LLM.py (multi-provider sync) - Updated: detect_conflicts.py (logic tree text conversion) ir_generation_skill: - Replaced old scripts/LLM.py + ir_generator.py with standalone project - New: main.py, config.py, step1-3_*.py, ensemble_merge.py - New: prompts/, tests/ subdirectories tests: - New: acceptance/ test suite with schema validation - Fixed: conftest no longer globally skips non-acceptance tests - Updated: test_sample.py for new ir_generation structure Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,472 @@
|
||||
"""
|
||||
Tests for ensemble_merge.py — all pure Python, no LLM calls, no file I/O.
|
||||
|
||||
Each test uses hardcoded mock data to verify one piece of the merge logic.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from ensemble_merge import (
|
||||
concept_name_similarity,
|
||||
cluster_concepts,
|
||||
merge_concept_cluster,
|
||||
unit_node_jaccard,
|
||||
path_similarity,
|
||||
unit_similarity,
|
||||
cluster_function_units,
|
||||
pick_best_representative,
|
||||
compute_confidence_versions,
|
||||
ensemble_merge_concepts,
|
||||
ensemble_merge_function_units,
|
||||
ensemble_merge,
|
||||
_collect_logic_tree_nodes,
|
||||
)
|
||||
|
||||
PASS = "[PASS]"
|
||||
FAIL = "[FAIL]"
|
||||
|
||||
# ---- Mock helpers ----
|
||||
|
||||
def _mk_unit(unit_id, name, path, logic_tree_nodes, description="", sources=None):
|
||||
"""Create a minimal function_unit dict for testing."""
|
||||
if sources is None:
|
||||
srcs = []
|
||||
if logic_tree_nodes:
|
||||
srcs.append({
|
||||
"image_id": "rId16",
|
||||
"type": "logic_tree",
|
||||
"logic_tree_nodes": logic_tree_nodes,
|
||||
})
|
||||
if not srcs:
|
||||
srcs.append({
|
||||
"section": "3.1",
|
||||
"type": "table",
|
||||
"text_snippet": "test",
|
||||
})
|
||||
else:
|
||||
srcs = sources
|
||||
return {
|
||||
"unit_id": unit_id,
|
||||
"name": name,
|
||||
"description": description or f"desc for {name}",
|
||||
"path": path,
|
||||
"sources": srcs,
|
||||
}
|
||||
|
||||
|
||||
def _mk_concept(name, parent=None, aliases=None, defined_in=None):
|
||||
"""Create a minimal concept dict for testing."""
|
||||
return {
|
||||
"name": name,
|
||||
"aliases": aliases or [],
|
||||
"defined_in": defined_in or ["3.1"],
|
||||
"parent": parent,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 1: concept_name_similarity
|
||||
# =============================================================================
|
||||
|
||||
def test_concept_name_similarity_exact():
|
||||
assert concept_name_similarity("国内", "国内") == 1.0
|
||||
assert concept_name_similarity("行车娱乐限制", "行车娱乐限制") == 1.0
|
||||
|
||||
def test_concept_name_similarity_substring():
|
||||
sim = concept_name_similarity("国内行车娱乐限制", "行车娱乐限制")
|
||||
assert sim >= 0.85, f"expected >= 0.85, got {sim}"
|
||||
|
||||
def test_concept_name_similarity_different():
|
||||
sim = concept_name_similarity("国内", "海外")
|
||||
assert sim < 0.7, f"expected < 0.7, got {sim}"
|
||||
|
||||
def test_concept_name_similarity_seq_matcher():
|
||||
sim = concept_name_similarity("前台打断", "前台应用打断")
|
||||
assert 0.6 < sim < 0.95, f"expected 0.6-0.95, got {sim}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 2: _collect_logic_tree_nodes
|
||||
# =============================================================================
|
||||
|
||||
def test_collect_logic_tree_nodes():
|
||||
unit = _mk_unit("U1", "test", ["A"], ["n1", "n2", "n3"])
|
||||
nodes = _collect_logic_tree_nodes(unit)
|
||||
assert nodes == {"n1", "n2", "n3"}
|
||||
|
||||
def test_collect_logic_tree_nodes_empty():
|
||||
unit = _mk_unit("U2", "test", ["A"], [], sources=[{"section": "3.1", "type": "table"}])
|
||||
nodes = _collect_logic_tree_nodes(unit)
|
||||
assert nodes == set()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 3: unit_node_jaccard
|
||||
# =============================================================================
|
||||
|
||||
def test_unit_node_jaccard_identical():
|
||||
u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2", "n3"])
|
||||
u2 = _mk_unit("U2", "b", ["A"], ["n1", "n2", "n3"])
|
||||
assert unit_node_jaccard(u1, u2) == 1.0
|
||||
|
||||
def test_unit_node_jaccard_partial():
|
||||
u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2", "n3", "n4"])
|
||||
u2 = _mk_unit("U2", "b", ["A"], ["n1", "n2", "n3"])
|
||||
# intersection=3, union=4
|
||||
assert abs(unit_node_jaccard(u1, u2) - 0.75) < 0.01
|
||||
|
||||
def test_unit_node_jaccard_disjoint():
|
||||
u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2"])
|
||||
u2 = _mk_unit("U2", "b", ["B"], ["n3", "n4"])
|
||||
assert unit_node_jaccard(u1, u2) == 0.0
|
||||
|
||||
def test_unit_node_jaccard_both_empty():
|
||||
u1 = _mk_unit("U1", "a", ["A"], [], sources=[{"section": "3.1", "type": "table"}])
|
||||
u2 = _mk_unit("U2", "b", ["B"], [], sources=[{"section": "3.1", "type": "table"}])
|
||||
assert unit_node_jaccard(u1, u2) == 0.0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 4: path_similarity
|
||||
# =============================================================================
|
||||
|
||||
def test_path_similarity_identical():
|
||||
assert path_similarity(
|
||||
["国内", "系统限制", "前台打断"],
|
||||
["国内", "系统限制", "前台打断"],
|
||||
) == 1.0
|
||||
|
||||
def test_path_similarity_partial():
|
||||
sim = path_similarity(
|
||||
["国内", "系统限制", "前台打断"],
|
||||
["国内", "系统限制", "后台限制启动"],
|
||||
)
|
||||
# 2/3 set overlap, sequential 3/5 ≈ 0.6
|
||||
assert 0.4 < sim < 0.9, f"expected 0.4-0.9, got {sim}"
|
||||
|
||||
def test_path_similarity_different():
|
||||
sim = path_similarity(["国内"], ["海外"])
|
||||
assert sim < 0.7, f"expected < 0.7, got {sim}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 5: unit_similarity
|
||||
# =============================================================================
|
||||
|
||||
def test_unit_similarity_identical():
|
||||
u = _mk_unit("U1", "国内-系统限制-前台打断",
|
||||
["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19"])
|
||||
assert unit_similarity(u, u) > 0.99
|
||||
|
||||
def test_unit_similarity_different():
|
||||
u1 = _mk_unit("U1", "a", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3"])
|
||||
u2 = _mk_unit("U2", "b", ["海外", "SDK限制"], ["n10", "n11", "n12"])
|
||||
assert unit_similarity(u1, u2) < 0.3
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 6: cluster_concepts
|
||||
# =============================================================================
|
||||
|
||||
def test_cluster_concepts_identical():
|
||||
v0 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")]
|
||||
v1 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")]
|
||||
v2 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")]
|
||||
clusters = cluster_concepts([v0, v1, v2])
|
||||
# Should have exactly 3 clusters (国内, 海外, 系统限制)
|
||||
assert len(clusters) == 3, f"expected 3 clusters, got {len(clusters)}"
|
||||
for c in clusters:
|
||||
assert len(c) == 3, f"expected each cluster to have 3 members, got {len(c)}"
|
||||
|
||||
def test_cluster_concepts_name_variation():
|
||||
v0 = [_mk_concept("国内行车娱乐限制", parent="国内")]
|
||||
v1 = [_mk_concept("行车娱乐限制", parent="国内")]
|
||||
v2 = [_mk_concept("国内行车娱乐限制", parent="国内")]
|
||||
clusters = cluster_concepts([v0, v1, v2])
|
||||
assert len(clusters) == 1, f"expected 1 cluster, got {len(clusters)}"
|
||||
assert len(clusters[0]) == 3, f"expected 3 members, got {len(clusters[0])}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 7: merge_concept_cluster
|
||||
# =============================================================================
|
||||
|
||||
def test_merge_concept_cluster():
|
||||
cluster = [
|
||||
(0, _mk_concept("国内行车娱乐限制", parent="国内", aliases=["限制"])),
|
||||
(1, _mk_concept("行车娱乐限制", parent="国内", aliases=["行车限制"])),
|
||||
(2, _mk_concept("行车娱乐限制", parent="国内", aliases=["限制"])),
|
||||
]
|
||||
merged, conf = merge_concept_cluster(cluster, 3)
|
||||
assert "行车娱乐限制" in merged["name"]
|
||||
assert merged["parent"] == "国内"
|
||||
assert set(merged["aliases"]) == {"限制", "行车限制"}
|
||||
assert conf in ("high", "medium")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 8: cluster_function_units
|
||||
# =============================================================================
|
||||
|
||||
def test_cluster_function_units_all_agree():
|
||||
u0 = _mk_unit("U-001", "国内-系统限制-前台打断",
|
||||
["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
||||
"switch ON, system app, foreground, speed>=15, non-P, interrupt + toast")
|
||||
u1 = _mk_unit("U-001", "国内-系统限制-前台打断",
|
||||
["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
||||
"switch ON, system app, foreground, speed>=15, non-P, interrupt + toast")
|
||||
u2 = _mk_unit("U-001", "国内-系统限制-前台打断",
|
||||
["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
||||
"switch ON, system app, foreground, interrupt")
|
||||
clusters = cluster_function_units([[u0], [u1], [u2]])
|
||||
assert len(clusters) == 1, f"expected 1 cluster, got {len(clusters)}"
|
||||
assert len(clusters[0]) == 3
|
||||
|
||||
def test_cluster_function_units_partial_agree():
|
||||
u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19"])
|
||||
u1 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19"])
|
||||
u2 = _mk_unit("U-002", "禁止", ["国内", "系统限制", "后台限制启动"],
|
||||
["n5", "n6"])
|
||||
clusters = cluster_function_units([[u0], [u1], [u2]])
|
||||
# u0+u1 in one cluster, u2 in another
|
||||
assert len(clusters) == 2, f"expected 2 clusters, got {len(clusters)}"
|
||||
cluster_sizes = sorted(len(c) for c in clusters)
|
||||
assert cluster_sizes == [1, 2], f"expected cluster sizes [1,2], got {cluster_sizes}"
|
||||
|
||||
def test_cluster_function_units_all_disagree():
|
||||
u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3"])
|
||||
u1 = _mk_unit("U-002", "禁止", ["国内", "系统限制", "后台限制启动"], ["n5", "n6"])
|
||||
u2 = _mk_unit("U-003", "SDK", ["国内", "SDK限制"], ["n10", "n11"])
|
||||
clusters = cluster_function_units([[u0], [u1], [u2]])
|
||||
assert len(clusters) == 3, f"expected 3 clusters, got {len(clusters)}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 9: pick_best_representative
|
||||
# =============================================================================
|
||||
|
||||
def test_pick_best_representative_prefers_rich():
|
||||
u0 = _mk_unit("U-001", "short", ["国内", "系统限制"],
|
||||
["n1", "n2", "n3"],
|
||||
description="short desc")
|
||||
u1 = _mk_unit("U-001", "detailed", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
||||
description="very detailed description of the full rule behavior " * 5)
|
||||
cluster = [(0, u0), (1, u1)]
|
||||
best = pick_best_representative(cluster)
|
||||
# u1 should win: more nodes, longer description, though u0 has lower temp
|
||||
assert best["name"] == "detailed"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 10: compute_confidence_versions
|
||||
# =============================================================================
|
||||
|
||||
def test_confidence_high_unanimous():
|
||||
assert compute_confidence_versions(3, 3, True) == "high"
|
||||
|
||||
def test_confidence_high_two_of_three_with_t0():
|
||||
assert compute_confidence_versions(2, 3, True) == "high"
|
||||
|
||||
def test_confidence_medium_two_of_three_without_t0():
|
||||
assert compute_confidence_versions(2, 3, False) == "medium"
|
||||
|
||||
def test_confidence_low_one_of_three():
|
||||
assert compute_confidence_versions(1, 3, False) == "low"
|
||||
|
||||
def test_confidence_high_all_two_versions():
|
||||
assert compute_confidence_versions(2, 2, True) == "high"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 11: ensemble_merge_concepts
|
||||
# =============================================================================
|
||||
|
||||
def test_ensemble_merge_concepts():
|
||||
v0 = [_mk_concept("国内"), _mk_concept("海外"),
|
||||
_mk_concept("国内行车娱乐限制", parent="国内")]
|
||||
v1 = [_mk_concept("国内"), _mk_concept("海外"),
|
||||
_mk_concept("行车娱乐限制", parent="国内",
|
||||
aliases=["限制"], defined_in=["3.1", "3.1.1"])]
|
||||
v2 = [_mk_concept("国内"), _mk_concept("海外"),
|
||||
_mk_concept("行车娱乐限制", parent="国内")]
|
||||
|
||||
merged = ensemble_merge_concepts([v0, v1, v2])
|
||||
# Should merge the 3 concepts across 3 versions into 3 clusters
|
||||
assert len(merged) == 3, f"expected 3 merged concepts, got {len(merged)}"
|
||||
for c in merged:
|
||||
assert "confidence" in c
|
||||
assert "ensemble_support" in c
|
||||
assert c["ensemble_support"] == "3/3"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 12: ensemble_merge_function_units
|
||||
# =============================================================================
|
||||
|
||||
def test_ensemble_merge_function_units():
|
||||
u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
||||
description="full description A")
|
||||
u1 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
||||
description="full description B (more detail)")
|
||||
u2 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25"],
|
||||
description="partial description")
|
||||
|
||||
merged = ensemble_merge_function_units([[u0], [u1], [u2]])
|
||||
assert len(merged) == 1, f"expected 1 unit, got {len(merged)}"
|
||||
unit = merged[0]
|
||||
assert unit["confidence"] == "high"
|
||||
assert unit["ensemble_support"] == "3/3"
|
||||
assert unit["source_versions"] == 3
|
||||
assert unit["unit_id"].startswith("FU-ENS-")
|
||||
# Should have picked u1 (more detail)
|
||||
assert "more detail" in unit["description"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 13: ensemble_merge full integration
|
||||
# =============================================================================
|
||||
|
||||
def test_ensemble_merge_full():
|
||||
v0 = {
|
||||
"feature_name": "行车娱乐限制",
|
||||
"concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")],
|
||||
"function_units": [
|
||||
_mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n25", "n26"]),
|
||||
_mk_unit("U-002", "后台禁止", ["国内", "系统限制", "后台限制启动"],
|
||||
["n5", "n6"]),
|
||||
],
|
||||
}
|
||||
v1 = {
|
||||
"feature_name": "行车娱乐限制",
|
||||
"concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")],
|
||||
"function_units": [
|
||||
_mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n25", "n26"]),
|
||||
_mk_unit("U-003", "SDK自定义", ["国内", "SDK限制", "自定义限制"],
|
||||
["n10", "n11"]),
|
||||
],
|
||||
}
|
||||
v2 = {
|
||||
"feature_name": "行车娱乐限制",
|
||||
"concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")],
|
||||
"function_units": [
|
||||
_mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n25", "n26"]),
|
||||
],
|
||||
}
|
||||
|
||||
result = ensemble_merge([v0, v1, v2])
|
||||
|
||||
assert result["feature_name"] == "行车娱乐限制"
|
||||
assert result["ensemble_versions"] == 3
|
||||
|
||||
units = result["function_units"]
|
||||
concepts = result["concepts"]
|
||||
|
||||
# Concepts: 国内 + 系统限制
|
||||
assert len(concepts) == 2
|
||||
|
||||
# Units: 打断 (3 versions → high), 后台禁止 (1 version → low), SDK (1 version → low)
|
||||
assert len(units) == 3
|
||||
|
||||
high_units = [u for u in units if u["confidence"] == "high"]
|
||||
low_units = [u for u in units if u["confidence"] == "low"]
|
||||
assert len(high_units) == 1
|
||||
assert len(low_units) == 2
|
||||
|
||||
# All units should have ensemble fields
|
||||
for u in units:
|
||||
assert "confidence" in u
|
||||
assert "ensemble_support" in u
|
||||
assert "source_versions" in u
|
||||
|
||||
# Confidence summary
|
||||
cs = result["confidence_summary"]
|
||||
assert cs["total_units"] == 3
|
||||
assert cs["high"] == 1
|
||||
assert cs["low"] == 2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Runner
|
||||
# =============================================================================
|
||||
|
||||
def run_all_tests():
|
||||
print("=" * 60)
|
||||
print("Ensemble Merge 测试 (纯 Python, 无 LLM)")
|
||||
print("=" * 60)
|
||||
|
||||
tests = [
|
||||
("concept_name_similarity exact", test_concept_name_similarity_exact),
|
||||
("concept_name_similarity substring", test_concept_name_similarity_substring),
|
||||
("concept_name_similarity different", test_concept_name_similarity_different),
|
||||
("concept_name_similarity seq_matcher", test_concept_name_similarity_seq_matcher),
|
||||
("collect_logic_tree_nodes", test_collect_logic_tree_nodes),
|
||||
("collect_logic_tree_nodes empty", test_collect_logic_tree_nodes_empty),
|
||||
("unit_node_jaccard identical", test_unit_node_jaccard_identical),
|
||||
("unit_node_jaccard partial", test_unit_node_jaccard_partial),
|
||||
("unit_node_jaccard disjoint", test_unit_node_jaccard_disjoint),
|
||||
("unit_node_jaccard both_empty", test_unit_node_jaccard_both_empty),
|
||||
("path_similarity identical", test_path_similarity_identical),
|
||||
("path_similarity partial", test_path_similarity_partial),
|
||||
("path_similarity different", test_path_similarity_different),
|
||||
("unit_similarity identical", test_unit_similarity_identical),
|
||||
("unit_similarity different", test_unit_similarity_different),
|
||||
("cluster_concepts identical", test_cluster_concepts_identical),
|
||||
("cluster_concepts name variation", test_cluster_concepts_name_variation),
|
||||
("merge_concept_cluster", test_merge_concept_cluster),
|
||||
("cluster_function_units all_agree", test_cluster_function_units_all_agree),
|
||||
("cluster_function_units partial_agree", test_cluster_function_units_partial_agree),
|
||||
("cluster_function_units all_disagree", test_cluster_function_units_all_disagree),
|
||||
("pick_best_representative", test_pick_best_representative_prefers_rich),
|
||||
("confidence high unanimous", test_confidence_high_unanimous),
|
||||
("confidence high 2/3 with t0", test_confidence_high_two_of_three_with_t0),
|
||||
("confidence medium 2/3 no t0", test_confidence_medium_two_of_three_without_t0),
|
||||
("confidence low 1/3", test_confidence_low_one_of_three),
|
||||
("confidence high 2/2", test_confidence_high_all_two_versions),
|
||||
("ensemble_merge_concepts", test_ensemble_merge_concepts),
|
||||
("ensemble_merge_function_units", test_ensemble_merge_function_units),
|
||||
("ensemble_merge full", test_ensemble_merge_full),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
for name, test_fn in tests:
|
||||
try:
|
||||
test_fn()
|
||||
print(f" {PASS} {name}")
|
||||
passed += 1
|
||||
except AssertionError as e:
|
||||
print(f" {FAIL} {name}: {e}")
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
print(f" {FAIL} {name}: unexpected {type(e).__name__}: {e}")
|
||||
failed += 1
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
if failed == 0:
|
||||
print(f"{PASS} 所有 {passed} 个测试通过!")
|
||||
else:
|
||||
print(f"{FAIL} {failed}/{passed + failed} 个测试失败")
|
||||
print(f"{'='*60}")
|
||||
|
||||
return failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
Tests for Stage 1 (Semantic Index).
|
||||
|
||||
Validates that the generated semantic_index.json meets all completeness
|
||||
and structural requirements, including the new iterative features:
|
||||
- function_units have path fields
|
||||
- concepts have parent references
|
||||
- logic tree node coverage meets thresholds
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
import config
|
||||
|
||||
|
||||
PASS = "[PASS]"
|
||||
FAIL = "[FAIL]"
|
||||
WARN = "[WARN]"
|
||||
|
||||
|
||||
def load_inputs():
|
||||
"""Load semantic_index.json and the original parsed document."""
|
||||
try:
|
||||
si = config.load_json(config.SEMANTIC_INDEX_JSON)
|
||||
except FileNotFoundError:
|
||||
print(f"{FAIL} semantic_index.json 未找到: {config.SEMANTIC_INDEX_JSON}")
|
||||
print(" 请先运行 step1_semantic_index.py")
|
||||
sys.exit(1)
|
||||
doc = config.load_input_document()
|
||||
return si, doc
|
||||
|
||||
|
||||
def build_image_index(doc: dict) -> dict[str, dict]:
|
||||
"""Build lookup: image rId -> image_analysis entry."""
|
||||
idx = {}
|
||||
for img in doc.get("image_analysis", []):
|
||||
rid = img.get("rid", "")
|
||||
if rid:
|
||||
idx[rid] = img
|
||||
return idx
|
||||
|
||||
|
||||
def build_logic_tree_node_index(doc: dict) -> dict[str, set[str]]:
|
||||
"""Build lookup: image rId -> set of all node IDs in that logic_tree."""
|
||||
idx = {}
|
||||
for img in doc.get("image_analysis", []):
|
||||
rid = img.get("rid", "")
|
||||
lt = img.get("logic_tree")
|
||||
if lt and rid:
|
||||
node_ids = {n["id"] for n in lt.get("nodes", [])}
|
||||
idx[rid] = node_ids
|
||||
return idx
|
||||
|
||||
|
||||
def check_unit_ids(units: list[dict]) -> list[str]:
|
||||
"""Check that every function_unit has a non-empty unit_id and name."""
|
||||
errors = []
|
||||
seen_ids = set()
|
||||
for i, fu in enumerate(units):
|
||||
uid = fu.get("unit_id", "")
|
||||
name = fu.get("name", "")
|
||||
if not uid:
|
||||
errors.append(f"function_unit[{i}]: unit_id 为空")
|
||||
elif uid in seen_ids:
|
||||
errors.append(f"function_unit[{i}]: unit_id '{uid}' 重复")
|
||||
seen_ids.add(uid)
|
||||
if not name:
|
||||
errors.append(f"function_unit[{i}] ({uid}): name 为空")
|
||||
return errors
|
||||
|
||||
|
||||
def check_unit_paths(units: list[dict]) -> list[str]:
|
||||
"""Check that every function_unit has a non-empty path array."""
|
||||
errors = []
|
||||
for fu in units:
|
||||
uid = fu.get("unit_id", "?")
|
||||
path = fu.get("path", [])
|
||||
if not path:
|
||||
errors.append(f"{uid}: path 字段为空或缺失")
|
||||
elif not isinstance(path, list):
|
||||
errors.append(f"{uid}: path 必须是数组")
|
||||
return errors
|
||||
|
||||
|
||||
def check_concept_parents(concepts: list[dict]) -> list[str]:
|
||||
"""Check that non-scope concepts have valid parent references."""
|
||||
errors = []
|
||||
concept_names = {c.get("name", "") for c in concepts}
|
||||
scope_concepts = {"国内", "海外"}
|
||||
|
||||
for c in concepts:
|
||||
name = c.get("name", "?")
|
||||
parent = c.get("parent", "")
|
||||
|
||||
if name in scope_concepts:
|
||||
# Scope concepts should have no parent
|
||||
if parent:
|
||||
errors.append(f"scope 概念 '{name}' 不应有 parent (当前: '{parent}')")
|
||||
else:
|
||||
# Non-scope concepts must have a parent
|
||||
if not parent:
|
||||
errors.append(f"概念 '{name}' 缺少 parent 字段")
|
||||
elif parent not in concept_names:
|
||||
errors.append(f"概念 '{name}' 的 parent '{parent}' 不存在于 concepts 中")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_sources_exist(
|
||||
units: list[dict], image_index: dict[str, dict], node_index: dict[str, set[str]]
|
||||
) -> list[str]:
|
||||
"""Check that all source references point to real content."""
|
||||
errors = []
|
||||
for fu in units:
|
||||
uid = fu.get("unit_id", "?")
|
||||
sources = fu.get("sources", [])
|
||||
if not sources:
|
||||
errors.append(f"{uid}: sources 为空,必须至少引用一张图片或一段文字")
|
||||
continue
|
||||
|
||||
has_text = False
|
||||
has_image = False
|
||||
|
||||
for j, src in enumerate(sources):
|
||||
src_type = src.get("type", "")
|
||||
if src_type in ("table", "para"):
|
||||
has_text = True
|
||||
section = src.get("section", "")
|
||||
if not section:
|
||||
errors.append(f"{uid}.sources[{j}]: 缺少 section")
|
||||
elif src_type == "logic_tree":
|
||||
has_image = True
|
||||
image_id = src.get("image_id", "")
|
||||
if not image_id:
|
||||
errors.append(f"{uid}.sources[{j}]: logic_tree 缺少 image_id")
|
||||
continue
|
||||
if image_id not in image_index:
|
||||
errors.append(
|
||||
f"{uid}.sources[{j}]: image_id '{image_id}' "
|
||||
f"在 image_analysis 中不存在"
|
||||
)
|
||||
continue
|
||||
node_ids = src.get("logic_tree_nodes", [])
|
||||
if node_ids and image_id in node_index:
|
||||
valid_nodes = node_index[image_id]
|
||||
for nid in node_ids:
|
||||
if nid not in valid_nodes:
|
||||
errors.append(
|
||||
f"{uid}.sources[{j}]: 节点 '{nid}' 在 "
|
||||
f"{image_id} 的逻辑树中不存在"
|
||||
)
|
||||
elif not node_ids:
|
||||
errors.append(
|
||||
f"{uid}.sources[{j}]: logic_tree 类型但未提供 logic_tree_nodes"
|
||||
)
|
||||
|
||||
if not has_text and not has_image:
|
||||
errors.append(f"{uid}: 必须至少引用一个文本或图片来源")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_logic_tree_coverage(
|
||||
units: list[dict], node_index: dict[str, set[str]]
|
||||
) -> list[str]:
|
||||
"""Check that decision and action nodes in logic trees are covered."""
|
||||
warnings = []
|
||||
for image_id, all_nodes in node_index.items():
|
||||
referenced = set()
|
||||
for fu in units:
|
||||
for src in fu.get("sources", []):
|
||||
if src.get("image_id") == image_id:
|
||||
for nid in src.get("logic_tree_nodes", []):
|
||||
referenced.add(nid)
|
||||
|
||||
uncovered = all_nodes - referenced
|
||||
if uncovered:
|
||||
doc = config.load_input_document()
|
||||
node_types = {}
|
||||
for img in doc.get("image_analysis", []):
|
||||
if img.get("rid") == image_id:
|
||||
lt = img.get("logic_tree", {})
|
||||
for n in lt.get("nodes", []):
|
||||
node_types[n["id"]] = n.get("type", "?")
|
||||
break
|
||||
|
||||
decision_action_uncovered = [
|
||||
n for n in uncovered if node_types.get(n) in ("decision", "action")
|
||||
]
|
||||
if decision_action_uncovered:
|
||||
warnings.append(
|
||||
f"{image_id}: {len(decision_action_uncovered)} 个 "
|
||||
f"decision/action 节点未被引用: {decision_action_uncovered}"
|
||||
)
|
||||
|
||||
return warnings
|
||||
|
||||
|
||||
def check_ensemble_confidence(units: list[dict]) -> list[str]:
|
||||
"""Check that every function_unit has confidence, ensemble_support, source_versions."""
|
||||
errors = []
|
||||
valid_conf = {"high", "medium", "low"}
|
||||
for fu in units:
|
||||
uid = fu.get("unit_id", "?")
|
||||
conf = fu.get("confidence", "")
|
||||
if not conf:
|
||||
errors.append(f"{uid}: 缺少 confidence 字段")
|
||||
elif conf not in valid_conf:
|
||||
errors.append(f"{uid}: confidence='{conf}' 无效 (期望 high/medium/low)")
|
||||
support = fu.get("ensemble_support", "")
|
||||
if not support:
|
||||
errors.append(f"{uid}: 缺少 ensemble_support 字段")
|
||||
if "source_versions" not in fu:
|
||||
errors.append(f"{uid}: 缺少 source_versions 字段")
|
||||
return errors
|
||||
|
||||
|
||||
def check_confidence_summary(si: dict) -> list[str]:
|
||||
"""Check that confidence_summary counts match actual unit/concept confidence."""
|
||||
errors = []
|
||||
cs = si.get("confidence_summary", {})
|
||||
if not cs:
|
||||
errors.append("缺少 confidence_summary 字段")
|
||||
return errors
|
||||
|
||||
units = si.get("function_units", [])
|
||||
concepts = si.get("concepts", [])
|
||||
|
||||
# Count actual confidence levels
|
||||
unit_high = sum(1 for u in units if u.get("confidence") == "high")
|
||||
unit_medium = sum(1 for u in units if u.get("confidence") == "medium")
|
||||
unit_low = sum(1 for u in units if u.get("confidence") == "low")
|
||||
concept_high = sum(1 for c in concepts if c.get("confidence") == "high")
|
||||
concept_medium = sum(1 for c in concepts if c.get("confidence") == "medium")
|
||||
concept_low = sum(1 for c in concepts if c.get("confidence") == "low")
|
||||
|
||||
if cs.get("total_units", 0) != len(units):
|
||||
errors.append(f"confidence_summary.total_units={cs.get('total_units')} != 实际 {len(units)}")
|
||||
if cs.get("high", 0) != unit_high:
|
||||
errors.append(f"confidence_summary.high={cs.get('high')} != 实际 {unit_high}")
|
||||
if cs.get("medium", 0) != unit_medium:
|
||||
errors.append(f"confidence_summary.medium={cs.get('medium')} != 实际 {unit_medium}")
|
||||
if cs.get("low", 0) != unit_low:
|
||||
errors.append(f"confidence_summary.low={cs.get('low')} != 实际 {unit_low}")
|
||||
if cs.get("total_concepts", 0) != len(concepts):
|
||||
errors.append(f"confidence_summary.total_concepts={cs.get('total_concepts')} != 实际 {len(concepts)}")
|
||||
if cs.get("concept_high", 0) != concept_high:
|
||||
errors.append(f"confidence_summary.concept_high={cs.get('concept_high')} != 实际 {concept_high}")
|
||||
if cs.get("concept_medium", 0) != concept_medium:
|
||||
errors.append(f"confidence_summary.concept_medium={cs.get('concept_medium')} != 实际 {concept_medium}")
|
||||
if cs.get("concept_low", 0) != concept_low:
|
||||
errors.append(f"confidence_summary.concept_low={cs.get('concept_low')} != 实际 {concept_low}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
print("=" * 60)
|
||||
print("Step 1 自检测试")
|
||||
print("=" * 60)
|
||||
|
||||
si, doc = load_inputs()
|
||||
units = si.get("function_units", [])
|
||||
concepts = si.get("concepts", [])
|
||||
image_index = build_image_index(doc)
|
||||
node_index = build_logic_tree_node_index(doc)
|
||||
|
||||
all_errors = []
|
||||
all_warnings = []
|
||||
|
||||
# Test 1: unit_id and name validity
|
||||
errors = check_unit_ids(units)
|
||||
if errors:
|
||||
print(f"\n{FAIL} unit_id/name 检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} unit_id/name 检查: 全部通过 ({len(units)} 个功能单元)")
|
||||
|
||||
# Test 2: path fields
|
||||
errors = check_unit_paths(units)
|
||||
if errors:
|
||||
print(f"\n{FAIL} path 字段检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} path 字段检查: 全部通过")
|
||||
|
||||
# Test 3: concept parent references
|
||||
errors = check_concept_parents(concepts)
|
||||
if errors:
|
||||
print(f"\n{FAIL} concept parent 检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} concept parent 检查: 全部通过 ({len(concepts)} 个概念)")
|
||||
|
||||
# Test 4: source references exist
|
||||
errors = check_sources_exist(units, image_index, node_index)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 来源引用检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 来源引用检查: 全部通过")
|
||||
|
||||
# Test 5: Logic tree coverage
|
||||
warnings = check_logic_tree_coverage(units, node_index)
|
||||
if warnings:
|
||||
print(f"\n{WARN} 逻辑树节点覆盖率: {len(warnings)} 个警告")
|
||||
for w in warnings:
|
||||
print(f" - {w}")
|
||||
all_warnings.extend(warnings)
|
||||
else:
|
||||
print(f"\n{PASS} 逻辑树节点覆盖率: 全部通过")
|
||||
|
||||
# Test 6: Ensemble confidence fields on function_units
|
||||
errors = check_ensemble_confidence(units)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 集成置信度字段: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 集成置信度字段: 全部通过")
|
||||
|
||||
# Test 7: Confidence summary consistency
|
||||
errors = check_confidence_summary(si)
|
||||
if errors:
|
||||
print(f"\n{FAIL} confidence_summary 一致性: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
cs = si.get("confidence_summary", {})
|
||||
print(f"\n{PASS} confidence_summary 一致性: "
|
||||
f"high={cs.get('high',0)}, medium={cs.get('medium',0)}, "
|
||||
f"low={cs.get('low',0)}")
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
total_failures = len(all_errors)
|
||||
total_warnings = len(all_warnings)
|
||||
|
||||
if total_failures == 0 and total_warnings == 0:
|
||||
print(f"{PASS} 所有测试通过!")
|
||||
elif total_failures == 0:
|
||||
print(f"{WARN} 全部通过但有 {total_warnings} 个警告")
|
||||
else:
|
||||
print(f"{FAIL} 测试失败: {total_failures} 个错误, {total_warnings} 个警告")
|
||||
print("\n请检查 LLM 输出质量,可能需要调整 Prompt 并重新运行 step1_semantic_index.py")
|
||||
|
||||
print(f"\n统计:")
|
||||
print(f" 功能单元数: {len(units)}")
|
||||
print(f" 概念数: {len(concepts)}")
|
||||
print(f" 逻辑树图片数: {len(node_index)}")
|
||||
|
||||
return total_failures == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Tests for Stage 2 (IR Extraction).
|
||||
|
||||
Validates that ir_fragments.json meets quality and structural requirements:
|
||||
- All fragments have non-empty rules
|
||||
- All rules have path arrays
|
||||
- All rules have precondition.geographic_scope and precondition.screen_type
|
||||
- All trigger conditions have signal/operator/value
|
||||
- user_interaction content is non-empty and not a placeholder
|
||||
- No duplicate rule_ids (across all fragments)
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
import config
|
||||
|
||||
|
||||
PASS = "[PASS]"
|
||||
FAIL = "[FAIL]"
|
||||
WARN = "[WARN]"
|
||||
|
||||
# Forbidden placeholder phrases in user_interaction content
|
||||
FORBIDDEN_PLACEHOLDERS = [
|
||||
"文案由业务定义", "待定", "自定义", "TBD", "todo", "TODO"
|
||||
]
|
||||
|
||||
|
||||
def load_fragments():
|
||||
"""Load ir_fragments.json."""
|
||||
try:
|
||||
return config.load_json(config.IR_FRAGMENTS_JSON)
|
||||
except FileNotFoundError:
|
||||
print(f"{FAIL} ir_fragments.json 未找到: {config.IR_FRAGMENTS_JSON}")
|
||||
print(" 请先运行 step2_ir_extraction.py")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def check_non_empty_rules(fragments: list[dict]) -> list[str]:
|
||||
"""Every fragment must have at least one rule."""
|
||||
errors = []
|
||||
for f in fragments:
|
||||
uid = f.get("unit_id", "?")
|
||||
rules = f.get("rules", [])
|
||||
if not rules:
|
||||
if f.get("error"):
|
||||
errors.append(f"{uid}: 提取失败 — {f['error']}")
|
||||
else:
|
||||
errors.append(f"{uid}: rules 为空")
|
||||
return errors
|
||||
|
||||
|
||||
def check_rule_paths(fragments: list[dict]) -> list[str]:
|
||||
"""Every rule must have a non-empty path array."""
|
||||
errors = []
|
||||
for f in fragments:
|
||||
uid = f.get("unit_id", "?")
|
||||
for j, rule in enumerate(f.get("rules", [])):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
path = rule.get("path", [])
|
||||
if not path:
|
||||
errors.append(f"{rid}: path 字段为空或缺失")
|
||||
elif not isinstance(path, list):
|
||||
errors.append(f"{rid}: path 必须是数组")
|
||||
return errors
|
||||
|
||||
|
||||
def check_precondition_fields(fragments: list[dict]) -> list[str]:
|
||||
"""Every rule must have precondition with geographic_scope and screen_type."""
|
||||
errors = []
|
||||
for f in fragments:
|
||||
uid = f.get("unit_id", "?")
|
||||
for j, rule in enumerate(f.get("rules", [])):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
precond = rule.get("precondition", {})
|
||||
if not precond:
|
||||
errors.append(f"{rid}: precondition 缺失")
|
||||
continue
|
||||
if not precond.get("geographic_scope"):
|
||||
errors.append(f"{rid}: precondition.geographic_scope 缺失")
|
||||
if "screen_type" not in precond:
|
||||
errors.append(f"{rid}: precondition.screen_type 缺失")
|
||||
return errors
|
||||
|
||||
|
||||
def check_user_interaction_content(fragments: list[dict]) -> list[str]:
|
||||
"""user_interaction actions must have non-empty, non-placeholder content."""
|
||||
errors = []
|
||||
for f in fragments:
|
||||
uid = f.get("unit_id", "?")
|
||||
for j, rule in enumerate(f.get("rules", [])):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
for k, action in enumerate(rule.get("actions", [])):
|
||||
if action.get("type") != "user_interaction":
|
||||
continue
|
||||
content = action.get("content", "")
|
||||
if not content:
|
||||
errors.append(
|
||||
f"{rid}.actions[{k}]: user_interaction 的 content 为空"
|
||||
)
|
||||
elif any(ph in content for ph in FORBIDDEN_PLACEHOLDERS):
|
||||
errors.append(
|
||||
f"{rid}.actions[{k}]: content 包含占位符: '{content}'"
|
||||
)
|
||||
return errors
|
||||
|
||||
|
||||
def check_sources_have_logic_tree_nodes(fragments: list[dict]) -> list[str]:
|
||||
"""Every rule should reference at least one logic tree node in its sources."""
|
||||
errors = []
|
||||
for f in fragments:
|
||||
uid = f.get("unit_id", "?")
|
||||
for j, rule in enumerate(f.get("rules", [])):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
sources = rule.get("sources", [])
|
||||
has_logic_tree = any(
|
||||
src.get("type") == "logic_tree" and src.get("node_ids")
|
||||
for src in sources
|
||||
)
|
||||
if not has_logic_tree:
|
||||
has_text = any(
|
||||
src.get("type") in ("table", "para") for src in sources
|
||||
)
|
||||
if not has_text:
|
||||
errors.append(f"{rid}: sources 中既无逻辑树引用也无文字引用")
|
||||
return errors
|
||||
|
||||
|
||||
def check_trigger_conditions(fragments: list[dict]) -> list[str]:
|
||||
"""Every trigger condition must have signal, operator, value."""
|
||||
errors = []
|
||||
for f in fragments:
|
||||
uid = f.get("unit_id", "?")
|
||||
for j, rule in enumerate(f.get("rules", [])):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
trigger = rule.get("trigger", {})
|
||||
conditions = trigger.get("conditions", [])
|
||||
|
||||
if trigger.get("event") is not None:
|
||||
continue
|
||||
|
||||
for k, cond in enumerate(conditions):
|
||||
signal = cond.get("signal", "")
|
||||
operator = cond.get("operator", "")
|
||||
has_value = "value" in cond
|
||||
|
||||
if not signal:
|
||||
errors.append(f"{rid}.condition[{k}]: 缺少 signal")
|
||||
if not operator:
|
||||
errors.append(f"{rid}.condition[{k}]: 缺少 operator")
|
||||
if not has_value:
|
||||
errors.append(f"{rid}.condition[{k}]: 缺少 value")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_duplicate_rule_ids(fragments: list[dict]) -> list[str]:
|
||||
"""Check for duplicate rule_ids across all fragments."""
|
||||
all_rule_ids = []
|
||||
for f in fragments:
|
||||
for rule in f.get("rules", []):
|
||||
rid = rule.get("rule_id", "")
|
||||
if rid:
|
||||
all_rule_ids.append(rid)
|
||||
|
||||
duplicates = [rid for rid, count in Counter(all_rule_ids).items() if count > 1]
|
||||
errors = []
|
||||
if duplicates:
|
||||
errors.append(f"重复 rule_id: {duplicates}")
|
||||
return errors
|
||||
|
||||
|
||||
def check_action_types(fragments: list[dict]) -> list[str]:
|
||||
"""Verify that actions have valid types."""
|
||||
valid_types = {"system", "user_interaction"}
|
||||
errors = []
|
||||
for f in fragments:
|
||||
for j, rule in enumerate(f.get("rules", [])):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
for k, action in enumerate(rule.get("actions", [])):
|
||||
atype = action.get("type", "")
|
||||
if atype not in valid_types:
|
||||
errors.append(
|
||||
f"{rid}.action[{k}]: type='{atype}' 无效, "
|
||||
f"应为 {valid_types}"
|
||||
)
|
||||
if atype == "user_interaction" and "content" not in action:
|
||||
errors.append(
|
||||
f"{rid}.action[{k}]: user_interaction 类型缺少 content 字段"
|
||||
)
|
||||
return errors
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
print("=" * 60)
|
||||
print("Step 2 自检测试")
|
||||
print("=" * 60)
|
||||
|
||||
fragments = load_fragments()
|
||||
all_errors = []
|
||||
total_units = len(fragments)
|
||||
total_rules = sum(len(f.get("rules", [])) for f in fragments)
|
||||
|
||||
# Test 1: Non-empty rules
|
||||
errors = check_non_empty_rules(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 非空规则检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 非空规则检查: 全部通过 ({total_units} 个片段)")
|
||||
|
||||
# Test 2: Rule path arrays
|
||||
errors = check_rule_paths(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 规则 path 字段: {len(errors)} 个错误")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
if len(errors) > 10:
|
||||
print(f" ... 还有 {len(errors) - 10} 个")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 规则 path 字段: 全部通过")
|
||||
|
||||
# Test 3: Precondition fields
|
||||
errors = check_precondition_fields(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} precondition 字段: {len(errors)} 个错误")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
if len(errors) > 10:
|
||||
print(f" ... 还有 {len(errors) - 10} 个")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} precondition 字段: 全部通过")
|
||||
|
||||
# Test 4: user_interaction content
|
||||
errors = check_user_interaction_content(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} user_interaction content: {len(errors)} 个错误")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
if len(errors) > 10:
|
||||
print(f" ... 还有 {len(errors) - 10} 个")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} user_interaction content: 全部通过")
|
||||
|
||||
# Test 5: Sources have logic tree references
|
||||
errors = check_sources_have_logic_tree_nodes(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 来源节点引用: {len(errors)} 个规则缺少来源引用")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
if len(errors) > 10:
|
||||
print(f" ... 还有 {len(errors) - 10} 个")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 来源节点引用: 全部通过")
|
||||
|
||||
# Test 6: Trigger conditions completeness
|
||||
errors = check_trigger_conditions(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 触发条件完整性: {len(errors)} 个条件不完整")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
if len(errors) > 10:
|
||||
print(f" ... 还有 {len(errors) - 10} 个")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 触发条件完整性: 全部通过")
|
||||
|
||||
# Test 7: No duplicate rule_ids
|
||||
errors = check_duplicate_rule_ids(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} rule_id 唯一性: 发现重复")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} rule_id 唯一性: 全部通过")
|
||||
|
||||
# Test 8: Valid action types
|
||||
errors = check_action_types(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 动作类型检查: {len(errors)} 个问题")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 动作类型检查: 全部通过")
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
total_failures = len(all_errors)
|
||||
|
||||
if total_failures == 0:
|
||||
print(f"{PASS} 所有测试通过!")
|
||||
else:
|
||||
print(f"{FAIL} 测试失败: {total_failures} 个错误")
|
||||
print("\n建议:")
|
||||
print(" 1. 检查 ir_fragments.json 中出错的规则")
|
||||
print(" 2. 如果某些功能单元的规则为空,检查上下文包是否丢失了关键信息")
|
||||
print(" 3. 调整 Prompt (prompts/step2_ir_extraction.txt) 后重新运行")
|
||||
|
||||
print(f"\n统计:")
|
||||
print(f" 功能单元数: {total_units}")
|
||||
print(f" 规则总数: {total_rules}")
|
||||
error_units = sum(1 for f in fragments if f.get("error"))
|
||||
if error_units:
|
||||
print(f" 提取失败的单元: {error_units}")
|
||||
|
||||
return total_failures == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Tests for Stage 2.5 (Branch Coverage Auto-Completion).
|
||||
|
||||
Validates:
|
||||
- Path enumeration exists and is non-empty
|
||||
- Auto-complete fragments have valid structure
|
||||
- No duplicate unit_ids in autocomplete fragments
|
||||
- Path coverage improved after autocomplete (if applicable)
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
import config
|
||||
|
||||
|
||||
PASS = "[PASS]"
|
||||
FAIL = "[FAIL]"
|
||||
WARN = "[WARN]"
|
||||
|
||||
|
||||
def load_path_enumeration():
|
||||
"""Load path_enumeration.json."""
|
||||
try:
|
||||
return config.load_json(config.PATH_ENUM_JSON)
|
||||
except FileNotFoundError:
|
||||
print(f"{FAIL} path_enumeration.json 未找到: {config.PATH_ENUM_JSON}")
|
||||
print(" 请先运行 step2_5_branch_coverage.py")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_autocomplete_fragments():
|
||||
"""Load ir_autocomplete_fragments.json, or return [] if absent."""
|
||||
path = config.IR_AUTOCOMPLETE_FRAGMENTS_JSON
|
||||
if not Path(path).exists():
|
||||
return None
|
||||
return config.load_json(path)
|
||||
|
||||
|
||||
def check_path_enumeration(data: dict) -> list[str]:
|
||||
"""Check path enumeration has valid structure."""
|
||||
errors = []
|
||||
paths = data.get("logic_tree_paths", {})
|
||||
if not paths:
|
||||
errors.append("logic_tree_paths 为空")
|
||||
total = data.get("total_paths", 0)
|
||||
if total <= 0:
|
||||
errors.append(f"total_paths = {total}, 期望 > 0")
|
||||
|
||||
for image_id, image_paths in paths.items():
|
||||
if not image_paths:
|
||||
errors.append(f"{image_id}: 路径列表为空")
|
||||
continue
|
||||
for i, p in enumerate(image_paths):
|
||||
if not p.get("path_id"):
|
||||
errors.append(f"{image_id}[{i}]: 缺少 path_id")
|
||||
if not p.get("image_id"):
|
||||
errors.append(f"{image_id}[{i}]: 缺少 image_id")
|
||||
if not p.get("node_ids"):
|
||||
errors.append(f"{image_id}[{i}]: 缺少 node_ids")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_autocomplete_fragments(fragments: list[dict] | None) -> list[str]:
|
||||
"""Check auto-complete fragments have valid structure."""
|
||||
if fragments is None:
|
||||
return ["ir_autocomplete_fragments.json 未生成 (可能无需补全)"]
|
||||
|
||||
errors = []
|
||||
seen_unit_ids = set()
|
||||
|
||||
for frag in fragments:
|
||||
uid = frag.get("unit_id", "")
|
||||
if not uid:
|
||||
errors.append("fragment 缺少 unit_id")
|
||||
continue
|
||||
if uid in seen_unit_ids:
|
||||
errors.append(f"unit_id '{uid}' 重复")
|
||||
seen_unit_ids.add(uid)
|
||||
|
||||
if not frag.get("auto_generated"):
|
||||
errors.append(f"{uid}: auto_generated 应为 true")
|
||||
|
||||
rules = frag.get("rules", [])
|
||||
for j, rule in enumerate(rules):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
if not rule.get("path"):
|
||||
errors.append(f"{rid}: path 字段缺失")
|
||||
precond = rule.get("precondition", {})
|
||||
if not precond.get("geographic_scope"):
|
||||
errors.append(f"{rid}: precondition.geographic_scope 缺失")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
print("=" * 60)
|
||||
print("Step 2.5 自检测试")
|
||||
print("=" * 60)
|
||||
|
||||
all_errors = []
|
||||
|
||||
# Test 1: Path enumeration exists
|
||||
try:
|
||||
path_data = load_path_enumeration()
|
||||
except SystemExit:
|
||||
return False
|
||||
|
||||
errors = check_path_enumeration(path_data)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 路径枚举检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
total = path_data.get("total_paths", 0)
|
||||
n_images = len(path_data.get("logic_tree_paths", {}))
|
||||
print(f"\n{PASS} 路径枚举检查: {total} 条路径, {n_images} 个逻辑树")
|
||||
|
||||
# Test 2: Auto-complete fragments
|
||||
fragments = load_autocomplete_fragments()
|
||||
errors = check_autocomplete_fragments(fragments)
|
||||
|
||||
if fragments is None:
|
||||
print(f"\n{WARN} 自动补全片段: 未生成 (可能所有路径已覆盖)")
|
||||
elif errors:
|
||||
print(f"\n{FAIL} 自动补全片段检查: {len(errors)} 个错误")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
auto_rules = sum(len(f.get("rules", [])) for f in fragments)
|
||||
print(f"\n{PASS} 自动补全片段检查: "
|
||||
f"{len(fragments)} 个片段, {auto_rules} 条规则")
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
total_failures = len(all_errors)
|
||||
|
||||
if total_failures == 0:
|
||||
print(f"{PASS} 所有测试通过!")
|
||||
else:
|
||||
print(f"{FAIL} 测试失败: {total_failures} 个错误")
|
||||
|
||||
return total_failures == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Tests for Stage 3 (Merge & Audit).
|
||||
|
||||
Validates:
|
||||
- ir_final.json exists and is well-formed
|
||||
- No duplicate rule_ids
|
||||
- All rule_ids follow new hierarchical naming convention
|
||||
- All rules have path arrays
|
||||
- ir_audit_report.md exists and contains all required sections
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
import config
|
||||
|
||||
|
||||
PASS = "[PASS]"
|
||||
FAIL = "[FAIL]"
|
||||
WARN = "[WARN]"
|
||||
|
||||
|
||||
def load_ir_final():
|
||||
"""Load ir_final.json."""
|
||||
try:
|
||||
return config.load_json(config.IR_FINAL_JSON)
|
||||
except FileNotFoundError:
|
||||
print(f"{FAIL} ir_final.json 未找到: {config.IR_FINAL_JSON}")
|
||||
print(" 请先运行 step3_merge_and_audit.py")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_audit_report():
|
||||
"""Load ir_audit_report.md if it exists."""
|
||||
try:
|
||||
with open(config.IR_AUDIT_REPORT_MD, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
print(f"{FAIL} ir_audit_report.md 未找到: {config.IR_AUDIT_REPORT_MD}")
|
||||
print(" 请先运行 step3_merge_and_audit.py")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def check_rule_ids(ir: dict) -> list[str]:
|
||||
"""Check for duplicate rule_ids and hierarchical naming convention.
|
||||
|
||||
Format: DRL-001-DOMESTIC-SYS-FG-INTERRUPT-01
|
||||
"""
|
||||
errors = []
|
||||
rules = ir.get("rules", [])
|
||||
rule_ids = [r.get("rule_id", "") for r in rules]
|
||||
|
||||
# No duplicates
|
||||
duplicates = [rid for rid, count in Counter(rule_ids).items() if count > 1]
|
||||
if duplicates:
|
||||
errors.append(f"重复 rule_id: {duplicates}")
|
||||
|
||||
# New hierarchical naming convention
|
||||
pattern = re.compile(
|
||||
r"^[A-Z]+-\d{3}-(DOMESTIC|OVERSEAS)-"
|
||||
r"(SYS|SDK|OTHER)-"
|
||||
r"(FG-INTERRUPT|BG-BLOCK|BG-PAUSE|NO-RESTRICT|SWITCH-OFF)-\d{2}$"
|
||||
)
|
||||
for rid in rule_ids:
|
||||
if rid and not pattern.match(rid):
|
||||
errors.append(
|
||||
f"rule_id 命名不规范: '{rid}' "
|
||||
f"(期望: FEATURE-SCOPE-METHOD-BEHAVIOR-NN)"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_top_level_structure(ir: dict) -> list[str]:
|
||||
"""Check that ir_final has the required top-level fields."""
|
||||
errors = []
|
||||
for field in ["feature", "feature_id", "rules"]:
|
||||
if field not in ir:
|
||||
errors.append(f"ir_final 缺少顶层字段: {field}")
|
||||
|
||||
if not isinstance(ir.get("rules"), list):
|
||||
errors.append("ir_final.rules 必须是数组")
|
||||
elif len(ir["rules"]) == 0:
|
||||
errors.append("ir_final.rules 为空")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_rule_paths(rules: list[dict]) -> list[str]:
|
||||
"""Every rule must have a non-empty path array."""
|
||||
errors = []
|
||||
for rule in rules:
|
||||
rid = rule.get("rule_id", "?")
|
||||
path = rule.get("path", [])
|
||||
if not path:
|
||||
errors.append(f"{rid}: path 字段为空或缺失")
|
||||
return errors
|
||||
|
||||
|
||||
def check_rule_completeness(rules: list[dict]) -> list[str]:
|
||||
"""Check each rule has all required fields."""
|
||||
errors = []
|
||||
required_fields = [
|
||||
"rule_id", "description", "priority", "sources",
|
||||
"precondition", "trigger", "actions"
|
||||
]
|
||||
for i, rule in enumerate(rules):
|
||||
rid = rule.get("rule_id", f"rule[{i}]")
|
||||
for field in required_fields:
|
||||
if field not in rule:
|
||||
errors.append(f"{rid}: 缺少字段 '{field}'")
|
||||
if not rule.get("sources"):
|
||||
errors.append(f"{rid}: sources 为空")
|
||||
if not rule.get("actions"):
|
||||
errors.append(f"{rid}: actions 为空")
|
||||
# Check precondition fields
|
||||
precond = rule.get("precondition", {})
|
||||
if not precond.get("geographic_scope"):
|
||||
errors.append(f"{rid}: precondition.geographic_scope 缺失")
|
||||
if "screen_type" not in precond:
|
||||
errors.append(f"{rid}: precondition.screen_type 缺失")
|
||||
return errors
|
||||
|
||||
|
||||
def check_audit_report(report: str) -> list[str]:
|
||||
"""Check audit report has all required sections."""
|
||||
errors = []
|
||||
|
||||
required_sections = [
|
||||
"逻辑树路径覆盖率",
|
||||
"表格枚举覆盖",
|
||||
"开关状态",
|
||||
"一致性扫描报告",
|
||||
"自动补全摘要",
|
||||
"规则清单",
|
||||
]
|
||||
for section in required_sections:
|
||||
if section not in report:
|
||||
errors.append(f"审计报告缺少章节: {section}")
|
||||
|
||||
# Should have the human review notice
|
||||
if "人工审查" not in report:
|
||||
errors.append("审计报告缺少人工审查提示")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
print("=" * 60)
|
||||
print("Step 3 自检测试")
|
||||
print("=" * 60)
|
||||
|
||||
ir = load_ir_final()
|
||||
report = load_audit_report()
|
||||
rules = ir.get("rules", [])
|
||||
all_errors = []
|
||||
|
||||
# Test 1: Top-level structure
|
||||
errors = check_top_level_structure(ir)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 顶层结构检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 顶层结构检查: 通过 "
|
||||
f"(feature={ir.get('feature')}, feature_id={ir.get('feature_id')})")
|
||||
|
||||
# Test 2: rule_id uniqueness and naming
|
||||
errors = check_rule_ids(ir)
|
||||
if errors:
|
||||
print(f"\n{FAIL} rule_id 检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} rule_id 检查: 全部通过 ({len(rules)} 个唯一 ID, 层次化格式)")
|
||||
|
||||
# Test 3: Rule path fields
|
||||
errors = check_rule_paths(rules)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 规则 path 字段: {len(errors)} 个错误")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 规则 path 字段: 全部通过")
|
||||
|
||||
# Test 4: Rule field completeness
|
||||
errors = check_rule_completeness(rules)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 规则字段完整性: {len(errors)} 个错误")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
if len(errors) > 10:
|
||||
print(f" ... 还有 {len(errors) - 10} 个")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 规则字段完整性: 全部通过")
|
||||
|
||||
# Test 5: Audit report content
|
||||
errors = check_audit_report(report)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 审计报告检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 审计报告检查: 全部通过 (6 个章节)")
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
total_failures = len(all_errors)
|
||||
|
||||
if total_failures == 0:
|
||||
print(f"{PASS} 所有测试通过!")
|
||||
print(f"\n最终交付物:")
|
||||
print(f" - {config.IR_FINAL_JSON} ({len(rules)} 条规则)")
|
||||
print(f" - {config.IR_AUDIT_REPORT_MD}")
|
||||
else:
|
||||
print(f"{FAIL} 测试失败: {total_failures} 个错误")
|
||||
print("\n建议: 检查 ir_fragments.json 和合并逻辑,修复问题后重新运行 step3_merge_and_audit.py")
|
||||
|
||||
return total_failures == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user