Files
pzhang_zywl 2cd02453ec
CI / test (pull_request) Successful in 8s
fix: step1 覆盖反馈重试增至 3 次 + 放宽质量门控 - Closes #75
- 重试次数 2→3,增加 LLM 补全机会
- 质量门控放宽:新增 sections 且无回归即采纳,不只严格要求覆盖率下降

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-02 18:35:06 +08:00

991 lines
38 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Stage 1: Ensemble Semantic Index Generation.
Generates N parallel LLM calls with different temperatures (e.g., 0.0, 0.3, 0.7),
then deterministically merges the results via ensemble_merge (pure Python, no LLM).
The merged output includes confidence scores for each concept and function_unit.
Outputs:
- output/semantic_index_r1.json (T=0.0 raw)
- output/semantic_index_r2.json (T=0.3 raw)
- output/semantic_index_r3.json (T=0.7 raw)
- output/semantic_index.json (ensemble-merged final)
"""
import concurrent.futures
import json
import re
import sys
import time
from pathlib import Path
import config
from ensemble_merge import ensemble_merge
# ---- Path Enumeration (for prompt embedding) ----
def _traverse_nested(node: dict, image_id: str, path_nodes: list,
branch_taken: str | None) -> list[dict]:
"""DFS traversal of a logic_tree_nested node, returning leaf path records."""
node_id = node.get("id", "?")
node_type = node.get("type", "?")
node_name = node.get("name", "")
path_nodes = path_nodes + [{
"id": node_id,
"type": node_type,
"label": node_name,
"branch_taken": branch_taken,
}]
if node_type == "end":
return [_make_path_record(path_nodes, image_id)]
children = node.get("children", [])
if not children:
return [_make_path_record(path_nodes, image_id)]
all_paths = []
for child in children:
# Decision nodes have {condition, node} wrappers; others are direct node dicts
if node_type == "decision":
condition = child.get("condition", "")
child_node = child.get("node", child)
else:
condition = "(implicit)"
child_node = child
all_paths.extend(
_traverse_nested(child_node, image_id, path_nodes, condition)
)
return all_paths
def _make_path_record(path_nodes: list, image_id: str) -> dict:
"""Build a path record from a completed node chain."""
action_nodes = [n for n in path_nodes if n["type"] == "action"]
decision_nodes = [n for n in path_nodes if n["type"] == "decision"]
node_ids = [n["id"] for n in path_nodes]
return {
"path_id": f"PATH-{image_id}-{'-'.join(node_ids)}",
"nodes": path_nodes,
"meaning": _describe_path(path_nodes),
"image_id": image_id,
"action_nodes": action_nodes,
"decision_nodes": decision_nodes,
"node_ids": node_ids,
}
def enumerate_logic_tree_paths(nested_tree: dict, image_id: str = "") -> list[dict]:
"""Enumerate all root-to-leaf paths from a logic_tree_nested structure.
Uses the nested tree directly (no flat-list adjacency). Decision nodes
fork by {condition, node} branches; other nodes have direct children.
"""
if not nested_tree:
return []
return _traverse_nested(nested_tree, image_id, [], None)
def _describe_path(path_nodes: list[dict]) -> str:
"""Generate a human-readable description of a logic tree path."""
parts = []
for n in path_nodes:
label = n["label"]
if n["branch_taken"] and n["branch_taken"] != "(implicit)":
label = f"{label}{n['branch_taken']}"
parts.append(label)
return "".join(parts)
def enumerate_all_paths(doc: dict) -> dict[str, list[dict]]:
"""Enumerate paths for all logic trees in the document.
Uses logic_tree_nested when available (proper tree), falling back to
flat logic_tree. Returns {image_id: [path, ...]}.
"""
result = {}
for img in doc.get("image_analysis", []):
rid = img.get("rid", "")
if not rid:
continue
nested = img.get("logic_tree_nested")
if nested:
result[rid] = enumerate_logic_tree_paths(nested, image_id=rid)
else:
lt = img.get("logic_tree")
if lt and lt.get("nodes"):
lt["image_id"] = rid
result[rid] = _enumerate_flat_tree(lt)
elif lt:
result[rid] = []
return result
def _enumerate_flat_tree(tree: dict) -> list[dict]:
"""Fallback: enumerate paths from flat logic_tree using adjacency.
Handles start/process/action/state nodes as implicit chain links.
"""
nodes = tree.get("nodes", [])
if not nodes:
return []
node_map = {n["id"]: n for n in nodes}
image_id = tree.get("image_id", "")
# Find root: first start/state node, or first process node, or first node
root = None
for n in nodes:
if n["type"] in ("start", "state"):
root = n
break
if root is None:
for n in nodes:
if n["type"] == "process":
root = n
break
if root is None:
root = nodes[0]
adj = _build_adjacency(nodes, node_map)
paths = []
def dfs(current_id, visited, path_nodes, branch_taken):
if current_id in visited:
return
new_visited = visited | {current_id}
node = node_map.get(current_id)
if node is None:
return
path_nodes = path_nodes + [{
"id": current_id,
"type": node["type"],
"label": node.get("description") or node.get("condition", ""),
"branch_taken": branch_taken,
}]
outgoing = adj.get(current_id, [])
if not outgoing:
action_nodes = [n for n in path_nodes if n["type"] == "action"]
decision_nodes = [n for n in path_nodes if n["type"] == "decision"]
node_ids = [n["id"] for n in path_nodes]
paths.append({
"path_id": f"PATH-{image_id}-{'-'.join(node_ids)}",
"nodes": path_nodes,
"meaning": _describe_path(path_nodes),
"image_id": image_id,
"action_nodes": action_nodes,
"decision_nodes": decision_nodes,
"node_ids": node_ids,
})
else:
for branch_val, target_id in outgoing:
dfs(target_id, new_visited, path_nodes, branch_val)
dfs(root["id"], set(), [], None)
return paths
def _build_adjacency(nodes, node_map):
"""Build {node_id: [(branch_value, target_id)]} adjacency for flat trees.
Handles: decision branches (explicit), non-branching nodes (implicit sequential).
"""
NON_BRANCHING = {"start", "process", "state", "action"}
adj = {}
has_explicit_incoming = set()
for n in nodes:
for br in n.get("branches", []):
has_explicit_incoming.add(br["target"])
for i, node in enumerate(nodes):
nid = node["id"]
adj.setdefault(nid, [])
# Explicit edges from decision nodes
for br in node.get("branches", []):
adj[nid].append((br["value"], br["target"]))
# Implicit edges for non-branching nodes (start/process/state/action)
if node["type"] in NON_BRANCHING and not node.get("branches"):
j = i + 1
targets = []
while j < len(nodes):
next_node = nodes[j]
next_nid = next_node["id"]
if next_nid in has_explicit_incoming:
break
if next_node["type"] in NON_BRANCHING | {"end"}:
targets.append(next_nid)
has_explicit_incoming.add(next_nid)
j += 1
continue
elif next_node["type"] == "decision":
if not targets:
targets.append(next_nid)
break
j += 1
for t in targets:
adj[nid].append(("(implicit)", t))
return adj
def format_paths_for_prompt(all_paths: dict[str, list[dict]]) -> str:
"""Format enumerated paths as a readable list for the LLM prompt."""
if not all_paths:
return "(无逻辑树路径)"
lines = []
for image_id, paths in all_paths.items():
lines.append(f"\n### {image_id} 的全部决策路径(共 {len(paths)} 条):")
for i, path in enumerate(paths, 1):
lines.append(f"\n**路径 {i}** (ID: {path['path_id']})")
lines.append(f" 含义: {path['meaning']}")
lines.append(f" 节点: {path['node_ids']}")
lines.append(f" 决策节点: {[n['id'] for n in path['decision_nodes']]}")
lines.append(f" 动作节点: {[n['id'] for n in path['action_nodes']]}")
return "\n".join(lines)
# ---- Document Formatting ----
def format_document_for_prompt(doc: dict) -> str:
"""Render the full parsed document as a readable string for the LLM prompt."""
lines = []
lines.append("=== SECTIONS ===")
for i, section in enumerate(doc.get("sections", [])):
source = section.get("source", f"(无标题-章节{i})")
lines.append(f"\n--- Section: {source} ---")
for block in section.get("blocks", []):
if block["type"] == "para":
lines.append(f"[段落 {block['index']}] {block['text']}")
elif block["type"] == "table":
lines.append(f"[表格 {block.get('table', '?')}]")
headers = block.get("headers", [])
lines.append(f" 表头: {' | '.join(headers)}")
for row in block.get("rows", []):
cols = row.get("columns", [])
cell_texts = []
for c in cols:
cell_texts.append(
f"[行{c.get('row','?')}]{c.get('name','')}: {c.get('text','')}"
)
lines.append(f" {'; '.join(cell_texts)}")
images = section.get("images", [])
if images:
lines.append(f" 图片引用: {', '.join(images)}")
lines.append("\n\n=== IMAGE_ANALYSIS (流程图逻辑树) ===")
for img in doc.get("image_analysis", []):
rid = img.get("rid", "?")
img_type = img.get("type", "?")
lines.append(f"\n--- Image: {rid} (type={img_type}) ---")
lines.append(f" 描述: {img.get('description', '')[:300]}")
lt = img.get("logic_tree")
if lt:
lines.append(f" 逻辑树根节点: {lt.get('root', '?')}")
lines.append(" 节点详情:")
for node in lt.get("nodes", []):
nid = node.get("id", "?")
ntype = node.get("type", "?")
desc = node.get("description", "") or node.get("condition", "")
lines.append(f" [{ntype}] {nid}: {desc}")
branches = node.get("branches", [])
if branches:
for br in branches:
lines.append(f"{br['value']}{br['target']}")
conflicts = doc.get("resolved_conflicts", [])
if conflicts:
lines.append("\n\n=== RESOLVED_CONFLICTS (图文冲突仲裁) ===")
for c in conflicts:
lines.append(
f" [{c.get('conflict_type','?')}] {c.get('section','?')}: "
f"{c.get('source','?')}为准 — {c.get('correction','')}"
)
return "\n".join(lines)
# ---- Prompt Building ----
def build_prompt(doc: dict, feedback: str = "", all_paths: dict | None = None) -> str:
"""Load the prompt template and inject the formatted document + paths + feedback."""
template_path = Path(config.PROMPTS_DIR) / "step1_semantic_index.txt"
template = template_path.read_text(encoding="utf-8")
formatted_doc = format_document_for_prompt(doc)
prompt = template.replace("{document_json}", formatted_doc)
if all_paths is None:
all_paths = enumerate_all_paths(doc)
path_text = format_paths_for_prompt(all_paths)
prompt = prompt.replace("{logic_tree_paths}", path_text)
if feedback:
prompt = prompt.replace("{feedback}", feedback)
else:
prompt = prompt.replace("{feedback}", "")
return prompt
# ---- Validation ----
def _quick_validate(
semantic_index: dict, doc: dict, all_paths: dict | None = None
) -> tuple[bool, dict]:
"""Validate semantic index and return (passed, gaps).
Uses a single COVERAGE_TARGET threshold (default 0.95).
"""
gaps = {
"missing_paths": [],
"missing_concepts": [],
"format_issues": [],
"parent_issues": [],
"coverage_warnings": [], # section/table coverage below threshold (non-blocking)
}
units = semantic_index.get("function_units", [])
concepts = semantic_index.get("concepts", [])
# --- Check function_units non-empty ---
if not units:
gaps["format_issues"].append("function_units 为空")
return False, gaps
# --- Check each function_unit has path ---
for fu in units:
uid = fu.get("unit_id", "?")
if not fu.get("path"):
gaps["format_issues"].append(f"{uid}: 缺少 path 字段")
if not fu.get("sources"):
gaps["format_issues"].append(f"{uid}: 缺少 sources")
# --- Logic tree node coverage ---
all_nodes = _collect_logic_tree_nodes(doc)
referenced = _collect_referenced_nodes(units)
threshold = config.COVERAGE_TARGET
for image_id, node_set in all_nodes.items():
ref_set = referenced.get(image_id, set())
checkable = {
nid for nid, ntype in node_set.items()
if ntype in ("decision", "action")
}
if not checkable:
continue
covered = checkable & ref_set
coverage = len(covered) / len(checkable) if checkable else 1.0
if coverage < threshold:
missing = checkable - ref_set
gaps["missing_paths"].append(
f"{image_id}: 覆盖率 {coverage:.0%} < {threshold:.0%}, "
f"未覆盖节点: {sorted(missing)}"
)
# --- Check logic tree path consistency ---
# A unit's logic_tree_nodes must form a valid (connected) path in the tree.
if all_paths is not None:
for fu in units:
uid = fu.get("unit_id", "?")
for src in fu.get("sources", []):
if src.get("type") != "logic_tree":
continue
image_id = src.get("image_id", "")
unit_nodes = set(src.get("logic_tree_nodes", []))
if not unit_nodes:
continue
# Check if there exists a path containing all these nodes
valid = False
for path in all_paths.get(image_id, []):
path_nodes = set(path.get("node_ids", []))
if unit_nodes.issubset(path_nodes):
valid = True
break
if not valid:
gaps["format_issues"].append(
f"{uid}: logic_tree_nodes 不构成有效路径 "
f"(image={image_id}, nodes={sorted(unit_nodes)})"
)
# --- Check for trivial units (only state/switch nodes, no actions) ---
if all_paths is not None:
for fu in units:
uid = fu.get("unit_id", "?")
has_logic_ref = False
has_action = False
has_non_trivial_decision = False
for src in fu.get("sources", []):
if src.get("type") != "logic_tree":
continue
has_logic_ref = True
node_ids = src.get("logic_tree_nodes", [])
node_types = {}
for image_id, nset in all_nodes.items():
for nid in node_ids:
if nid in nset:
node_types[nid] = nset[nid]
for nid in node_ids:
ntype = node_types.get(nid, "")
if ntype == "action":
has_action = True
# Count decisions beyond first level (e.g., n1/n2 are just root+switch)
decisions = [nid for nid in node_ids
if node_types.get(nid, "") == "decision"]
if len(decisions) > 1:
has_non_trivial_decision = True
if has_logic_ref and not has_action and not has_non_trivial_decision:
gaps["format_issues"].append(
f"{uid}: 可能为空壳单元(仅有state/开关节点,无action或深层decision"
)
# --- Concept parent validity ---
concept_names = {c["name"] for c in concepts}
for c in concepts:
name = c.get("name", "?")
parent = c.get("parent") # can be None for scope-level
if parent is not None and parent not in concept_names:
gaps["parent_issues"].append(
f"concept '{name}' 的 parent '{parent}' 不存在"
)
# Warn about scope-level concepts without parent=null
for c in concepts:
if c.get("parent") is not None:
continue
name = c.get("name", "")
# Scope-level concepts (国内/海外) should have parent=null
if name not in ("国内", "海外", ""):
gaps["parent_issues"].append(
f"concept '{name}' 的 parent 为 null,但它可能不是 scope 概念"
)
# --- Check for missing scope concepts ---
if "国内" not in concept_names:
gaps["missing_concepts"].append("缺少 scope 概念: 国内")
if "海外" not in concept_names and any(
"海外" in s.get("source", "") for s in doc.get("sections", [])
):
gaps["missing_concepts"].append("缺少 scope 概念: 海外")
# --- Section and table coverage ---
# Filter out non-functional sections (background, glossary, changelog, etc.)
non_functional_patterns = [
re.compile(p) for p in [
r"编制.*变更.*日志", r"变更日志", r"文档背景", r"文档范围",
r"术语解释", r"参考", r"附录", r"版本", r"变更记录",
r"目录", r"前言", r"概述", r"简介",
r"PRD", r"前置条件", r"依赖", r"行业规范", r"输入文件",
r"后方输入", r"政策法规", r"相关文档", r"概要说明",
]
]
def _is_functional_section(sec_name: str) -> bool:
if not sec_name.strip():
return False
# Check non-functional patterns first (even if section is numbered)
for pat in non_functional_patterns:
if pat.search(sec_name):
return False
# Numbered sections (e.g., "3.1.1") are functional
if re.match(r"^([\d.]+)", sec_name):
return True
return True
def _has_section_content(sec: dict) -> bool:
"""Check if a section has meaningful content (text >= 10 chars, table, or image).
A section is considered "empty" if all its text blocks have fewer than
10 characters and it contains no tables or images. These typically come
from image-only Word sections that doc_parser cannot extract text from.
"""
for block in sec.get("blocks", []):
blk_type = block.get("type", "")
if blk_type == "table":
return True
if blk_type in ("image", "figure", "picture"):
return True
text = block.get("text", "")
if isinstance(text, str) and len(text.strip()) >= 10:
return True
return False
func_sections = [
s for s in doc.get("sections", [])
if _is_functional_section(s.get("source", ""))
and _has_section_content(s)
]
covered_sections: set[str] = set()
for fu in units:
for src in fu.get("sources", []):
sec = src.get("section", "")
if sec:
covered_sections.add(sec)
# Use lower threshold for section/table coverage (70% vs 95% for logic trees)
SECTION_COVERAGE_TARGET = 0.70
section_cov = len(covered_sections) / max(len(func_sections), 1)
print(f" 章节覆盖率: {section_cov:.0%} ({len(covered_sections)}/{len(func_sections)} "
f"functional sections)", flush=True)
if section_cov < SECTION_COVERAGE_TARGET:
uncovered = [s["source"] for s in func_sections
if s["source"] not in covered_sections]
gaps["coverage_warnings"].append(
f"章节覆盖率 {section_cov:.0%} < {SECTION_COVERAGE_TARGET:.0%}, "
f"未覆盖: {uncovered[:5]}"
)
# Count table rows — only from functional sections with content
total_rows = sum(
len(b.get("rows", []))
for s in doc.get("sections", [])
if _is_functional_section(s.get("source", ""))
and _has_section_content(s)
for b in s.get("blocks", [])
if b.get("type") == "table"
)
covered_set: set[tuple] = set()
for fu in units:
for src in fu.get("sources", []):
if src.get("type") == "table" and src.get("row"):
covered_set.add((src.get("section", ""), src.get("row")))
covered_rows = len(covered_set)
# When there are no table rows to cover, skip check
if total_rows == 0:
row_cov = 1.0
else:
row_cov = covered_rows / total_rows
print(f" 表格行覆盖率: {row_cov:.0%} ({covered_rows}/{total_rows} rows)", flush=True)
if row_cov < SECTION_COVERAGE_TARGET:
# Collect specific missing rows with content for targeted feedback
missing_rows: list[dict] = []
for s in doc.get("sections", []):
if not _is_functional_section(s.get("source", "")):
continue
if not _has_section_content(s):
continue
sec_name = s.get("source", "").split()[0] if s.get("source") else "?"
for b in s.get("blocks", []):
if b.get("type") != "table":
continue
for row in b.get("rows", []):
rn = row.get("row")
if (sec_name, rn) not in covered_set:
key_col = ""
val_col = ""
for col in row.get("columns", []):
cn = col.get("name", "")
ct = col.get("text", "")[:100]
if cn in ("功能", "三级功能", "一级功能", "功能名称"):
key_col = ct
elif cn in ("功能详细说明", "详细说明", "四级功能", "说明"):
val_col = ct
if not key_col:
# Use first column as key
for col in row.get("columns", []):
key_col = col.get("text", "")[:60]
break
missing_rows.append({
"section": sec_name,
"row": rn,
"key": key_col,
"value": val_col,
})
gaps["coverage_warnings"].append(
f"表格行覆盖率 {row_cov:.0%} < {SECTION_COVERAGE_TARGET:.0%}, "
f"({covered_rows}/{total_rows} rows from functional sections)"
)
gaps["missing_table_rows"] = missing_rows
# Coverage warnings are non-blocking (depend on LLM prompt quality)
if gaps["coverage_warnings"]:
print(f" [WARN] 覆盖率低于 {SECTION_COVERAGE_TARGET:.0%} 阈值,但 pipeline 继续运行。"
f"请通过 Prompt 优化或反馈重试提升。", flush=True)
# Only format_issues and logic_tree missing_paths block the pipeline.
# parent_issues and coverage_warnings are non-blocking (LLM quality).
passed = (
not gaps["missing_paths"]
and not gaps["format_issues"]
)
return passed, gaps
def _build_coverage_feedback(gaps: dict) -> str:
"""Generate targeted feedback text for re-prompting when coverage is below threshold."""
parts = []
for item in gaps.get("coverage_warnings", []):
parts.append(f"- {item}")
# Include specific missing table rows with their content
missing_rows = gaps.get("missing_table_rows", [])
if missing_rows:
parts.append(f"\n### 以下具体表格行缺少对应 function_unit(共 {len(missing_rows)} 行):\n")
for mr in missing_rows:
sec = mr.get("section", "?")
rn = mr.get("row", "?")
key = mr.get("key", "")
val = mr.get("value", "")
parts.append(
f"- **章节 {sec}, 行 {rn}**: {key}"
+ (f"{val}" if val else "")
)
if not parts:
return ""
return (
"\n## 关键覆盖反馈(上一轮 LLM 输出存在缺口,请重新处理)\n\n"
+ "\n".join(parts)
+ "\n\n"
"### 修复动作(必须执行)\n\n"
"1. **重新扫描上述每个缺失章节和表格行**,从文字和表格中提取所有可被测试的功能行为\n"
"2. **为上述每个缺失表格行创建独立的 function_unit**,不得合并不同行的规则\n"
"3. **每个 function_unit 必须引用具体的 section 号和 row 号**作为 source\n"
"4. **非功能章节可以跳过**(如背景、术语、变更日志),但行为规则章节必须覆盖\n"
"5. 输出中必须包含针对上述缺口的新 function_unit,**尤其是列出具体缺失的表格行**\n"
)
def _collect_logic_tree_nodes(doc: dict) -> dict[str, dict[str, str]]:
"""Return {image_id: {node_id: node_type}} for all logic trees."""
result = {}
for img in doc.get("image_analysis", []):
lt = img.get("logic_tree")
rid = img.get("rid", "")
if lt and rid:
result[rid] = {n["id"]: n["type"] for n in lt.get("nodes", [])}
return result
def _collect_referenced_nodes(units: list[dict]) -> dict[str, set[str]]:
"""Return {image_id: {referenced node_ids}} across all function_units."""
refs = {}
for fu in units:
for src in fu.get("sources", []):
if src.get("type") == "logic_tree":
image_id = src.get("image_id", "")
if image_id not in refs:
refs[image_id] = set()
refs[image_id].update(src.get("logic_tree_nodes", []))
return refs
# ---- LLM Calls ----
def extract_json_from_response(text: str) -> str:
"""Robustly extract JSON from LLM response."""
m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text)
if m:
return m.group(1).strip()
start = text.find("{")
if start == -1:
raise ValueError("No JSON object found in LLM response")
depth = 0
for i in range(start, len(text)):
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
if depth == 0:
return text[start : i + 1]
raise ValueError("Unclosed JSON object in LLM response")
def call_llm(prompt: str, max_retries: int = 2,
temperature: float | None = None) -> dict:
"""Send prompt to LLM, return parsed JSON dict.
Args:
temperature: Override config.TEMPERATURE. If None, uses config default.
"""
import sys as _sys
try:
client = config.llm_client()
except Exception as e:
print(f" LLM 客户端初始化失败: {e}", file=_sys.stderr)
print(f" 请检查: IR_PROVIDER={config.LLM_PROVIDER}, secrets.yaml 或环境变量", file=_sys.stderr)
raise
temp = temperature if temperature is not None else config.TEMPERATURE
for attempt in range(max_retries + 1):
print(f" LLM 调用 model={config.MODEL_NAME} T={temp} "
f"(尝试 {attempt + 1}/{max_retries + 1})...", flush=True)
try:
resp = client.chat.completions.create(
model=config.MODEL_NAME,
messages=[
{
"role": "system",
"content": "你是一个精确的 JSON 输出引擎。只输出合法的 JSON。",
},
{"role": "user", "content": prompt},
],
temperature=temp,
max_tokens=config.MAX_TOKENS,
)
content = resp.choices[0].message.content
if content is None:
raise RuntimeError(
"LLM 返回空响应 (content=None)。可能是 API 配额不足或模型不可用。"
)
# Log response length and first characters for diagnostics
print(f" 响应长度: {len(content)} 字符", flush=True)
json_str = extract_json_from_response(content)
result = json.loads(json_str)
n_units = len(result.get("function_units", []))
n_concepts = len(result.get("concepts", []))
print(f" 提取: {n_concepts} 概念, {n_units} 功能单元", flush=True)
return result
except (json.JSONDecodeError, ValueError) as e:
print(f" JSON 解析失败: {e}", file=_sys.stderr)
# Show a snippet of what the LLM returned for diagnosis
print(f" LLM 返回内容前 500 字符: {content[:500] if content else '(None)'}", file=_sys.stderr)
if attempt < max_retries:
time.sleep(2)
raise RuntimeError(
f"无法从 LLM 响应中解析 JSON{max_retries + 1} 次尝试均失败)。"
f"最后返回内容前 500 字符: {content[:500] if content else '(None)'}"
)
# ---- Ensemble Orchestration ----
def run_ensemble_semantic_index(doc: dict) -> dict:
"""Run N parallel LLM calls at different temperatures, then ensemble-merge.
1. Enumerate all logic tree paths (once).
2. Build the prompt (once — no iterative feedback needed).
3. Launch len(ENSEMBLE_TEMPERATURES) parallel LLM calls via ThreadPoolExecutor.
4. Collect all results.
5. Call ensemble_merge() for deterministic merge.
6. Validate final output with _quick_validate().
7. Save individual version outputs + merged output.
"""
all_paths = enumerate_all_paths(doc)
print(f" 已枚举逻辑树路径: {sum(len(v) for v in all_paths.values())}")
prompt = build_prompt(doc, "", all_paths)
print(f" Prompt 长度: {len(prompt)} 字符")
temperatures = config.ENSEMBLE_TEMPERATURES
print(f" 集成温度: {temperatures}")
# Parallel LLM calls
raw_results: list[tuple[int, float, dict]] = []
with concurrent.futures.ThreadPoolExecutor(
max_workers=len(temperatures)
) as executor:
future_to_meta = {}
for i, temp in enumerate(temperatures):
future = executor.submit(call_llm, prompt, 2, temp)
future_to_meta[future] = (i, temp)
for future in concurrent.futures.as_completed(future_to_meta):
idx, temp = future_to_meta[future]
try:
si = future.result()
n_units = len(si.get("function_units", []))
n_concepts = len(si.get("concepts", []))
print(f" T={temp}: {n_concepts} 概念, {n_units} 功能单元")
raw_results.append((idx, temp, si))
except Exception as e:
print(f" T={temp}: FAIL — {e}")
raw_results.append((idx, temp, {
"feature_name": "", "concepts": [], "function_units": []
}))
if not raw_results:
raise RuntimeError("所有集成的 LLM 调用均失败")
# Check that at least some raw results have function_units
all_empty = all(
len(r[2].get("function_units", [])) == 0 for r in raw_results
)
if all_empty:
raise RuntimeError(
"所有集成的 LLM 调用返回了空的 function_units。请检查:\n"
" 1. API Key 是否配置正确 (secrets.yaml 或环境变量)\n"
" 2. 输入文档格式是否与 Prompt 兼容\n"
" 3. LLM 服务是否可访问"
)
# Sort by temperature for determinism
raw_results.sort(key=lambda x: x[1])
semantic_indices = [r[2] for r in raw_results]
# Save individual version outputs
version_paths = {
0: config.SEMANTIC_INDEX_R1_JSON,
1: config.SEMANTIC_INDEX_R2_JSON,
2: config.SEMANTIC_INDEX_R3_JSON,
}
for i, si in enumerate(semantic_indices):
out_path = version_paths.get(i)
if out_path:
config.save_json(si, out_path)
print(f" 保存版本 {i} (T={temperatures[i]}): {out_path}")
# Ensemble merge
print(f"\n 集成合并 {len(semantic_indices)} 个版本...")
merged = ensemble_merge(semantic_indices)
merged["ensemble_temperatures"] = list(temperatures)
# Validate
passed, gaps = _quick_validate(merged, doc, all_paths)
merged["validation_passed"] = passed
merged["validation_gaps"] = {
k: v for k, v in gaps.items() if v
}
# Print summary
cs = merged.get("confidence_summary", {})
print(f" 合并后: {cs.get('total_concepts', 0)} 概念, "
f"{cs.get('total_units', 0)} 功能单元")
print(f" 置信度: high={cs.get('high', 0)}, medium={cs.get('medium', 0)}, "
f"low={cs.get('low', 0)}")
print(f" 验证: {'PASS' if passed else 'GAPS FOUND'}")
if not passed:
for k, v in gaps.items():
if v:
print(f" {k}: {len(v)} 个问题")
# Feedback retry: re-run with coverage feedback (up to 3 retries, quality-gated)
retry_count = 0
while retry_count < 3:
feedback = _build_coverage_feedback(gaps)
if not feedback:
break
retry_count += 1
print(f"\n 覆盖反馈重试 #{retry_count} (feedback长度={len(feedback)}字符)...", flush=True)
try:
# record pre-retry coverage to gate quality
pre_warnings = len(gaps.get("coverage_warnings", []))
pre_missing_rows = len(gaps.get("missing_table_rows", []))
retry_prompt = build_prompt(doc, feedback, all_paths)
print(f" 重试 prompt 长度: {len(retry_prompt)} 字符", flush=True)
retry_result = call_llm(retry_prompt, max_retries=1, temperature=0.3)
n_retry_units = len(retry_result.get("function_units", []))
n_retry_concepts = len(retry_result.get("concepts", []))
print(f" 重试返回: {n_retry_concepts} 概念, {n_retry_units} 功能单元", flush=True)
if n_retry_units > 0:
retry_sections = set()
for fu in retry_result.get("function_units", []):
for src in fu.get("sources", []):
if src.get("section"):
retry_sections.add(src["section"])
print(f" 重试新增 sections: {sorted(retry_sections)}", flush=True)
# Quality gate: include retry if it adds new sections or doesn't regress coverage
trial_indices = semantic_indices + [retry_result]
trial_merged = ensemble_merge(trial_indices)
trial_passed, trial_gaps = _quick_validate(trial_merged, doc, all_paths)
trial_warnings = len(trial_gaps.get("coverage_warnings", []))
trial_missing = len(trial_gaps.get("missing_table_rows", []))
improved = trial_warnings < pre_warnings or trial_missing < pre_missing_rows
no_regression = trial_warnings <= pre_warnings and trial_missing <= pre_missing_rows
has_new_sections = len(retry_sections) > 0
if improved or (no_regression and has_new_sections):
semantic_indices.append(retry_result)
merged = trial_merged
passed, gaps = trial_passed, trial_gaps
merged["ensemble_temperatures"] = list(temperatures) + [f"feedback_retry_{retry_count}"]
merged["validation_passed"] = passed
merged["validation_gaps"] = {
k: v for k, v in gaps.items() if v
}
print(f" 重试后验证 (已采纳): {'PASS' if passed else 'GAPS FOUND'} "
f"(warnings {pre_warnings}{trial_warnings}, "
f"missing_rows {pre_missing_rows}{trial_missing})", flush=True)
else:
print(f" 重试结果未提升覆盖率,丢弃 "
f"(warnings {pre_warnings}{trial_warnings}, "
f"missing_rows {pre_missing_rows}{trial_missing})", flush=True)
except Exception as e:
print(f" 覆盖反馈重试失败: {e}", flush=True)
import traceback
traceback.print_exc()
break
return merged
# ---- Main ----
def main():
print("=" * 60)
print("阶段一:集成语义索引 (Ensemble Semantic Index)")
print("=" * 60)
# 1. Load input
print(f"\n[1/3] 加载输入文档: {config.INPUT_JSON}")
doc = config.load_input_document()
print(f" 已加载 {len(doc.get('sections', []))} 个 section, "
f"{len(doc.get('image_analysis', []))} 张图片分析")
# 2. Run ensemble generation + merge
print(f"\n[2/3] 运行集成语义索引 ({len(config.ENSEMBLE_TEMPERATURES)} 个温度版本)...")
merged_index = run_ensemble_semantic_index(doc)
# 3. Save outputs
print(f"\n[3/3] 保存最终语义索引: {config.SEMANTIC_INDEX_JSON}")
config.save_json(merged_index, config.SEMANTIC_INDEX_JSON)
# Also save path enumeration for downstream use
all_paths = enumerate_all_paths(doc)
config.save_json(
{"logic_tree_paths": {k: v for k, v in all_paths.items()}},
config.PATH_ENUM_JSON,
)
print(f" 路径枚举: {config.PATH_ENUM_JSON}")
cs = merged_index.get("confidence_summary", {})
n_concepts = cs.get("total_concepts", len(merged_index.get("concepts", [])))
n_units = cs.get("total_units", len(merged_index.get("function_units", [])))
n_versions = merged_index.get("ensemble_versions", len(config.ENSEMBLE_TEMPERATURES))
if not merged_index.get("validation_passed", True):
print(f"\n注意: 语义索引验证发现以下问题 (非阻塞,pipeline 继续运行):")
gaps = merged_index.get("validation_gaps", {})
for category, issues in gaps.items():
for issue in issues:
print(f" [{category}] {issue}")
print(f"\n完成! {n_versions} 版本集成, {n_concepts} 个概念, {n_units} 个功能单元.")
print(f"输出: {config.SEMANTIC_INDEX_JSON}")
if __name__ == "__main__":
main()