""" Stage 1: Ensemble Semantic Index Generation. Generates N parallel LLM calls with different temperatures (e.g., 0.0, 0.3, 0.7), then deterministically merges the results via ensemble_merge (pure Python, no LLM). The merged output includes confidence scores for each concept and function_unit. Outputs: - output/semantic_index_r1.json (T=0.0 raw) - output/semantic_index_r2.json (T=0.3 raw) - output/semantic_index_r3.json (T=0.7 raw) - output/semantic_index.json (ensemble-merged final) """ import concurrent.futures import json import re import sys import time from pathlib import Path import config from ensemble_merge import ensemble_merge # ---- Path Enumeration (for prompt embedding) ---- def _traverse_nested(node: dict, image_id: str, path_nodes: list, branch_taken: str | None) -> list[dict]: """DFS traversal of a logic_tree_nested node, returning leaf path records.""" node_id = node.get("id", "?") node_type = node.get("type", "?") node_name = node.get("name", "") path_nodes = path_nodes + [{ "id": node_id, "type": node_type, "label": node_name, "branch_taken": branch_taken, }] if node_type == "end": return [_make_path_record(path_nodes, image_id)] children = node.get("children", []) if not children: return [_make_path_record(path_nodes, image_id)] all_paths = [] for child in children: # Decision nodes have {condition, node} wrappers; others are direct node dicts if node_type == "decision": condition = child.get("condition", "") child_node = child.get("node", child) else: condition = "(implicit)" child_node = child all_paths.extend( _traverse_nested(child_node, image_id, path_nodes, condition) ) return all_paths def _make_path_record(path_nodes: list, image_id: str) -> dict: """Build a path record from a completed node chain.""" action_nodes = [n for n in path_nodes if n["type"] == "action"] decision_nodes = [n for n in path_nodes if n["type"] == "decision"] node_ids = [n["id"] for n in path_nodes] return { "path_id": f"PATH-{image_id}-{'-'.join(node_ids)}", "nodes": path_nodes, "meaning": _describe_path(path_nodes), "image_id": image_id, "action_nodes": action_nodes, "decision_nodes": decision_nodes, "node_ids": node_ids, } def enumerate_logic_tree_paths(nested_tree: dict, image_id: str = "") -> list[dict]: """Enumerate all root-to-leaf paths from a logic_tree_nested structure. Uses the nested tree directly (no flat-list adjacency). Decision nodes fork by {condition, node} branches; other nodes have direct children. """ if not nested_tree: return [] return _traverse_nested(nested_tree, image_id, [], None) def _describe_path(path_nodes: list[dict]) -> str: """Generate a human-readable description of a logic tree path.""" parts = [] for n in path_nodes: label = n["label"] if n["branch_taken"] and n["branch_taken"] != "(implicit)": label = f"{label} → {n['branch_taken']}" parts.append(label) return " → ".join(parts) def enumerate_all_paths(doc: dict) -> dict[str, list[dict]]: """Enumerate paths for all logic trees in the document. Uses logic_tree_nested when available (proper tree), falling back to flat logic_tree. Returns {image_id: [path, ...]}. """ result = {} for img in doc.get("image_analysis", []): rid = img.get("rid", "") if not rid: continue nested = img.get("logic_tree_nested") if nested: result[rid] = enumerate_logic_tree_paths(nested, image_id=rid) else: lt = img.get("logic_tree") if lt and lt.get("nodes"): lt["image_id"] = rid result[rid] = _enumerate_flat_tree(lt) elif lt: result[rid] = [] return result def _enumerate_flat_tree(tree: dict) -> list[dict]: """Fallback: enumerate paths from flat logic_tree using adjacency. Handles start/process/action/state nodes as implicit chain links. """ nodes = tree.get("nodes", []) if not nodes: return [] node_map = {n["id"]: n for n in nodes} image_id = tree.get("image_id", "") # Find root: first start/state node, or first process node, or first node root = None for n in nodes: if n["type"] in ("start", "state"): root = n break if root is None: for n in nodes: if n["type"] == "process": root = n break if root is None: root = nodes[0] adj = _build_adjacency(nodes, node_map) paths = [] def dfs(current_id, visited, path_nodes, branch_taken): if current_id in visited: return new_visited = visited | {current_id} node = node_map.get(current_id) if node is None: return path_nodes = path_nodes + [{ "id": current_id, "type": node["type"], "label": node.get("description") or node.get("condition", ""), "branch_taken": branch_taken, }] outgoing = adj.get(current_id, []) if not outgoing: action_nodes = [n for n in path_nodes if n["type"] == "action"] decision_nodes = [n for n in path_nodes if n["type"] == "decision"] node_ids = [n["id"] for n in path_nodes] paths.append({ "path_id": f"PATH-{image_id}-{'-'.join(node_ids)}", "nodes": path_nodes, "meaning": _describe_path(path_nodes), "image_id": image_id, "action_nodes": action_nodes, "decision_nodes": decision_nodes, "node_ids": node_ids, }) else: for branch_val, target_id in outgoing: dfs(target_id, new_visited, path_nodes, branch_val) dfs(root["id"], set(), [], None) return paths def _build_adjacency(nodes, node_map): """Build {node_id: [(branch_value, target_id)]} adjacency for flat trees. Handles: decision branches (explicit), non-branching nodes (implicit sequential). """ NON_BRANCHING = {"start", "process", "state", "action"} adj = {} has_explicit_incoming = set() for n in nodes: for br in n.get("branches", []): has_explicit_incoming.add(br["target"]) for i, node in enumerate(nodes): nid = node["id"] adj.setdefault(nid, []) # Explicit edges from decision nodes for br in node.get("branches", []): adj[nid].append((br["value"], br["target"])) # Implicit edges for non-branching nodes (start/process/state/action) if node["type"] in NON_BRANCHING and not node.get("branches"): j = i + 1 targets = [] while j < len(nodes): next_node = nodes[j] next_nid = next_node["id"] if next_nid in has_explicit_incoming: break if next_node["type"] in NON_BRANCHING | {"end"}: targets.append(next_nid) has_explicit_incoming.add(next_nid) j += 1 continue elif next_node["type"] == "decision": if not targets: targets.append(next_nid) break j += 1 for t in targets: adj[nid].append(("(implicit)", t)) return adj def format_paths_for_prompt(all_paths: dict[str, list[dict]]) -> str: """Format enumerated paths as a readable list for the LLM prompt.""" if not all_paths: return "(无逻辑树路径)" lines = [] for image_id, paths in all_paths.items(): lines.append(f"\n### {image_id} 的全部决策路径(共 {len(paths)} 条):") for i, path in enumerate(paths, 1): lines.append(f"\n**路径 {i}** (ID: {path['path_id']})") lines.append(f" 含义: {path['meaning']}") lines.append(f" 节点: {path['node_ids']}") lines.append(f" 决策节点: {[n['id'] for n in path['decision_nodes']]}") lines.append(f" 动作节点: {[n['id'] for n in path['action_nodes']]}") return "\n".join(lines) # ---- Document Formatting ---- def format_document_for_prompt(doc: dict) -> str: """Render the full parsed document as a readable string for the LLM prompt.""" lines = [] lines.append("=== SECTIONS ===") for i, section in enumerate(doc.get("sections", [])): source = section.get("source", f"(无标题-章节{i})") lines.append(f"\n--- Section: {source} ---") for block in section.get("blocks", []): if block["type"] == "para": lines.append(f"[段落 {block['index']}] {block['text']}") elif block["type"] == "table": lines.append(f"[表格 {block.get('table', '?')}]") headers = block.get("headers", []) lines.append(f" 表头: {' | '.join(headers)}") for row in block.get("rows", []): cols = row.get("columns", []) cell_texts = [] for c in cols: cell_texts.append( f"[行{c.get('row','?')}]{c.get('name','')}: {c.get('text','')}" ) lines.append(f" {'; '.join(cell_texts)}") images = section.get("images", []) if images: lines.append(f" 图片引用: {', '.join(images)}") lines.append("\n\n=== IMAGE_ANALYSIS (流程图逻辑树) ===") for img in doc.get("image_analysis", []): rid = img.get("rid", "?") img_type = img.get("type", "?") lines.append(f"\n--- Image: {rid} (type={img_type}) ---") lines.append(f" 描述: {img.get('description', '')[:300]}") lt = img.get("logic_tree") if lt: lines.append(f" 逻辑树根节点: {lt.get('root', '?')}") lines.append(" 节点详情:") for node in lt.get("nodes", []): nid = node.get("id", "?") ntype = node.get("type", "?") desc = node.get("description", "") or node.get("condition", "") lines.append(f" [{ntype}] {nid}: {desc}") branches = node.get("branches", []) if branches: for br in branches: lines.append(f" → {br['value']} → {br['target']}") conflicts = doc.get("resolved_conflicts", []) if conflicts: lines.append("\n\n=== RESOLVED_CONFLICTS (图文冲突仲裁) ===") for c in conflicts: lines.append( f" [{c.get('conflict_type','?')}] {c.get('section','?')}: " f"以{c.get('source','?')}为准 — {c.get('correction','')}" ) return "\n".join(lines) # ---- Prompt Building ---- def build_prompt(doc: dict, feedback: str = "", all_paths: dict | None = None) -> str: """Load the prompt template and inject the formatted document + paths + feedback.""" template_path = Path(config.PROMPTS_DIR) / "step1_semantic_index.txt" template = template_path.read_text(encoding="utf-8") formatted_doc = format_document_for_prompt(doc) prompt = template.replace("{document_json}", formatted_doc) if all_paths is None: all_paths = enumerate_all_paths(doc) path_text = format_paths_for_prompt(all_paths) prompt = prompt.replace("{logic_tree_paths}", path_text) if feedback: prompt = prompt.replace("{feedback}", feedback) else: prompt = prompt.replace("{feedback}", "") return prompt # ---- Validation ---- def _quick_validate( semantic_index: dict, doc: dict, all_paths: dict | None = None ) -> tuple[bool, dict]: """Validate semantic index and return (passed, gaps). Uses a single COVERAGE_TARGET threshold (default 0.95). """ gaps = { "missing_paths": [], "missing_concepts": [], "format_issues": [], "parent_issues": [], "coverage_warnings": [], # section/table coverage below threshold (non-blocking) } units = semantic_index.get("function_units", []) concepts = semantic_index.get("concepts", []) # --- Check function_units non-empty --- if not units: gaps["format_issues"].append("function_units 为空") return False, gaps # --- Check each function_unit has path --- for fu in units: uid = fu.get("unit_id", "?") if not fu.get("path"): gaps["format_issues"].append(f"{uid}: 缺少 path 字段") if not fu.get("sources"): gaps["format_issues"].append(f"{uid}: 缺少 sources") # --- Logic tree node coverage --- all_nodes = _collect_logic_tree_nodes(doc) referenced = _collect_referenced_nodes(units) threshold = config.COVERAGE_TARGET for image_id, node_set in all_nodes.items(): ref_set = referenced.get(image_id, set()) checkable = { nid for nid, ntype in node_set.items() if ntype in ("decision", "action") } if not checkable: continue covered = checkable & ref_set coverage = len(covered) / len(checkable) if checkable else 1.0 if coverage < threshold: missing = checkable - ref_set gaps["missing_paths"].append( f"{image_id}: 覆盖率 {coverage:.0%} < {threshold:.0%}, " f"未覆盖节点: {sorted(missing)}" ) # --- Check logic tree path consistency --- # A unit's logic_tree_nodes must form a valid (connected) path in the tree. if all_paths is not None: for fu in units: uid = fu.get("unit_id", "?") for src in fu.get("sources", []): if src.get("type") != "logic_tree": continue image_id = src.get("image_id", "") unit_nodes = set(src.get("logic_tree_nodes", [])) if not unit_nodes: continue # Check if there exists a path containing all these nodes valid = False for path in all_paths.get(image_id, []): path_nodes = set(path.get("node_ids", [])) if unit_nodes.issubset(path_nodes): valid = True break if not valid: gaps["format_issues"].append( f"{uid}: logic_tree_nodes 不构成有效路径 " f"(image={image_id}, nodes={sorted(unit_nodes)})" ) # --- Check for trivial units (only state/switch nodes, no actions) --- if all_paths is not None: for fu in units: uid = fu.get("unit_id", "?") has_logic_ref = False has_action = False has_non_trivial_decision = False for src in fu.get("sources", []): if src.get("type") != "logic_tree": continue has_logic_ref = True node_ids = src.get("logic_tree_nodes", []) node_types = {} for image_id, nset in all_nodes.items(): for nid in node_ids: if nid in nset: node_types[nid] = nset[nid] for nid in node_ids: ntype = node_types.get(nid, "") if ntype == "action": has_action = True # Count decisions beyond first level (e.g., n1/n2 are just root+switch) decisions = [nid for nid in node_ids if node_types.get(nid, "") == "decision"] if len(decisions) > 1: has_non_trivial_decision = True if has_logic_ref and not has_action and not has_non_trivial_decision: gaps["format_issues"].append( f"{uid}: 可能为空壳单元(仅有state/开关节点,无action或深层decision)" ) # --- Concept parent validity --- concept_names = {c["name"] for c in concepts} for c in concepts: name = c.get("name", "?") parent = c.get("parent") # can be None for scope-level if parent is not None and parent not in concept_names: gaps["parent_issues"].append( f"concept '{name}' 的 parent '{parent}' 不存在" ) # Warn about scope-level concepts without parent=null for c in concepts: if c.get("parent") is not None: continue name = c.get("name", "") # Scope-level concepts (国内/海外) should have parent=null if name not in ("国内", "海外", ""): gaps["parent_issues"].append( f"concept '{name}' 的 parent 为 null,但它可能不是 scope 概念" ) # --- Check for missing scope concepts --- if "国内" not in concept_names: gaps["missing_concepts"].append("缺少 scope 概念: 国内") if "海外" not in concept_names and any( "海外" in s.get("source", "") for s in doc.get("sections", []) ): gaps["missing_concepts"].append("缺少 scope 概念: 海外") # --- Section and table coverage --- # Filter out non-functional sections (background, glossary, changelog, etc.) non_functional_patterns = [ re.compile(p) for p in [ r"编制.*变更.*日志", r"变更日志", r"文档背景", r"文档范围", r"术语解释", r"参考", r"附录", r"版本", r"变更记录", r"目录", r"前言", r"概述", r"简介", r"PRD", r"前置条件", r"依赖", r"行业规范", r"输入文件", r"后方输入", r"政策法规", r"相关文档", r"概要说明", ] ] def _is_functional_section(sec_name: str) -> bool: if not sec_name.strip(): return False # Check non-functional patterns first (even if section is numbered) for pat in non_functional_patterns: if pat.search(sec_name): return False # Numbered sections (e.g., "3.1.1") are functional if re.match(r"^([\d.]+)", sec_name): return True return True def _has_section_content(sec: dict) -> bool: """Check if a section has meaningful content (text >= 10 chars, table, or image). A section is considered "empty" if all its text blocks have fewer than 10 characters and it contains no tables or images. These typically come from image-only Word sections that doc_parser cannot extract text from. """ for block in sec.get("blocks", []): blk_type = block.get("type", "") if blk_type == "table": return True if blk_type in ("image", "figure", "picture"): return True text = block.get("text", "") if isinstance(text, str) and len(text.strip()) >= 10: return True return False func_sections = [ s for s in doc.get("sections", []) if _is_functional_section(s.get("source", "")) and _has_section_content(s) ] covered_sections: set[str] = set() for fu in units: for src in fu.get("sources", []): sec = src.get("section", "") if sec: covered_sections.add(sec) # Use lower threshold for section/table coverage (70% vs 95% for logic trees) SECTION_COVERAGE_TARGET = 0.70 section_cov = len(covered_sections) / max(len(func_sections), 1) print(f" 章节覆盖率: {section_cov:.0%} ({len(covered_sections)}/{len(func_sections)} " f"functional sections)", flush=True) if section_cov < SECTION_COVERAGE_TARGET: uncovered = [s["source"] for s in func_sections if s["source"] not in covered_sections] gaps["coverage_warnings"].append( f"章节覆盖率 {section_cov:.0%} < {SECTION_COVERAGE_TARGET:.0%}, " f"未覆盖: {uncovered[:5]}" ) # Count table rows — only from functional sections with content total_rows = sum( len(b.get("rows", [])) for s in doc.get("sections", []) if _is_functional_section(s.get("source", "")) and _has_section_content(s) for b in s.get("blocks", []) if b.get("type") == "table" ) covered_set: set[tuple] = set() for fu in units: for src in fu.get("sources", []): if src.get("type") == "table" and src.get("row"): covered_set.add((src.get("section", ""), src.get("row"))) covered_rows = len(covered_set) # When there are no table rows to cover, skip check if total_rows == 0: row_cov = 1.0 else: row_cov = covered_rows / total_rows print(f" 表格行覆盖率: {row_cov:.0%} ({covered_rows}/{total_rows} rows)", flush=True) if row_cov < SECTION_COVERAGE_TARGET: # Collect specific missing rows with content for targeted feedback missing_rows: list[dict] = [] for s in doc.get("sections", []): if not _is_functional_section(s.get("source", "")): continue if not _has_section_content(s): continue sec_name = s.get("source", "").split()[0] if s.get("source") else "?" for b in s.get("blocks", []): if b.get("type") != "table": continue for row in b.get("rows", []): rn = row.get("row") if (sec_name, rn) not in covered_set: key_col = "" val_col = "" for col in row.get("columns", []): cn = col.get("name", "") ct = col.get("text", "")[:100] if cn in ("功能", "三级功能", "一级功能", "功能名称"): key_col = ct elif cn in ("功能详细说明", "详细说明", "四级功能", "说明"): val_col = ct if not key_col: # Use first column as key for col in row.get("columns", []): key_col = col.get("text", "")[:60] break missing_rows.append({ "section": sec_name, "row": rn, "key": key_col, "value": val_col, }) gaps["coverage_warnings"].append( f"表格行覆盖率 {row_cov:.0%} < {SECTION_COVERAGE_TARGET:.0%}, " f"({covered_rows}/{total_rows} rows from functional sections)" ) gaps["missing_table_rows"] = missing_rows # Coverage warnings are non-blocking (depend on LLM prompt quality) if gaps["coverage_warnings"]: print(f" [WARN] 覆盖率低于 {SECTION_COVERAGE_TARGET:.0%} 阈值,但 pipeline 继续运行。" f"请通过 Prompt 优化或反馈重试提升。", flush=True) # Only format_issues and logic_tree missing_paths block the pipeline. # parent_issues and coverage_warnings are non-blocking (LLM quality). passed = ( not gaps["missing_paths"] and not gaps["format_issues"] ) return passed, gaps def _build_coverage_feedback(gaps: dict) -> str: """Generate targeted feedback text for re-prompting when coverage is below threshold.""" parts = [] for item in gaps.get("coverage_warnings", []): parts.append(f"- {item}") # Include specific missing table rows with their content missing_rows = gaps.get("missing_table_rows", []) if missing_rows: parts.append(f"\n### 以下具体表格行缺少对应 function_unit(共 {len(missing_rows)} 行):\n") for mr in missing_rows: sec = mr.get("section", "?") rn = mr.get("row", "?") key = mr.get("key", "") val = mr.get("value", "") parts.append( f"- **章节 {sec}, 行 {rn}**: {key}" + (f" — {val}" if val else "") ) if not parts: return "" return ( "\n## 关键覆盖反馈(上一轮 LLM 输出存在缺口,请重新处理)\n\n" + "\n".join(parts) + "\n\n" "### 修复动作(必须执行)\n\n" "1. **重新扫描上述每个缺失章节和表格行**,从文字和表格中提取所有可被测试的功能行为\n" "2. **为上述每个缺失表格行创建独立的 function_unit**,不得合并不同行的规则\n" "3. **每个 function_unit 必须引用具体的 section 号和 row 号**作为 source\n" "4. **非功能章节可以跳过**(如背景、术语、变更日志),但行为规则章节必须覆盖\n" "5. 输出中必须包含针对上述缺口的新 function_unit,**尤其是列出具体缺失的表格行**\n" ) def _collect_logic_tree_nodes(doc: dict) -> dict[str, dict[str, str]]: """Return {image_id: {node_id: node_type}} for all logic trees.""" result = {} for img in doc.get("image_analysis", []): lt = img.get("logic_tree") rid = img.get("rid", "") if lt and rid: result[rid] = {n["id"]: n["type"] for n in lt.get("nodes", [])} return result def _collect_referenced_nodes(units: list[dict]) -> dict[str, set[str]]: """Return {image_id: {referenced node_ids}} across all function_units.""" refs = {} for fu in units: for src in fu.get("sources", []): if src.get("type") == "logic_tree": image_id = src.get("image_id", "") if image_id not in refs: refs[image_id] = set() refs[image_id].update(src.get("logic_tree_nodes", [])) return refs # ---- LLM Calls ---- def extract_json_from_response(text: str) -> str: """Robustly extract JSON from LLM response.""" m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text) if m: return m.group(1).strip() start = text.find("{") if start == -1: raise ValueError("No JSON object found in LLM response") depth = 0 for i in range(start, len(text)): if text[i] == "{": depth += 1 elif text[i] == "}": depth -= 1 if depth == 0: return text[start : i + 1] raise ValueError("Unclosed JSON object in LLM response") def call_llm(prompt: str, max_retries: int = 2, temperature: float | None = None) -> dict: """Send prompt to LLM, return parsed JSON dict. Args: temperature: Override config.TEMPERATURE. If None, uses config default. """ import sys as _sys try: client = config.llm_client() except Exception as e: print(f" LLM 客户端初始化失败: {e}", file=_sys.stderr) print(f" 请检查: IR_PROVIDER={config.LLM_PROVIDER}, secrets.yaml 或环境变量", file=_sys.stderr) raise temp = temperature if temperature is not None else config.TEMPERATURE for attempt in range(max_retries + 1): print(f" LLM 调用 model={config.MODEL_NAME} T={temp} " f"(尝试 {attempt + 1}/{max_retries + 1})...", flush=True) try: resp = client.chat.completions.create( model=config.MODEL_NAME, messages=[ { "role": "system", "content": "你是一个精确的 JSON 输出引擎。只输出合法的 JSON。", }, {"role": "user", "content": prompt}, ], temperature=temp, max_tokens=config.MAX_TOKENS, ) content = resp.choices[0].message.content if content is None: raise RuntimeError( "LLM 返回空响应 (content=None)。可能是 API 配额不足或模型不可用。" ) # Log response length and first characters for diagnostics print(f" 响应长度: {len(content)} 字符", flush=True) json_str = extract_json_from_response(content) result = json.loads(json_str) n_units = len(result.get("function_units", [])) n_concepts = len(result.get("concepts", [])) print(f" 提取: {n_concepts} 概念, {n_units} 功能单元", flush=True) return result except (json.JSONDecodeError, ValueError) as e: print(f" JSON 解析失败: {e}", file=_sys.stderr) # Show a snippet of what the LLM returned for diagnosis print(f" LLM 返回内容前 500 字符: {content[:500] if content else '(None)'}", file=_sys.stderr) if attempt < max_retries: time.sleep(2) raise RuntimeError( f"无法从 LLM 响应中解析 JSON({max_retries + 1} 次尝试均失败)。" f"最后返回内容前 500 字符: {content[:500] if content else '(None)'}" ) # ---- Ensemble Orchestration ---- def run_ensemble_semantic_index(doc: dict) -> dict: """Run N parallel LLM calls at different temperatures, then ensemble-merge. 1. Enumerate all logic tree paths (once). 2. Build the prompt (once — no iterative feedback needed). 3. Launch len(ENSEMBLE_TEMPERATURES) parallel LLM calls via ThreadPoolExecutor. 4. Collect all results. 5. Call ensemble_merge() for deterministic merge. 6. Validate final output with _quick_validate(). 7. Save individual version outputs + merged output. """ all_paths = enumerate_all_paths(doc) print(f" 已枚举逻辑树路径: {sum(len(v) for v in all_paths.values())} 条") prompt = build_prompt(doc, "", all_paths) print(f" Prompt 长度: {len(prompt)} 字符") temperatures = config.ENSEMBLE_TEMPERATURES print(f" 集成温度: {temperatures}") # Parallel LLM calls raw_results: list[tuple[int, float, dict]] = [] with concurrent.futures.ThreadPoolExecutor( max_workers=len(temperatures) ) as executor: future_to_meta = {} for i, temp in enumerate(temperatures): future = executor.submit(call_llm, prompt, 2, temp) future_to_meta[future] = (i, temp) for future in concurrent.futures.as_completed(future_to_meta): idx, temp = future_to_meta[future] try: si = future.result() n_units = len(si.get("function_units", [])) n_concepts = len(si.get("concepts", [])) print(f" T={temp}: {n_concepts} 概念, {n_units} 功能单元") raw_results.append((idx, temp, si)) except Exception as e: print(f" T={temp}: FAIL — {e}") raw_results.append((idx, temp, { "feature_name": "", "concepts": [], "function_units": [] })) if not raw_results: raise RuntimeError("所有集成的 LLM 调用均失败") # Check that at least some raw results have function_units all_empty = all( len(r[2].get("function_units", [])) == 0 for r in raw_results ) if all_empty: raise RuntimeError( "所有集成的 LLM 调用返回了空的 function_units。请检查:\n" " 1. API Key 是否配置正确 (secrets.yaml 或环境变量)\n" " 2. 输入文档格式是否与 Prompt 兼容\n" " 3. LLM 服务是否可访问" ) # Sort by temperature for determinism raw_results.sort(key=lambda x: x[1]) semantic_indices = [r[2] for r in raw_results] # Save individual version outputs version_paths = { 0: config.SEMANTIC_INDEX_R1_JSON, 1: config.SEMANTIC_INDEX_R2_JSON, 2: config.SEMANTIC_INDEX_R3_JSON, } for i, si in enumerate(semantic_indices): out_path = version_paths.get(i) if out_path: config.save_json(si, out_path) print(f" 保存版本 {i} (T={temperatures[i]}): {out_path}") # Ensemble merge print(f"\n 集成合并 {len(semantic_indices)} 个版本...") merged = ensemble_merge(semantic_indices) merged["ensemble_temperatures"] = list(temperatures) # Validate passed, gaps = _quick_validate(merged, doc, all_paths) merged["validation_passed"] = passed merged["validation_gaps"] = { k: v for k, v in gaps.items() if v } # Print summary cs = merged.get("confidence_summary", {}) print(f" 合并后: {cs.get('total_concepts', 0)} 概念, " f"{cs.get('total_units', 0)} 功能单元") print(f" 置信度: high={cs.get('high', 0)}, medium={cs.get('medium', 0)}, " f"low={cs.get('low', 0)}") print(f" 验证: {'PASS' if passed else 'GAPS FOUND'}") if not passed: for k, v in gaps.items(): if v: print(f" {k}: {len(v)} 个问题") # Feedback retry: re-run with coverage feedback (one retry) feedback = _build_coverage_feedback(gaps) if feedback: print(f"\n 覆盖反馈重试 (feedback长度={len(feedback)}字符)...", flush=True) try: retry_prompt = build_prompt(doc, feedback, all_paths) print(f" 重试 prompt 长度: {len(retry_prompt)} 字符", flush=True) retry_result = call_llm(retry_prompt, max_retries=1, temperature=0.3) n_retry_units = len(retry_result.get("function_units", [])) n_retry_concepts = len(retry_result.get("concepts", [])) print(f" 重试返回: {n_retry_concepts} 概念, {n_retry_units} 功能单元", flush=True) if n_retry_units > 0: # Check which new sections were covered retry_sections = set() for fu in retry_result.get("function_units", []): for src in fu.get("sources", []): if src.get("section"): retry_sections.add(src["section"]) print(f" 重试新增 sections: {sorted(retry_sections)}", flush=True) # Merge retry into results and re-validate semantic_indices.append(retry_result) merged = ensemble_merge(semantic_indices) merged["ensemble_temperatures"] = list(temperatures) + ["feedback_retry"] passed, gaps = _quick_validate(merged, doc, all_paths) merged["validation_passed"] = passed merged["validation_gaps"] = { k: v for k, v in gaps.items() if v } print(f" 重试后验证: {'PASS' if passed else 'GAPS FOUND'}", flush=True) except Exception as e: print(f" 覆盖反馈重试失败: {e}", flush=True) import traceback traceback.print_exc() return merged # ---- Main ---- def main(): print("=" * 60) print("阶段一:集成语义索引 (Ensemble Semantic Index)") print("=" * 60) # 1. Load input print(f"\n[1/3] 加载输入文档: {config.INPUT_JSON}") doc = config.load_input_document() print(f" 已加载 {len(doc.get('sections', []))} 个 section, " f"{len(doc.get('image_analysis', []))} 张图片分析") # 2. Run ensemble generation + merge print(f"\n[2/3] 运行集成语义索引 ({len(config.ENSEMBLE_TEMPERATURES)} 个温度版本)...") merged_index = run_ensemble_semantic_index(doc) # 3. Save outputs print(f"\n[3/3] 保存最终语义索引: {config.SEMANTIC_INDEX_JSON}") config.save_json(merged_index, config.SEMANTIC_INDEX_JSON) # Also save path enumeration for downstream use all_paths = enumerate_all_paths(doc) config.save_json( {"logic_tree_paths": {k: v for k, v in all_paths.items()}}, config.PATH_ENUM_JSON, ) print(f" 路径枚举: {config.PATH_ENUM_JSON}") cs = merged_index.get("confidence_summary", {}) n_concepts = cs.get("total_concepts", len(merged_index.get("concepts", []))) n_units = cs.get("total_units", len(merged_index.get("function_units", []))) n_versions = merged_index.get("ensemble_versions", len(config.ENSEMBLE_TEMPERATURES)) if not merged_index.get("validation_passed", True): print(f"\n注意: 语义索引验证发现以下问题 (非阻塞,pipeline 继续运行):") gaps = merged_index.get("validation_gaps", {}) for category, issues in gaps.items(): for issue in issues: print(f" [{category}] {issue}") print(f"\n完成! {n_versions} 版本集成, {n_concepts} 个概念, {n_units} 个功能单元.") print(f"输出: {config.SEMANTIC_INDEX_JSON}") if __name__ == "__main__": main()