""" Stage 1: Ensemble Semantic Index Generation. Generates N parallel LLM calls with different temperatures (e.g., 0.0, 0.3, 0.7), then deterministically merges the results via ensemble_merge (pure Python, no LLM). The merged output includes confidence scores for each concept and function_unit. Outputs: - output/semantic_index_r1.json (T=0.0 raw) - output/semantic_index_r2.json (T=0.3 raw) - output/semantic_index_r3.json (T=0.7 raw) - output/semantic_index.json (ensemble-merged final) """ import concurrent.futures import json import re import sys import time from pathlib import Path import config from ensemble_merge import ensemble_merge # ---- Path Enumeration (for prompt embedding) ---- def _traverse_nested(node: dict, image_id: str, path_nodes: list, branch_taken: str | None) -> list[dict]: """DFS traversal of a logic_tree_nested node, returning leaf path records.""" node_id = node.get("id", "?") node_type = node.get("type", "?") node_name = node.get("name", "") path_nodes = path_nodes + [{ "id": node_id, "type": node_type, "label": node_name, "branch_taken": branch_taken, }] if node_type == "end": return [_make_path_record(path_nodes, image_id)] children = node.get("children", []) if not children: return [_make_path_record(path_nodes, image_id)] all_paths = [] for child in children: # Decision nodes have {condition, node} wrappers; others are direct node dicts if node_type == "decision": condition = child.get("condition", "") child_node = child.get("node", child) else: condition = "(implicit)" child_node = child all_paths.extend( _traverse_nested(child_node, image_id, path_nodes, condition) ) return all_paths def _make_path_record(path_nodes: list, image_id: str) -> dict: """Build a path record from a completed node chain.""" action_nodes = [n for n in path_nodes if n["type"] == "action"] decision_nodes = [n for n in path_nodes if n["type"] == "decision"] node_ids = [n["id"] for n in path_nodes] return { "path_id": f"PATH-{image_id}-{'-'.join(node_ids)}", "nodes": path_nodes, "meaning": _describe_path(path_nodes), "image_id": image_id, "action_nodes": action_nodes, "decision_nodes": decision_nodes, "node_ids": node_ids, } def enumerate_logic_tree_paths(nested_tree: dict, image_id: str = "") -> list[dict]: """Enumerate all root-to-leaf paths from a logic_tree_nested structure. Uses the nested tree directly (no flat-list adjacency). Decision nodes fork by {condition, node} branches; other nodes have direct children. """ if not nested_tree: return [] return _traverse_nested(nested_tree, image_id, [], None) def _describe_path(path_nodes: list[dict]) -> str: """Generate a human-readable description of a logic tree path.""" parts = [] for n in path_nodes: label = n["label"] if n["branch_taken"] and n["branch_taken"] != "(implicit)": label = f"{label} → {n['branch_taken']}" parts.append(label) return " → ".join(parts) def enumerate_all_paths(doc: dict) -> dict[str, list[dict]]: """Enumerate paths for all logic trees in the document. Uses logic_tree_nested when available (proper tree), falling back to flat logic_tree. Returns {image_id: [path, ...]}. """ result = {} for img in doc.get("image_analysis", []): rid = img.get("rid", "") if not rid: continue nested = img.get("logic_tree_nested") if nested: result[rid] = enumerate_logic_tree_paths(nested, image_id=rid) else: lt = img.get("logic_tree") if lt and lt.get("nodes"): lt["image_id"] = rid result[rid] = _enumerate_flat_tree(lt) elif lt: result[rid] = [] return result def _enumerate_flat_tree(tree: dict) -> list[dict]: """Fallback: enumerate paths from flat logic_tree using adjacency. Handles start/process/action/state nodes as implicit chain links. """ nodes = tree.get("nodes", []) if not nodes: return [] node_map = {n["id"]: n for n in nodes} image_id = tree.get("image_id", "") # Find root: first start/state node, or first process node, or first node root = None for n in nodes: if n["type"] in ("start", "state"): root = n break if root is None: for n in nodes: if n["type"] == "process": root = n break if root is None: root = nodes[0] adj = _build_adjacency(nodes, node_map) paths = [] def dfs(current_id, visited, path_nodes, branch_taken): if current_id in visited: return new_visited = visited | {current_id} node = node_map.get(current_id) if node is None: return path_nodes = path_nodes + [{ "id": current_id, "type": node["type"], "label": node.get("description") or node.get("condition", ""), "branch_taken": branch_taken, }] outgoing = adj.get(current_id, []) if not outgoing: action_nodes = [n for n in path_nodes if n["type"] == "action"] decision_nodes = [n for n in path_nodes if n["type"] == "decision"] node_ids = [n["id"] for n in path_nodes] paths.append({ "path_id": f"PATH-{image_id}-{'-'.join(node_ids)}", "nodes": path_nodes, "meaning": _describe_path(path_nodes), "image_id": image_id, "action_nodes": action_nodes, "decision_nodes": decision_nodes, "node_ids": node_ids, }) else: for branch_val, target_id in outgoing: dfs(target_id, new_visited, path_nodes, branch_val) dfs(root["id"], set(), [], None) return paths def _build_adjacency(nodes, node_map): """Build {node_id: [(branch_value, target_id)]} adjacency for flat trees. Handles: decision branches (explicit), non-branching nodes (implicit sequential). """ NON_BRANCHING = {"start", "process", "state", "action"} adj = {} has_explicit_incoming = set() for n in nodes: for br in n.get("branches", []): has_explicit_incoming.add(br["target"]) for i, node in enumerate(nodes): nid = node["id"] adj.setdefault(nid, []) # Explicit edges from decision nodes for br in node.get("branches", []): adj[nid].append((br["value"], br["target"])) # Implicit edges for non-branching nodes (start/process/state/action) if node["type"] in NON_BRANCHING and not node.get("branches"): j = i + 1 targets = [] while j < len(nodes): next_node = nodes[j] next_nid = next_node["id"] if next_nid in has_explicit_incoming: break if next_node["type"] in NON_BRANCHING | {"end"}: targets.append(next_nid) has_explicit_incoming.add(next_nid) j += 1 continue elif next_node["type"] == "decision": if not targets: targets.append(next_nid) break j += 1 for t in targets: adj[nid].append(("(implicit)", t)) return adj def format_paths_for_prompt(all_paths: dict[str, list[dict]]) -> str: """Format enumerated paths as a readable list for the LLM prompt.""" if not all_paths: return "(无逻辑树路径)" lines = [] for image_id, paths in all_paths.items(): lines.append(f"\n### {image_id} 的全部决策路径(共 {len(paths)} 条):") for i, path in enumerate(paths, 1): lines.append(f"\n**路径 {i}** (ID: {path['path_id']})") lines.append(f" 含义: {path['meaning']}") lines.append(f" 节点: {path['node_ids']}") lines.append(f" 决策节点: {[n['id'] for n in path['decision_nodes']]}") lines.append(f" 动作节点: {[n['id'] for n in path['action_nodes']]}") return "\n".join(lines) # ---- Document Formatting ---- def format_document_for_prompt(doc: dict) -> str: """Render the full parsed document as a readable string for the LLM prompt.""" lines = [] lines.append("=== SECTIONS ===") for i, section in enumerate(doc.get("sections", [])): source = section.get("source", f"(无标题-章节{i})") lines.append(f"\n--- Section: {source} ---") for block in section.get("blocks", []): if block["type"] == "para": lines.append(f"[段落 {block['index']}] {block['text']}") elif block["type"] == "table": lines.append(f"[表格 {block.get('table', '?')}]") headers = block.get("headers", []) lines.append(f" 表头: {' | '.join(headers)}") for row in block.get("rows", []): cols = row.get("columns", []) cell_texts = [] for c in cols: cell_texts.append( f"[行{c.get('row','?')}]{c.get('name','')}: {c.get('text','')}" ) lines.append(f" {'; '.join(cell_texts)}") images = section.get("images", []) if images: lines.append(f" 图片引用: {', '.join(images)}") lines.append("\n\n=== IMAGE_ANALYSIS (流程图逻辑树) ===") for img in doc.get("image_analysis", []): rid = img.get("rid", "?") img_type = img.get("type", "?") lines.append(f"\n--- Image: {rid} (type={img_type}) ---") lines.append(f" 描述: {img.get('description', '')[:300]}") lt = img.get("logic_tree") if lt: lines.append(f" 逻辑树根节点: {lt.get('root', '?')}") lines.append(" 节点详情:") for node in lt.get("nodes", []): nid = node.get("id", "?") ntype = node.get("type", "?") desc = node.get("description", "") or node.get("condition", "") lines.append(f" [{ntype}] {nid}: {desc}") branches = node.get("branches", []) if branches: for br in branches: lines.append(f" → {br['value']} → {br['target']}") conflicts = doc.get("resolved_conflicts", []) if conflicts: lines.append("\n\n=== RESOLVED_CONFLICTS (图文冲突仲裁) ===") for c in conflicts: lines.append( f" [{c.get('conflict_type','?')}] {c.get('section','?')}: " f"以{c.get('source','?')}为准 — {c.get('correction','')}" ) return "\n".join(lines) # ---- Prompt Building ---- def build_prompt(doc: dict, feedback: str = "", all_paths: dict | None = None) -> str: """Load the prompt template and inject the formatted document + paths + feedback.""" template_path = Path(config.PROMPTS_DIR) / "step1_semantic_index.txt" template = template_path.read_text(encoding="utf-8") formatted_doc = format_document_for_prompt(doc) prompt = template.replace("{document_json}", formatted_doc) if all_paths is None: all_paths = enumerate_all_paths(doc) path_text = format_paths_for_prompt(all_paths) prompt = prompt.replace("{logic_tree_paths}", path_text) if feedback: prompt = prompt.replace("{feedback}", feedback) else: prompt = prompt.replace("{feedback}", "") return prompt # ---- Validation ---- def _quick_validate( semantic_index: dict, doc: dict, all_paths: dict | None = None ) -> tuple[bool, dict]: """Validate semantic index and return (passed, gaps). Uses a single COVERAGE_TARGET threshold (default 0.95). """ gaps = { "missing_paths": [], "missing_concepts": [], "format_issues": [], "parent_issues": [], } units = semantic_index.get("function_units", []) concepts = semantic_index.get("concepts", []) # --- Check function_units non-empty --- if not units: gaps["format_issues"].append("function_units 为空") return False, gaps # --- Check each function_unit has path --- for fu in units: uid = fu.get("unit_id", "?") if not fu.get("path"): gaps["format_issues"].append(f"{uid}: 缺少 path 字段") if not fu.get("sources"): gaps["format_issues"].append(f"{uid}: 缺少 sources") # --- Logic tree node coverage --- all_nodes = _collect_logic_tree_nodes(doc) referenced = _collect_referenced_nodes(units) threshold = config.COVERAGE_TARGET for image_id, node_set in all_nodes.items(): ref_set = referenced.get(image_id, set()) checkable = { nid for nid, ntype in node_set.items() if ntype in ("decision", "action") } if not checkable: continue covered = checkable & ref_set coverage = len(covered) / len(checkable) if checkable else 1.0 if coverage < threshold: missing = checkable - ref_set gaps["missing_paths"].append( f"{image_id}: 覆盖率 {coverage:.0%} < {threshold:.0%}, " f"未覆盖节点: {sorted(missing)}" ) # --- Check logic tree path consistency --- # A unit's logic_tree_nodes must form a valid (connected) path in the tree. if all_paths is not None: for fu in units: uid = fu.get("unit_id", "?") for src in fu.get("sources", []): if src.get("type") != "logic_tree": continue image_id = src.get("image_id", "") unit_nodes = set(src.get("logic_tree_nodes", [])) if not unit_nodes: continue # Check if there exists a path containing all these nodes valid = False for path in all_paths.get(image_id, []): path_nodes = set(path.get("node_ids", [])) if unit_nodes.issubset(path_nodes): valid = True break if not valid: gaps["format_issues"].append( f"{uid}: logic_tree_nodes 不构成有效路径 " f"(image={image_id}, nodes={sorted(unit_nodes)})" ) # --- Check for trivial units (only state/switch nodes, no actions) --- if all_paths is not None: for fu in units: uid = fu.get("unit_id", "?") has_logic_ref = False has_action = False has_non_trivial_decision = False for src in fu.get("sources", []): if src.get("type") != "logic_tree": continue has_logic_ref = True node_ids = src.get("logic_tree_nodes", []) node_types = {} for image_id, nset in all_nodes.items(): for nid in node_ids: if nid in nset: node_types[nid] = nset[nid] for nid in node_ids: ntype = node_types.get(nid, "") if ntype == "action": has_action = True # Count decisions beyond first level (e.g., n1/n2 are just root+switch) decisions = [nid for nid in node_ids if node_types.get(nid, "") == "decision"] if len(decisions) > 1: has_non_trivial_decision = True if has_logic_ref and not has_action and not has_non_trivial_decision: gaps["format_issues"].append( f"{uid}: 可能为空壳单元(仅有state/开关节点,无action或深层decision)" ) # --- Concept parent validity --- concept_names = {c["name"] for c in concepts} for c in concepts: name = c.get("name", "?") parent = c.get("parent") # can be None for scope-level if parent is not None and parent not in concept_names: gaps["parent_issues"].append( f"concept '{name}' 的 parent '{parent}' 不存在" ) # Warn about scope-level concepts without parent=null for c in concepts: if c.get("parent") is not None: continue name = c.get("name", "") # Scope-level concepts (国内/海外) should have parent=null if name not in ("国内", "海外", ""): gaps["parent_issues"].append( f"concept '{name}' 的 parent 为 null,但它可能不是 scope 概念" ) # --- Check for missing scope concepts --- if "国内" not in concept_names: gaps["missing_concepts"].append("缺少 scope 概念: 国内") if "海外" not in concept_names and any( "海外" in s.get("source", "") for s in doc.get("sections", []) ): gaps["missing_concepts"].append("缺少 scope 概念: 海外") passed = ( not gaps["missing_paths"] and not gaps["format_issues"] and not gaps["parent_issues"] ) return passed, gaps def _collect_logic_tree_nodes(doc: dict) -> dict[str, dict[str, str]]: """Return {image_id: {node_id: node_type}} for all logic trees.""" result = {} for img in doc.get("image_analysis", []): lt = img.get("logic_tree") rid = img.get("rid", "") if lt and rid: result[rid] = {n["id"]: n["type"] for n in lt.get("nodes", [])} return result def _collect_referenced_nodes(units: list[dict]) -> dict[str, set[str]]: """Return {image_id: {referenced node_ids}} across all function_units.""" refs = {} for fu in units: for src in fu.get("sources", []): if src.get("type") == "logic_tree": image_id = src.get("image_id", "") if image_id not in refs: refs[image_id] = set() refs[image_id].update(src.get("logic_tree_nodes", [])) return refs # ---- LLM Calls ---- def extract_json_from_response(text: str) -> str: """Robustly extract JSON from LLM response.""" m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text) if m: return m.group(1).strip() start = text.find("{") if start == -1: raise ValueError("No JSON object found in LLM response") depth = 0 for i in range(start, len(text)): if text[i] == "{": depth += 1 elif text[i] == "}": depth -= 1 if depth == 0: return text[start : i + 1] raise ValueError("Unclosed JSON object in LLM response") def call_llm(prompt: str, max_retries: int = 2, temperature: float | None = None) -> dict: """Send prompt to LLM, return parsed JSON dict. Args: temperature: Override config.TEMPERATURE. If None, uses config default. """ client = config.llm_client() temp = temperature if temperature is not None else config.TEMPERATURE for attempt in range(max_retries + 1): print(f" LLM 调用 T={temp} (尝试 {attempt + 1}/{max_retries + 1})...", flush=True) try: resp = client.chat.completions.create( model=config.MODEL_NAME, messages=[ { "role": "system", "content": "你是一个精确的 JSON 输出引擎。只输出合法的 JSON。", }, {"role": "user", "content": prompt}, ], temperature=temp, max_tokens=config.MAX_TOKENS, ) content = resp.choices[0].message.content if content is None: raise RuntimeError("LLM returned empty response") json_str = extract_json_from_response(content) return json.loads(json_str) except (json.JSONDecodeError, ValueError) as e: print(f" JSON 解析失败: {e}") if attempt < max_retries: time.sleep(2) raise RuntimeError("无法从 LLM 响应中解析 JSON") # ---- Ensemble Orchestration ---- def run_ensemble_semantic_index(doc: dict) -> dict: """Run N parallel LLM calls at different temperatures, then ensemble-merge. 1. Enumerate all logic tree paths (once). 2. Build the prompt (once — no iterative feedback needed). 3. Launch len(ENSEMBLE_TEMPERATURES) parallel LLM calls via ThreadPoolExecutor. 4. Collect all results. 5. Call ensemble_merge() for deterministic merge. 6. Validate final output with _quick_validate(). 7. Save individual version outputs + merged output. """ all_paths = enumerate_all_paths(doc) print(f" 已枚举逻辑树路径: {sum(len(v) for v in all_paths.values())} 条") prompt = build_prompt(doc, "", all_paths) print(f" Prompt 长度: {len(prompt)} 字符") temperatures = config.ENSEMBLE_TEMPERATURES print(f" 集成温度: {temperatures}") # Parallel LLM calls raw_results: list[tuple[int, float, dict]] = [] with concurrent.futures.ThreadPoolExecutor( max_workers=len(temperatures) ) as executor: future_to_meta = {} for i, temp in enumerate(temperatures): future = executor.submit(call_llm, prompt, 2, temp) future_to_meta[future] = (i, temp) for future in concurrent.futures.as_completed(future_to_meta): idx, temp = future_to_meta[future] try: si = future.result() n_units = len(si.get("function_units", [])) n_concepts = len(si.get("concepts", [])) print(f" T={temp}: {n_concepts} 概念, {n_units} 功能单元") raw_results.append((idx, temp, si)) except Exception as e: print(f" T={temp}: FAIL — {e}") raw_results.append((idx, temp, { "feature_name": "", "concepts": [], "function_units": [] })) if not raw_results: raise RuntimeError("所有集成的 LLM 调用均失败") # Sort by temperature for determinism raw_results.sort(key=lambda x: x[1]) semantic_indices = [r[2] for r in raw_results] # Save individual version outputs version_paths = { 0: config.SEMANTIC_INDEX_R1_JSON, 1: config.SEMANTIC_INDEX_R2_JSON, 2: config.SEMANTIC_INDEX_R3_JSON, } for i, si in enumerate(semantic_indices): out_path = version_paths.get(i) if out_path: config.save_json(si, out_path) print(f" 保存版本 {i} (T={temperatures[i]}): {out_path}") # Ensemble merge print(f"\n 集成合并 {len(semantic_indices)} 个版本...") merged = ensemble_merge(semantic_indices) merged["ensemble_temperatures"] = list(temperatures) # Validate passed, gaps = _quick_validate(merged, doc, all_paths) merged["validation_passed"] = passed merged["validation_gaps"] = { k: v for k, v in gaps.items() if v } # Print summary cs = merged.get("confidence_summary", {}) print(f" 合并后: {cs.get('total_concepts', 0)} 概念, " f"{cs.get('total_units', 0)} 功能单元") print(f" 置信度: high={cs.get('high', 0)}, medium={cs.get('medium', 0)}, " f"low={cs.get('low', 0)}") print(f" 验证: {'PASS' if passed else 'GAPS FOUND'}") if not passed: for k, v in gaps.items(): if v: print(f" {k}: {len(v)} 个问题") return merged # ---- Main ---- def main(): print("=" * 60) print("阶段一:集成语义索引 (Ensemble Semantic Index)") print("=" * 60) # 1. Load input print(f"\n[1/3] 加载输入文档: {config.INPUT_JSON}") doc = config.load_input_document() print(f" 已加载 {len(doc.get('sections', []))} 个 section, " f"{len(doc.get('image_analysis', []))} 张图片分析") # 2. Run ensemble generation + merge print(f"\n[2/3] 运行集成语义索引 ({len(config.ENSEMBLE_TEMPERATURES)} 个温度版本)...") merged_index = run_ensemble_semantic_index(doc) # 3. Save outputs print(f"\n[3/3] 保存最终语义索引: {config.SEMANTIC_INDEX_JSON}") config.save_json(merged_index, config.SEMANTIC_INDEX_JSON) # Also save path enumeration for downstream use all_paths = enumerate_all_paths(doc) config.save_json( {"logic_tree_paths": {k: v for k, v in all_paths.items()}}, config.PATH_ENUM_JSON, ) print(f" 路径枚举: {config.PATH_ENUM_JSON}") cs = merged_index.get("confidence_summary", {}) n_concepts = cs.get("total_concepts", len(merged_index.get("concepts", []))) n_units = cs.get("total_units", len(merged_index.get("function_units", []))) n_versions = merged_index.get("ensemble_versions", len(config.ENSEMBLE_TEMPERATURES)) print(f"\n完成! {n_versions} 版本集成, {n_concepts} 个概念, {n_units} 个功能单元.") print(f"输出: {config.SEMANTIC_INDEX_JSON}") if __name__ == "__main__": main()