fec4c09ee0
CI / test (push) Successful in 8s
doc_parser_skill: - New: verify_flowchart.py (flowchart validation) - Updated: LLM.py (multi-provider: DeepSeek + DashScope) - Updated: image_parser.py (logic tree support, external prompts) - Updated: SKILL.md, prompts/image_prompt.md conflict_detection_skill: - Updated: LLM.py (multi-provider sync) - Updated: detect_conflicts.py (logic tree text conversion) ir_generation_skill: - Replaced old scripts/LLM.py + ir_generator.py with standalone project - New: main.py, config.py, step1-3_*.py, ensemble_merge.py - New: prompts/, tests/ subdirectories tests: - New: acceptance/ test suite with schema validation - Fixed: conftest no longer globally skips non-acceptance tests - Updated: test_sample.py for new ir_generation structure Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
473 lines
20 KiB
Python
473 lines
20 KiB
Python
"""
|
|
Tests for ensemble_merge.py — all pure Python, no LLM calls, no file I/O.
|
|
|
|
Each test uses hardcoded mock data to verify one piece of the merge logic.
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from ensemble_merge import (
|
|
concept_name_similarity,
|
|
cluster_concepts,
|
|
merge_concept_cluster,
|
|
unit_node_jaccard,
|
|
path_similarity,
|
|
unit_similarity,
|
|
cluster_function_units,
|
|
pick_best_representative,
|
|
compute_confidence_versions,
|
|
ensemble_merge_concepts,
|
|
ensemble_merge_function_units,
|
|
ensemble_merge,
|
|
_collect_logic_tree_nodes,
|
|
)
|
|
|
|
PASS = "[PASS]"
|
|
FAIL = "[FAIL]"
|
|
|
|
# ---- Mock helpers ----
|
|
|
|
def _mk_unit(unit_id, name, path, logic_tree_nodes, description="", sources=None):
|
|
"""Create a minimal function_unit dict for testing."""
|
|
if sources is None:
|
|
srcs = []
|
|
if logic_tree_nodes:
|
|
srcs.append({
|
|
"image_id": "rId16",
|
|
"type": "logic_tree",
|
|
"logic_tree_nodes": logic_tree_nodes,
|
|
})
|
|
if not srcs:
|
|
srcs.append({
|
|
"section": "3.1",
|
|
"type": "table",
|
|
"text_snippet": "test",
|
|
})
|
|
else:
|
|
srcs = sources
|
|
return {
|
|
"unit_id": unit_id,
|
|
"name": name,
|
|
"description": description or f"desc for {name}",
|
|
"path": path,
|
|
"sources": srcs,
|
|
}
|
|
|
|
|
|
def _mk_concept(name, parent=None, aliases=None, defined_in=None):
|
|
"""Create a minimal concept dict for testing."""
|
|
return {
|
|
"name": name,
|
|
"aliases": aliases or [],
|
|
"defined_in": defined_in or ["3.1"],
|
|
"parent": parent,
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Test 1: concept_name_similarity
|
|
# =============================================================================
|
|
|
|
def test_concept_name_similarity_exact():
|
|
assert concept_name_similarity("国内", "国内") == 1.0
|
|
assert concept_name_similarity("行车娱乐限制", "行车娱乐限制") == 1.0
|
|
|
|
def test_concept_name_similarity_substring():
|
|
sim = concept_name_similarity("国内行车娱乐限制", "行车娱乐限制")
|
|
assert sim >= 0.85, f"expected >= 0.85, got {sim}"
|
|
|
|
def test_concept_name_similarity_different():
|
|
sim = concept_name_similarity("国内", "海外")
|
|
assert sim < 0.7, f"expected < 0.7, got {sim}"
|
|
|
|
def test_concept_name_similarity_seq_matcher():
|
|
sim = concept_name_similarity("前台打断", "前台应用打断")
|
|
assert 0.6 < sim < 0.95, f"expected 0.6-0.95, got {sim}"
|
|
|
|
|
|
# =============================================================================
|
|
# Test 2: _collect_logic_tree_nodes
|
|
# =============================================================================
|
|
|
|
def test_collect_logic_tree_nodes():
|
|
unit = _mk_unit("U1", "test", ["A"], ["n1", "n2", "n3"])
|
|
nodes = _collect_logic_tree_nodes(unit)
|
|
assert nodes == {"n1", "n2", "n3"}
|
|
|
|
def test_collect_logic_tree_nodes_empty():
|
|
unit = _mk_unit("U2", "test", ["A"], [], sources=[{"section": "3.1", "type": "table"}])
|
|
nodes = _collect_logic_tree_nodes(unit)
|
|
assert nodes == set()
|
|
|
|
|
|
# =============================================================================
|
|
# Test 3: unit_node_jaccard
|
|
# =============================================================================
|
|
|
|
def test_unit_node_jaccard_identical():
|
|
u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2", "n3"])
|
|
u2 = _mk_unit("U2", "b", ["A"], ["n1", "n2", "n3"])
|
|
assert unit_node_jaccard(u1, u2) == 1.0
|
|
|
|
def test_unit_node_jaccard_partial():
|
|
u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2", "n3", "n4"])
|
|
u2 = _mk_unit("U2", "b", ["A"], ["n1", "n2", "n3"])
|
|
# intersection=3, union=4
|
|
assert abs(unit_node_jaccard(u1, u2) - 0.75) < 0.01
|
|
|
|
def test_unit_node_jaccard_disjoint():
|
|
u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2"])
|
|
u2 = _mk_unit("U2", "b", ["B"], ["n3", "n4"])
|
|
assert unit_node_jaccard(u1, u2) == 0.0
|
|
|
|
def test_unit_node_jaccard_both_empty():
|
|
u1 = _mk_unit("U1", "a", ["A"], [], sources=[{"section": "3.1", "type": "table"}])
|
|
u2 = _mk_unit("U2", "b", ["B"], [], sources=[{"section": "3.1", "type": "table"}])
|
|
assert unit_node_jaccard(u1, u2) == 0.0
|
|
|
|
|
|
# =============================================================================
|
|
# Test 4: path_similarity
|
|
# =============================================================================
|
|
|
|
def test_path_similarity_identical():
|
|
assert path_similarity(
|
|
["国内", "系统限制", "前台打断"],
|
|
["国内", "系统限制", "前台打断"],
|
|
) == 1.0
|
|
|
|
def test_path_similarity_partial():
|
|
sim = path_similarity(
|
|
["国内", "系统限制", "前台打断"],
|
|
["国内", "系统限制", "后台限制启动"],
|
|
)
|
|
# 2/3 set overlap, sequential 3/5 ≈ 0.6
|
|
assert 0.4 < sim < 0.9, f"expected 0.4-0.9, got {sim}"
|
|
|
|
def test_path_similarity_different():
|
|
sim = path_similarity(["国内"], ["海外"])
|
|
assert sim < 0.7, f"expected < 0.7, got {sim}"
|
|
|
|
|
|
# =============================================================================
|
|
# Test 5: unit_similarity
|
|
# =============================================================================
|
|
|
|
def test_unit_similarity_identical():
|
|
u = _mk_unit("U1", "国内-系统限制-前台打断",
|
|
["国内", "系统限制", "前台打断"],
|
|
["n1", "n2", "n3", "n8", "n19"])
|
|
assert unit_similarity(u, u) > 0.99
|
|
|
|
def test_unit_similarity_different():
|
|
u1 = _mk_unit("U1", "a", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3"])
|
|
u2 = _mk_unit("U2", "b", ["海外", "SDK限制"], ["n10", "n11", "n12"])
|
|
assert unit_similarity(u1, u2) < 0.3
|
|
|
|
|
|
# =============================================================================
|
|
# Test 6: cluster_concepts
|
|
# =============================================================================
|
|
|
|
def test_cluster_concepts_identical():
|
|
v0 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")]
|
|
v1 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")]
|
|
v2 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")]
|
|
clusters = cluster_concepts([v0, v1, v2])
|
|
# Should have exactly 3 clusters (国内, 海外, 系统限制)
|
|
assert len(clusters) == 3, f"expected 3 clusters, got {len(clusters)}"
|
|
for c in clusters:
|
|
assert len(c) == 3, f"expected each cluster to have 3 members, got {len(c)}"
|
|
|
|
def test_cluster_concepts_name_variation():
|
|
v0 = [_mk_concept("国内行车娱乐限制", parent="国内")]
|
|
v1 = [_mk_concept("行车娱乐限制", parent="国内")]
|
|
v2 = [_mk_concept("国内行车娱乐限制", parent="国内")]
|
|
clusters = cluster_concepts([v0, v1, v2])
|
|
assert len(clusters) == 1, f"expected 1 cluster, got {len(clusters)}"
|
|
assert len(clusters[0]) == 3, f"expected 3 members, got {len(clusters[0])}"
|
|
|
|
|
|
# =============================================================================
|
|
# Test 7: merge_concept_cluster
|
|
# =============================================================================
|
|
|
|
def test_merge_concept_cluster():
|
|
cluster = [
|
|
(0, _mk_concept("国内行车娱乐限制", parent="国内", aliases=["限制"])),
|
|
(1, _mk_concept("行车娱乐限制", parent="国内", aliases=["行车限制"])),
|
|
(2, _mk_concept("行车娱乐限制", parent="国内", aliases=["限制"])),
|
|
]
|
|
merged, conf = merge_concept_cluster(cluster, 3)
|
|
assert "行车娱乐限制" in merged["name"]
|
|
assert merged["parent"] == "国内"
|
|
assert set(merged["aliases"]) == {"限制", "行车限制"}
|
|
assert conf in ("high", "medium")
|
|
|
|
|
|
# =============================================================================
|
|
# Test 8: cluster_function_units
|
|
# =============================================================================
|
|
|
|
def test_cluster_function_units_all_agree():
|
|
u0 = _mk_unit("U-001", "国内-系统限制-前台打断",
|
|
["国内", "系统限制", "前台打断"],
|
|
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
|
"switch ON, system app, foreground, speed>=15, non-P, interrupt + toast")
|
|
u1 = _mk_unit("U-001", "国内-系统限制-前台打断",
|
|
["国内", "系统限制", "前台打断"],
|
|
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
|
"switch ON, system app, foreground, speed>=15, non-P, interrupt + toast")
|
|
u2 = _mk_unit("U-001", "国内-系统限制-前台打断",
|
|
["国内", "系统限制", "前台打断"],
|
|
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
|
"switch ON, system app, foreground, interrupt")
|
|
clusters = cluster_function_units([[u0], [u1], [u2]])
|
|
assert len(clusters) == 1, f"expected 1 cluster, got {len(clusters)}"
|
|
assert len(clusters[0]) == 3
|
|
|
|
def test_cluster_function_units_partial_agree():
|
|
u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
|
["n1", "n2", "n3", "n8", "n19"])
|
|
u1 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
|
["n1", "n2", "n3", "n8", "n19"])
|
|
u2 = _mk_unit("U-002", "禁止", ["国内", "系统限制", "后台限制启动"],
|
|
["n5", "n6"])
|
|
clusters = cluster_function_units([[u0], [u1], [u2]])
|
|
# u0+u1 in one cluster, u2 in another
|
|
assert len(clusters) == 2, f"expected 2 clusters, got {len(clusters)}"
|
|
cluster_sizes = sorted(len(c) for c in clusters)
|
|
assert cluster_sizes == [1, 2], f"expected cluster sizes [1,2], got {cluster_sizes}"
|
|
|
|
def test_cluster_function_units_all_disagree():
|
|
u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3"])
|
|
u1 = _mk_unit("U-002", "禁止", ["国内", "系统限制", "后台限制启动"], ["n5", "n6"])
|
|
u2 = _mk_unit("U-003", "SDK", ["国内", "SDK限制"], ["n10", "n11"])
|
|
clusters = cluster_function_units([[u0], [u1], [u2]])
|
|
assert len(clusters) == 3, f"expected 3 clusters, got {len(clusters)}"
|
|
|
|
|
|
# =============================================================================
|
|
# Test 9: pick_best_representative
|
|
# =============================================================================
|
|
|
|
def test_pick_best_representative_prefers_rich():
|
|
u0 = _mk_unit("U-001", "short", ["国内", "系统限制"],
|
|
["n1", "n2", "n3"],
|
|
description="short desc")
|
|
u1 = _mk_unit("U-001", "detailed", ["国内", "系统限制", "前台打断"],
|
|
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
|
description="very detailed description of the full rule behavior " * 5)
|
|
cluster = [(0, u0), (1, u1)]
|
|
best = pick_best_representative(cluster)
|
|
# u1 should win: more nodes, longer description, though u0 has lower temp
|
|
assert best["name"] == "detailed"
|
|
|
|
|
|
# =============================================================================
|
|
# Test 10: compute_confidence_versions
|
|
# =============================================================================
|
|
|
|
def test_confidence_high_unanimous():
|
|
assert compute_confidence_versions(3, 3, True) == "high"
|
|
|
|
def test_confidence_high_two_of_three_with_t0():
|
|
assert compute_confidence_versions(2, 3, True) == "high"
|
|
|
|
def test_confidence_medium_two_of_three_without_t0():
|
|
assert compute_confidence_versions(2, 3, False) == "medium"
|
|
|
|
def test_confidence_low_one_of_three():
|
|
assert compute_confidence_versions(1, 3, False) == "low"
|
|
|
|
def test_confidence_high_all_two_versions():
|
|
assert compute_confidence_versions(2, 2, True) == "high"
|
|
|
|
|
|
# =============================================================================
|
|
# Test 11: ensemble_merge_concepts
|
|
# =============================================================================
|
|
|
|
def test_ensemble_merge_concepts():
|
|
v0 = [_mk_concept("国内"), _mk_concept("海外"),
|
|
_mk_concept("国内行车娱乐限制", parent="国内")]
|
|
v1 = [_mk_concept("国内"), _mk_concept("海外"),
|
|
_mk_concept("行车娱乐限制", parent="国内",
|
|
aliases=["限制"], defined_in=["3.1", "3.1.1"])]
|
|
v2 = [_mk_concept("国内"), _mk_concept("海外"),
|
|
_mk_concept("行车娱乐限制", parent="国内")]
|
|
|
|
merged = ensemble_merge_concepts([v0, v1, v2])
|
|
# Should merge the 3 concepts across 3 versions into 3 clusters
|
|
assert len(merged) == 3, f"expected 3 merged concepts, got {len(merged)}"
|
|
for c in merged:
|
|
assert "confidence" in c
|
|
assert "ensemble_support" in c
|
|
assert c["ensemble_support"] == "3/3"
|
|
|
|
|
|
# =============================================================================
|
|
# Test 12: ensemble_merge_function_units
|
|
# =============================================================================
|
|
|
|
def test_ensemble_merge_function_units():
|
|
u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
|
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
|
description="full description A")
|
|
u1 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
|
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
|
description="full description B (more detail)")
|
|
u2 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
|
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25"],
|
|
description="partial description")
|
|
|
|
merged = ensemble_merge_function_units([[u0], [u1], [u2]])
|
|
assert len(merged) == 1, f"expected 1 unit, got {len(merged)}"
|
|
unit = merged[0]
|
|
assert unit["confidence"] == "high"
|
|
assert unit["ensemble_support"] == "3/3"
|
|
assert unit["source_versions"] == 3
|
|
assert unit["unit_id"].startswith("FU-ENS-")
|
|
# Should have picked u1 (more detail)
|
|
assert "more detail" in unit["description"]
|
|
|
|
|
|
# =============================================================================
|
|
# Test 13: ensemble_merge full integration
|
|
# =============================================================================
|
|
|
|
def test_ensemble_merge_full():
|
|
v0 = {
|
|
"feature_name": "行车娱乐限制",
|
|
"concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")],
|
|
"function_units": [
|
|
_mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
|
["n1", "n2", "n3", "n8", "n19", "n25", "n26"]),
|
|
_mk_unit("U-002", "后台禁止", ["国内", "系统限制", "后台限制启动"],
|
|
["n5", "n6"]),
|
|
],
|
|
}
|
|
v1 = {
|
|
"feature_name": "行车娱乐限制",
|
|
"concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")],
|
|
"function_units": [
|
|
_mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
|
["n1", "n2", "n3", "n8", "n19", "n25", "n26"]),
|
|
_mk_unit("U-003", "SDK自定义", ["国内", "SDK限制", "自定义限制"],
|
|
["n10", "n11"]),
|
|
],
|
|
}
|
|
v2 = {
|
|
"feature_name": "行车娱乐限制",
|
|
"concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")],
|
|
"function_units": [
|
|
_mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
|
["n1", "n2", "n3", "n8", "n19", "n25", "n26"]),
|
|
],
|
|
}
|
|
|
|
result = ensemble_merge([v0, v1, v2])
|
|
|
|
assert result["feature_name"] == "行车娱乐限制"
|
|
assert result["ensemble_versions"] == 3
|
|
|
|
units = result["function_units"]
|
|
concepts = result["concepts"]
|
|
|
|
# Concepts: 国内 + 系统限制
|
|
assert len(concepts) == 2
|
|
|
|
# Units: 打断 (3 versions → high), 后台禁止 (1 version → low), SDK (1 version → low)
|
|
assert len(units) == 3
|
|
|
|
high_units = [u for u in units if u["confidence"] == "high"]
|
|
low_units = [u for u in units if u["confidence"] == "low"]
|
|
assert len(high_units) == 1
|
|
assert len(low_units) == 2
|
|
|
|
# All units should have ensemble fields
|
|
for u in units:
|
|
assert "confidence" in u
|
|
assert "ensemble_support" in u
|
|
assert "source_versions" in u
|
|
|
|
# Confidence summary
|
|
cs = result["confidence_summary"]
|
|
assert cs["total_units"] == 3
|
|
assert cs["high"] == 1
|
|
assert cs["low"] == 2
|
|
|
|
|
|
# =============================================================================
|
|
# Runner
|
|
# =============================================================================
|
|
|
|
def run_all_tests():
|
|
print("=" * 60)
|
|
print("Ensemble Merge 测试 (纯 Python, 无 LLM)")
|
|
print("=" * 60)
|
|
|
|
tests = [
|
|
("concept_name_similarity exact", test_concept_name_similarity_exact),
|
|
("concept_name_similarity substring", test_concept_name_similarity_substring),
|
|
("concept_name_similarity different", test_concept_name_similarity_different),
|
|
("concept_name_similarity seq_matcher", test_concept_name_similarity_seq_matcher),
|
|
("collect_logic_tree_nodes", test_collect_logic_tree_nodes),
|
|
("collect_logic_tree_nodes empty", test_collect_logic_tree_nodes_empty),
|
|
("unit_node_jaccard identical", test_unit_node_jaccard_identical),
|
|
("unit_node_jaccard partial", test_unit_node_jaccard_partial),
|
|
("unit_node_jaccard disjoint", test_unit_node_jaccard_disjoint),
|
|
("unit_node_jaccard both_empty", test_unit_node_jaccard_both_empty),
|
|
("path_similarity identical", test_path_similarity_identical),
|
|
("path_similarity partial", test_path_similarity_partial),
|
|
("path_similarity different", test_path_similarity_different),
|
|
("unit_similarity identical", test_unit_similarity_identical),
|
|
("unit_similarity different", test_unit_similarity_different),
|
|
("cluster_concepts identical", test_cluster_concepts_identical),
|
|
("cluster_concepts name variation", test_cluster_concepts_name_variation),
|
|
("merge_concept_cluster", test_merge_concept_cluster),
|
|
("cluster_function_units all_agree", test_cluster_function_units_all_agree),
|
|
("cluster_function_units partial_agree", test_cluster_function_units_partial_agree),
|
|
("cluster_function_units all_disagree", test_cluster_function_units_all_disagree),
|
|
("pick_best_representative", test_pick_best_representative_prefers_rich),
|
|
("confidence high unanimous", test_confidence_high_unanimous),
|
|
("confidence high 2/3 with t0", test_confidence_high_two_of_three_with_t0),
|
|
("confidence medium 2/3 no t0", test_confidence_medium_two_of_three_without_t0),
|
|
("confidence low 1/3", test_confidence_low_one_of_three),
|
|
("confidence high 2/2", test_confidence_high_all_two_versions),
|
|
("ensemble_merge_concepts", test_ensemble_merge_concepts),
|
|
("ensemble_merge_function_units", test_ensemble_merge_function_units),
|
|
("ensemble_merge full", test_ensemble_merge_full),
|
|
]
|
|
|
|
passed = 0
|
|
failed = 0
|
|
for name, test_fn in tests:
|
|
try:
|
|
test_fn()
|
|
print(f" {PASS} {name}")
|
|
passed += 1
|
|
except AssertionError as e:
|
|
print(f" {FAIL} {name}: {e}")
|
|
failed += 1
|
|
except Exception as e:
|
|
print(f" {FAIL} {name}: unexpected {type(e).__name__}: {e}")
|
|
failed += 1
|
|
|
|
print(f"\n{'='*60}")
|
|
if failed == 0:
|
|
print(f"{PASS} 所有 {passed} 个测试通过!")
|
|
else:
|
|
print(f"{FAIL} {failed}/{passed + failed} 个测试失败")
|
|
print(f"{'='*60}")
|
|
|
|
return failed == 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
success = run_all_tests()
|
|
sys.exit(0 if success else 1)
|