""" Tests for ensemble_merge.py — all pure Python, no LLM calls, no file I/O. Each test uses hardcoded mock data to verify one piece of the merge logic. """ import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from ensemble_merge import ( concept_name_similarity, cluster_concepts, merge_concept_cluster, unit_node_jaccard, path_similarity, unit_similarity, cluster_function_units, pick_best_representative, compute_confidence_versions, ensemble_merge_concepts, ensemble_merge_function_units, ensemble_merge, _collect_logic_tree_nodes, ) PASS = "[PASS]" FAIL = "[FAIL]" # ---- Mock helpers ---- def _mk_unit(unit_id, name, path, logic_tree_nodes, description="", sources=None): """Create a minimal function_unit dict for testing.""" if sources is None: srcs = [] if logic_tree_nodes: srcs.append({ "image_id": "rId16", "type": "logic_tree", "logic_tree_nodes": logic_tree_nodes, }) if not srcs: srcs.append({ "section": "3.1", "type": "table", "text_snippet": "test", }) else: srcs = sources return { "unit_id": unit_id, "name": name, "description": description or f"desc for {name}", "path": path, "sources": srcs, } def _mk_concept(name, parent=None, aliases=None, defined_in=None): """Create a minimal concept dict for testing.""" return { "name": name, "aliases": aliases or [], "defined_in": defined_in or ["3.1"], "parent": parent, } # ============================================================================= # Test 1: concept_name_similarity # ============================================================================= def test_concept_name_similarity_exact(): assert concept_name_similarity("国内", "国内") == 1.0 assert concept_name_similarity("行车娱乐限制", "行车娱乐限制") == 1.0 def test_concept_name_similarity_substring(): sim = concept_name_similarity("国内行车娱乐限制", "行车娱乐限制") assert sim >= 0.85, f"expected >= 0.85, got {sim}" def test_concept_name_similarity_different(): sim = concept_name_similarity("国内", "海外") assert sim < 0.7, f"expected < 0.7, got {sim}" def test_concept_name_similarity_seq_matcher(): sim = concept_name_similarity("前台打断", "前台应用打断") assert 0.6 < sim < 0.95, f"expected 0.6-0.95, got {sim}" # ============================================================================= # Test 2: _collect_logic_tree_nodes # ============================================================================= def test_collect_logic_tree_nodes(): unit = _mk_unit("U1", "test", ["A"], ["n1", "n2", "n3"]) nodes = _collect_logic_tree_nodes(unit) assert nodes == {"n1", "n2", "n3"} def test_collect_logic_tree_nodes_empty(): unit = _mk_unit("U2", "test", ["A"], [], sources=[{"section": "3.1", "type": "table"}]) nodes = _collect_logic_tree_nodes(unit) assert nodes == set() # ============================================================================= # Test 3: unit_node_jaccard # ============================================================================= def test_unit_node_jaccard_identical(): u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2", "n3"]) u2 = _mk_unit("U2", "b", ["A"], ["n1", "n2", "n3"]) assert unit_node_jaccard(u1, u2) == 1.0 def test_unit_node_jaccard_partial(): u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2", "n3", "n4"]) u2 = _mk_unit("U2", "b", ["A"], ["n1", "n2", "n3"]) # intersection=3, union=4 assert abs(unit_node_jaccard(u1, u2) - 0.75) < 0.01 def test_unit_node_jaccard_disjoint(): u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2"]) u2 = _mk_unit("U2", "b", ["B"], ["n3", "n4"]) assert unit_node_jaccard(u1, u2) == 0.0 def test_unit_node_jaccard_both_empty(): u1 = _mk_unit("U1", "a", ["A"], [], sources=[{"section": "3.1", "type": "table"}]) u2 = _mk_unit("U2", "b", ["B"], [], sources=[{"section": "3.1", "type": "table"}]) assert unit_node_jaccard(u1, u2) == 0.0 # ============================================================================= # Test 4: path_similarity # ============================================================================= def test_path_similarity_identical(): assert path_similarity( ["国内", "系统限制", "前台打断"], ["国内", "系统限制", "前台打断"], ) == 1.0 def test_path_similarity_partial(): sim = path_similarity( ["国内", "系统限制", "前台打断"], ["国内", "系统限制", "后台限制启动"], ) # 2/3 set overlap, sequential 3/5 ≈ 0.6 assert 0.4 < sim < 0.9, f"expected 0.4-0.9, got {sim}" def test_path_similarity_different(): sim = path_similarity(["国内"], ["海外"]) assert sim < 0.7, f"expected < 0.7, got {sim}" # ============================================================================= # Test 5: unit_similarity # ============================================================================= def test_unit_similarity_identical(): u = _mk_unit("U1", "国内-系统限制-前台打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3", "n8", "n19"]) assert unit_similarity(u, u) > 0.99 def test_unit_similarity_different(): u1 = _mk_unit("U1", "a", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3"]) u2 = _mk_unit("U2", "b", ["海外", "SDK限制"], ["n10", "n11", "n12"]) assert unit_similarity(u1, u2) < 0.3 # ============================================================================= # Test 6: cluster_concepts # ============================================================================= def test_cluster_concepts_identical(): v0 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")] v1 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")] v2 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")] clusters = cluster_concepts([v0, v1, v2]) # Should have exactly 3 clusters (国内, 海外, 系统限制) assert len(clusters) == 3, f"expected 3 clusters, got {len(clusters)}" for c in clusters: assert len(c) == 3, f"expected each cluster to have 3 members, got {len(c)}" def test_cluster_concepts_name_variation(): v0 = [_mk_concept("国内行车娱乐限制", parent="国内")] v1 = [_mk_concept("行车娱乐限制", parent="国内")] v2 = [_mk_concept("国内行车娱乐限制", parent="国内")] clusters = cluster_concepts([v0, v1, v2]) assert len(clusters) == 1, f"expected 1 cluster, got {len(clusters)}" assert len(clusters[0]) == 3, f"expected 3 members, got {len(clusters[0])}" # ============================================================================= # Test 7: merge_concept_cluster # ============================================================================= def test_merge_concept_cluster(): cluster = [ (0, _mk_concept("国内行车娱乐限制", parent="国内", aliases=["限制"])), (1, _mk_concept("行车娱乐限制", parent="国内", aliases=["行车限制"])), (2, _mk_concept("行车娱乐限制", parent="国内", aliases=["限制"])), ] merged, conf = merge_concept_cluster(cluster, 3) assert "行车娱乐限制" in merged["name"] assert merged["parent"] == "国内" assert set(merged["aliases"]) == {"限制", "行车限制"} assert conf in ("high", "medium") # ============================================================================= # Test 8: cluster_function_units # ============================================================================= def test_cluster_function_units_all_agree(): u0 = _mk_unit("U-001", "国内-系统限制-前台打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"], "switch ON, system app, foreground, speed>=15, non-P, interrupt + toast") u1 = _mk_unit("U-001", "国内-系统限制-前台打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"], "switch ON, system app, foreground, speed>=15, non-P, interrupt + toast") u2 = _mk_unit("U-001", "国内-系统限制-前台打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"], "switch ON, system app, foreground, interrupt") clusters = cluster_function_units([[u0], [u1], [u2]]) assert len(clusters) == 1, f"expected 1 cluster, got {len(clusters)}" assert len(clusters[0]) == 3 def test_cluster_function_units_partial_agree(): u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3", "n8", "n19"]) u1 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3", "n8", "n19"]) u2 = _mk_unit("U-002", "禁止", ["国内", "系统限制", "后台限制启动"], ["n5", "n6"]) clusters = cluster_function_units([[u0], [u1], [u2]]) # u0+u1 in one cluster, u2 in another assert len(clusters) == 2, f"expected 2 clusters, got {len(clusters)}" cluster_sizes = sorted(len(c) for c in clusters) assert cluster_sizes == [1, 2], f"expected cluster sizes [1,2], got {cluster_sizes}" def test_cluster_function_units_all_disagree(): u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3"]) u1 = _mk_unit("U-002", "禁止", ["国内", "系统限制", "后台限制启动"], ["n5", "n6"]) u2 = _mk_unit("U-003", "SDK", ["国内", "SDK限制"], ["n10", "n11"]) clusters = cluster_function_units([[u0], [u1], [u2]]) assert len(clusters) == 3, f"expected 3 clusters, got {len(clusters)}" # ============================================================================= # Test 9: pick_best_representative # ============================================================================= def test_pick_best_representative_prefers_rich(): u0 = _mk_unit("U-001", "short", ["国内", "系统限制"], ["n1", "n2", "n3"], description="short desc") u1 = _mk_unit("U-001", "detailed", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"], description="very detailed description of the full rule behavior " * 5) cluster = [(0, u0), (1, u1)] best = pick_best_representative(cluster) # u1 should win: more nodes, longer description, though u0 has lower temp assert best["name"] == "detailed" # ============================================================================= # Test 10: compute_confidence_versions # ============================================================================= def test_confidence_high_unanimous(): assert compute_confidence_versions(3, 3, True) == "high" def test_confidence_high_two_of_three_with_t0(): assert compute_confidence_versions(2, 3, True) == "high" def test_confidence_medium_two_of_three_without_t0(): assert compute_confidence_versions(2, 3, False) == "medium" def test_confidence_low_one_of_three(): assert compute_confidence_versions(1, 3, False) == "low" def test_confidence_high_all_two_versions(): assert compute_confidence_versions(2, 2, True) == "high" # ============================================================================= # Test 11: ensemble_merge_concepts # ============================================================================= def test_ensemble_merge_concepts(): v0 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("国内行车娱乐限制", parent="国内")] v1 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("行车娱乐限制", parent="国内", aliases=["限制"], defined_in=["3.1", "3.1.1"])] v2 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("行车娱乐限制", parent="国内")] merged = ensemble_merge_concepts([v0, v1, v2]) # Should merge the 3 concepts across 3 versions into 3 clusters assert len(merged) == 3, f"expected 3 merged concepts, got {len(merged)}" for c in merged: assert "confidence" in c assert "ensemble_support" in c assert c["ensemble_support"] == "3/3" # ============================================================================= # Test 12: ensemble_merge_function_units # ============================================================================= def test_ensemble_merge_function_units(): u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"], description="full description A") u1 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"], description="full description B (more detail)") u2 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25"], description="partial description") merged = ensemble_merge_function_units([[u0], [u1], [u2]]) assert len(merged) == 1, f"expected 1 unit, got {len(merged)}" unit = merged[0] assert unit["confidence"] == "high" assert unit["ensemble_support"] == "3/3" assert unit["source_versions"] == 3 assert unit["unit_id"].startswith("FU-ENS-") # Should have picked u1 (more detail) assert "more detail" in unit["description"] # ============================================================================= # Test 13: ensemble_merge full integration # ============================================================================= def test_ensemble_merge_full(): v0 = { "feature_name": "行车娱乐限制", "concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")], "function_units": [ _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3", "n8", "n19", "n25", "n26"]), _mk_unit("U-002", "后台禁止", ["国内", "系统限制", "后台限制启动"], ["n5", "n6"]), ], } v1 = { "feature_name": "行车娱乐限制", "concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")], "function_units": [ _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3", "n8", "n19", "n25", "n26"]), _mk_unit("U-003", "SDK自定义", ["国内", "SDK限制", "自定义限制"], ["n10", "n11"]), ], } v2 = { "feature_name": "行车娱乐限制", "concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")], "function_units": [ _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3", "n8", "n19", "n25", "n26"]), ], } result = ensemble_merge([v0, v1, v2]) assert result["feature_name"] == "行车娱乐限制" assert result["ensemble_versions"] == 3 units = result["function_units"] concepts = result["concepts"] # Concepts: 国内 + 系统限制 assert len(concepts) == 2 # Units: 打断 (3 versions → high), 后台禁止 (1 version → low), SDK (1 version → low) assert len(units) == 3 high_units = [u for u in units if u["confidence"] == "high"] low_units = [u for u in units if u["confidence"] == "low"] assert len(high_units) == 1 assert len(low_units) == 2 # All units should have ensemble fields for u in units: assert "confidence" in u assert "ensemble_support" in u assert "source_versions" in u # Confidence summary cs = result["confidence_summary"] assert cs["total_units"] == 3 assert cs["high"] == 1 assert cs["low"] == 2 # ============================================================================= # Runner # ============================================================================= def run_all_tests(): print("=" * 60) print("Ensemble Merge 测试 (纯 Python, 无 LLM)") print("=" * 60) tests = [ ("concept_name_similarity exact", test_concept_name_similarity_exact), ("concept_name_similarity substring", test_concept_name_similarity_substring), ("concept_name_similarity different", test_concept_name_similarity_different), ("concept_name_similarity seq_matcher", test_concept_name_similarity_seq_matcher), ("collect_logic_tree_nodes", test_collect_logic_tree_nodes), ("collect_logic_tree_nodes empty", test_collect_logic_tree_nodes_empty), ("unit_node_jaccard identical", test_unit_node_jaccard_identical), ("unit_node_jaccard partial", test_unit_node_jaccard_partial), ("unit_node_jaccard disjoint", test_unit_node_jaccard_disjoint), ("unit_node_jaccard both_empty", test_unit_node_jaccard_both_empty), ("path_similarity identical", test_path_similarity_identical), ("path_similarity partial", test_path_similarity_partial), ("path_similarity different", test_path_similarity_different), ("unit_similarity identical", test_unit_similarity_identical), ("unit_similarity different", test_unit_similarity_different), ("cluster_concepts identical", test_cluster_concepts_identical), ("cluster_concepts name variation", test_cluster_concepts_name_variation), ("merge_concept_cluster", test_merge_concept_cluster), ("cluster_function_units all_agree", test_cluster_function_units_all_agree), ("cluster_function_units partial_agree", test_cluster_function_units_partial_agree), ("cluster_function_units all_disagree", test_cluster_function_units_all_disagree), ("pick_best_representative", test_pick_best_representative_prefers_rich), ("confidence high unanimous", test_confidence_high_unanimous), ("confidence high 2/3 with t0", test_confidence_high_two_of_three_with_t0), ("confidence medium 2/3 no t0", test_confidence_medium_two_of_three_without_t0), ("confidence low 1/3", test_confidence_low_one_of_three), ("confidence high 2/2", test_confidence_high_all_two_versions), ("ensemble_merge_concepts", test_ensemble_merge_concepts), ("ensemble_merge_function_units", test_ensemble_merge_function_units), ("ensemble_merge full", test_ensemble_merge_full), ] passed = 0 failed = 0 for name, test_fn in tests: try: test_fn() print(f" {PASS} {name}") passed += 1 except AssertionError as e: print(f" {FAIL} {name}: {e}") failed += 1 except Exception as e: print(f" {FAIL} {name}: unexpected {type(e).__name__}: {e}") failed += 1 print(f"\n{'='*60}") if failed == 0: print(f"{PASS} 所有 {passed} 个测试通过!") else: print(f"{FAIL} {failed}/{passed + failed} 个测试失败") print(f"{'='*60}") return failed == 0 if __name__ == "__main__": success = run_all_tests() sys.exit(0 if success else 1)