""" Deterministic ensemble merge for semantic index generation. All functions are pure Python with zero LLM calls. Fully testable with mock data. Cross-references N semantic_index outputs (generated with different temperatures) and produces a single merged index with confidence scores. Used by: step1_semantic_index.py Tested by: tests/test_ensemble_merge.py """ from collections import defaultdict from difflib import SequenceMatcher # ============================================================================= # Concept Name Similarity # ============================================================================= def concept_name_similarity(name_a: str, name_b: str) -> float: """Compute similarity between two concept names for cross-version matching. Strategy (in order of precedence): 1. Exact string match -> 1.0 2. Substring containment (one is a substring of the other) -> 0.9 3. SequenceMatcher ratio on character sequences -> 0.0-1.0 Returns: float in [0.0, 1.0] where >= 0.7 means "likely the same concept". """ if name_a == name_b: return 1.0 # Substring containment: one name is contained in the other if name_a in name_b or name_b in name_a: # Only count as similar if they're of comparable length # (avoid matching "国内" with "国内行车娱乐限制") len_ratio = min(len(name_a), len(name_b)) / max(len(name_a), len(name_b)) if len_ratio >= 0.5: return 0.85 + 0.05 * len_ratio # range 0.875-0.90 return 0.55 # too different in length → below threshold return SequenceMatcher(None, name_a, name_b).ratio() # ============================================================================= # Concept Clustering & Merging # ============================================================================= def cluster_concepts( all_concepts_lists: list[list[dict]], similarity_threshold: float = 0.7, ) -> list[list[tuple[int, dict]]]: """Group concepts across ensemble versions by name similarity. Uses greedy single-pass clustering: for each concept, find the best-matching existing cluster. If max similarity >= threshold, add to it; otherwise, create a new cluster. Args: all_concepts_lists: List of concept lists, one per ensemble version. all_concepts_lists[i] = concepts from version i. similarity_threshold: Minimum name similarity to join a cluster. Returns: List of clusters. Each cluster is list of (version_idx, concept_dict). """ clusters = [] # type: list[list[tuple[int, dict]]] for version_idx, concepts in enumerate(all_concepts_lists): for c in concepts: name = c.get("name", "") if not name: continue best_cluster = None best_sim = 0.0 for cluster in clusters: # Compare against the first member of the cluster (seed) seed_name = cluster[0][1].get("name", "") sim = concept_name_similarity(name, seed_name) if sim > best_sim: best_sim = sim best_cluster = cluster if best_cluster is not None and best_sim >= similarity_threshold: best_cluster.append((version_idx, c)) else: clusters.append([(version_idx, c)]) return clusters def merge_concept_cluster( cluster: list[tuple[int, dict]], total_versions: int, ) -> tuple[dict, str]: """Merge a single cluster of matched concepts into one concept dict. Rules: - name: Longest name (most specific). Tie-break by lower version_idx. - aliases: Union of all aliases across versions. - defined_in: Union of all defined_in across versions. - parent: Most common non-null parent (voting). Tie-break by lower version_idx. Returns: (merged_concept_dict, confidence_level) where confidence is "high"/"medium"/"low". """ if not cluster: return {}, "low" # --- name: longest (most specific) --- best_name = "" best_name_len = 0 for v_idx, c in cluster: n = c.get("name", "") if len(n) > best_name_len: best_name = n best_name_len = len(n) elif len(n) == best_name_len and v_idx < cluster[0][0]: # lower version idx best_name = n # --- aliases: union --- aliases = set() for _, c in cluster: for a in c.get("aliases", []): aliases.add(a) # --- defined_in: union --- defined_in = set() for _, c in cluster: for d in c.get("defined_in", []): defined_in.add(d) # --- parent: most common non-null parent (vote) --- parent_votes = defaultdict(int) for v_idx, c in cluster: p = c.get("parent") if p is not None: parent_votes[p] += 1 if parent_votes: best_parent = max(parent_votes, key=lambda p: (parent_votes[p], -1)) else: best_parent = None # --- confidence --- versions_present = len({v_idx for v_idx, _ in cluster}) confidence = compute_confidence_versions(versions_present, total_versions, any(v_idx == 0 for v_idx, _ in cluster)) merged = { "name": best_name, "aliases": sorted(aliases), "defined_in": sorted(defined_in), "parent": best_parent, "confidence": confidence, } return merged, confidence # ============================================================================= # Unit Similarity Functions # ============================================================================= def _collect_logic_tree_nodes(unit: dict) -> set[str]: """Extract the flattened set of all logic tree node IDs from a function_unit.""" nodes = set() for src in unit.get("sources", []): if src.get("type") == "logic_tree": nodes.update(src.get("logic_tree_nodes", [])) return nodes def unit_node_jaccard(unit_a: dict, unit_b: dict) -> float: """Compute Jaccard similarity on logic tree node sets between two units. Jaccard(A, B) = |A ∩ B| / |A ∪ B|. Returns 0.0 if both have no nodes. """ nodes_a = _collect_logic_tree_nodes(unit_a) nodes_b = _collect_logic_tree_nodes(unit_b) if not nodes_a and not nodes_b: return 0.0 if not nodes_a or not nodes_b: return 0.0 intersection = nodes_a & nodes_b union = nodes_a | nodes_b return len(intersection) / len(union) def path_similarity(path_a: list[str], path_b: list[str]) -> float: """Compute similarity between two path arrays. Hybrid approach: - Sequential similarity (order-aware): SequenceMatcher on joined strings. - Set similarity (order-independent): Jaccard on path element sets. - Final score: 0.5 * seq_sim + 0.5 * set_sim Returns: float in [0.0, 1.0]. """ if not path_a and not path_b: return 1.0 if not path_a or not path_b: return 0.0 # Sequential similarity joined_a = "|".join(path_a) joined_b = "|".join(path_b) seq_sim = SequenceMatcher(None, joined_a, joined_b).ratio() # Set similarity set_a = set(path_a) set_b = set(path_b) set_sim = len(set_a & set_b) / len(set_a | set_b) return 0.5 * seq_sim + 0.5 * set_sim def unit_similarity(unit_a: dict, unit_b: dict) -> float: """Combined similarity between two function_units. Weighted combination: - 0.6 * unit_node_jaccard (primary signal: same logic tree nodes = same rule) - 0.4 * path_similarity (secondary signal: semantic agreement) Returns: float in [0.0, 1.0]. >= 0.5 means "likely the same function_unit". """ return 0.6 * unit_node_jaccard(unit_a, unit_b) + 0.4 * path_similarity( unit_a.get("path", []), unit_b.get("path", []) ) # ============================================================================= # Function Unit Clustering & Merging # ============================================================================= def cluster_function_units( all_units_lists: list[list[dict]], similarity_threshold: float = 0.5, ) -> list[list[tuple[int, dict]]]: """Group function_units across ensemble versions by content similarity. Lowest-temperature versions are processed first (most stable → cluster seeds). Higher-temperature variants join existing clusters if similar enough. Args: all_units_lists: List of unit lists, one per ensemble version. similarity_threshold: Minimum unit_similarity to join a cluster. Returns: List of clusters. Each cluster is list of (version_idx, unit_dict). """ clusters = [] # type: list[list[tuple[int, dict]]] for version_idx, units in enumerate(all_units_lists): for unit in units: best_cluster = None best_sim = 0.0 for cluster in clusters: # Compare against all members already in the cluster cluster_sim = max( unit_similarity(unit, existing_unit) for (_, existing_unit) in cluster ) if cluster_sim > best_sim: best_sim = cluster_sim best_cluster = cluster if best_cluster is not None and best_sim >= similarity_threshold: best_cluster.append((version_idx, unit)) else: clusters.append([(version_idx, unit)]) return clusters def pick_best_representative( cluster: list[tuple[int, dict]], ) -> dict: """Select the best function_unit from a cluster as the merged representative. Scoring formula (all normalized to [0, 1]): - 0.35: Node count (more logic_tree_nodes = more complete trace) - 0.25: Source count (more sources = more evidence) - 0.20: Description length (longer = more detail, capped at 500 chars) - 0.20: Temperature rank (lower version_idx = lower temp = more stable) Returns a deep copy of the winning unit dict. """ if not cluster: return {} # Compute max values for normalization max_nodes = max( len(_collect_logic_tree_nodes(unit)) for _, unit in cluster ) max_sources = max( len(unit.get("sources", [])) for _, unit in cluster ) max_desc_len = max( len(unit.get("description", "")) for _, unit in cluster ) max_version_idx = max(v_idx for v_idx, _ in cluster) num_versions = len(cluster) def score(v_idx: int, unit: dict) -> float: nodes = len(_collect_logic_tree_nodes(unit)) sources = len(unit.get("sources", [])) desc_len = min(len(unit.get("description", "")), 500) temp_rank = 1.0 - (v_idx / max(num_versions, max_version_idx + 1)) return ( 0.35 * (nodes / max(1, max_nodes)) + 0.25 * (sources / max(1, max_sources)) + 0.20 * (desc_len / max(1, max_desc_len)) + 0.20 * temp_rank ) best = max(cluster, key=lambda x: score(x[0], x[1])) return dict(best[1]) # deep-ish copy (1 level) def merge_unit_sources( cluster: list[tuple[int, dict]], ) -> list[dict]: """Union all sources from units in a cluster, deduplicating by (type, image_id, section). When the same source key appears in multiple versions, keeps the one with the most logic_tree_nodes. """ # Group by dedup key source_groups = defaultdict(list) for v_idx, unit in cluster: for src in unit.get("sources", []): # Build a dedup key src_type = src.get("type", "") if src_type == "logic_tree": key = ("logic_tree", src.get("image_id", "")) else: key = (src_type, src.get("section", ""), src.get("row", "")) source_groups[key].append(src) # Pick best per group result = [] for key, sources in source_groups.items(): # Pick the source with the most logic_tree_nodes (if any) best = max(sources, key=lambda s: len(s.get("logic_tree_nodes", []))) result.append(dict(best)) return result def compute_confidence_versions( versions_present: int, total_versions: int, includes_lowest_temp: bool = False, ) -> str: """Compute 3-level confidence based on cross-version agreement. - "high": Appears in all versions, OR >= 2/3 with lowest-temp version (T=0.0). - "medium": Appears in >= half the versions but not all. - "low": Appears in fewer than half (singleton in ensemble). Args: versions_present: Number of versions this item appeared in. total_versions: Total number of ensemble versions. includes_lowest_temp: Whether the item appeared in the T=0.0 version. """ ratio = versions_present / total_versions if ratio >= 1.0: return "high" if ratio >= 0.5 and includes_lowest_temp: return "high" if ratio >= 0.5: return "medium" return "low" def ensemble_merge_concepts( all_concepts_lists: list[list[dict]], ) -> list[dict]: """Merge concepts across all ensemble versions. Returns: List of merged concept dicts, each with added "confidence" field. """ total = len(all_concepts_lists) clusters = cluster_concepts(all_concepts_lists) merged = [] seen_names = set() for cluster in clusters: concept, confidence = merge_concept_cluster(cluster, total) name = concept.get("name", "") if name and name not in seen_names: concept["ensemble_support"] = f"{len({v for v, _ in cluster})}/{total}" merged.append(concept) seen_names.add(name) # Sort: high confidence first, then by name conf_order = {"high": 0, "medium": 1, "low": 2} merged.sort(key=lambda c: (conf_order.get(c.get("confidence", "low"), 3), c.get("name", ""))) # Validate and fix parent references merged = _validate_concept_parents(merged) return merged def _validate_concept_parents(concepts: list[dict]) -> list[dict]: """Post-merge: validate that every concept's parent exists in the list. Strategy for dangling parents: 1. Fuzzy match (concept_name_similarity >= 0.7) → fix reference 2. No match → set parent to null, downgrade confidence to "low" """ concept_names = {c["name"] for c in concepts} conf_order = {"high": 0, "medium": 1, "low": 2} for c in concepts: parent = c.get("parent") if parent is None: continue if parent in concept_names: continue # Dangling parent — try fuzzy match best_match = None best_sim = 0.0 for name in concept_names: sim = concept_name_similarity(parent, name) if sim > best_sim: best_sim = sim best_match = name if best_match and best_sim >= 0.7: c["parent"] = best_match # Downgrade if match was fuzzy (not exact) if best_sim < 1.0: current_conf = c.get("confidence", "low") c["confidence"] = _downgrade_confidence(current_conf) else: c["parent"] = None c["confidence"] = _downgrade_confidence(c.get("confidence", "low")) # Re-sort after confidence changes concepts.sort(key=lambda c: (conf_order.get(c.get("confidence", "low"), 3), c.get("name", ""))) return concepts def _downgrade_confidence(current: str) -> str: """Drop confidence one level.""" if current == "high": return "medium" return "low" def ensemble_merge_function_units( all_units_lists: list[list[dict]], ) -> list[dict]: """Merge function_units across all ensemble versions. 1. Cluster units across versions. 2. For each cluster: pick best, merge sources, compute confidence. 3. Reassign stable unit_ids: FU-ENS-001, FU-ENS-002, ... Returns: List of merged function_unit dicts with added "confidence", "ensemble_support", "source_versions" fields. """ total = len(all_units_lists) clusters = cluster_function_units(all_units_lists) merged = [] for cluster in clusters: # Pick best representative best = pick_best_representative(cluster) # Merge sources from all cluster members best["sources"] = merge_unit_sources(cluster) # Compute confidence versions_present = len({v_idx for v_idx, _ in cluster}) includes_t0 = any(v_idx == 0 for v_idx, _ in cluster) confidence = compute_confidence_versions( versions_present, total, includes_t0 ) best["confidence"] = confidence best["ensemble_support"] = f"{versions_present}/{total}" best["source_versions"] = versions_present merged.append(best) # Sort by confidence desc, then by unit_id conf_order = {"high": 0, "medium": 1, "low": 2} merged.sort(key=lambda u: (conf_order.get(u.get("confidence", "low"), 3), u.get("unit_id", ""))) # Reassign stable unit_ids for i, unit in enumerate(merged): # Preserve original unit_id for traceability if "original_unit_id" not in unit: unit["original_unit_id"] = unit.get("unit_id", "") unit["unit_id"] = f"FU-ENS-{i + 1:03d}" return merged # ============================================================================= # Top-Level Ensemble Merge # ============================================================================= def ensemble_merge( semantic_indices: list[dict], ) -> dict: """Merge N semantic index outputs into one ensemble result. Args: semantic_indices: List of semantic_index dicts from each temperature run. semantic_indices[0] should be the lowest-temperature version. Returns: Merged semantic_index dict with structure: { "feature_name": str, "ensemble_versions": int, "concepts": [...], "function_units": [...], "confidence_summary": {...}, } """ if not semantic_indices: return { "feature_name": "", "ensemble_versions": 0, "concepts": [], "function_units": [], "confidence_summary": {}, } total = len(semantic_indices) # Extract concepts and function_units from each version all_concepts = [si.get("concepts", []) for si in semantic_indices] all_units = [si.get("function_units", []) for si in semantic_indices] # Merge merged_concepts = ensemble_merge_concepts(all_concepts) merged_units = ensemble_merge_function_units(all_units) # Feature name: majority vote across versions feature_names = [si.get("feature_name", "") for si in semantic_indices] name_counts = defaultdict(int) for fn in feature_names: if fn: name_counts[fn] += 1 feature_name = max(name_counts, key=name_counts.get) if name_counts else "" # Confidence summary unit_conf = defaultdict(int) for u in merged_units: unit_conf[u.get("confidence", "low")] += 1 concept_conf = defaultdict(int) for c in merged_concepts: concept_conf[c.get("confidence", "low")] += 1 return { "feature_name": feature_name, "ensemble_versions": total, "concepts": merged_concepts, "function_units": merged_units, "confidence_summary": { "total_units": len(merged_units), "high": unit_conf.get("high", 0), "medium": unit_conf.get("medium", 0), "low": unit_conf.get("low", 0), "total_concepts": len(merged_concepts), "concept_high": concept_conf.get("high", 0), "concept_medium": concept_conf.get("medium", 0), "concept_low": concept_conf.get("low", 0), }, }