268520d453
CI / test (pull_request) Successful in 11s
- step3 _normalize_rule: 将 function_unit_description 等非法 source type 标准化为 text - step1 覆盖反馈重试: 仅纳入实际提升覆盖率的 retry 结果,避免低质量输出稀释 ensemble - 新增 UT: test_normalize_source_invalid_type Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
988 lines
38 KiB
Python
988 lines
38 KiB
Python
"""
|
||
Stage 1: Ensemble Semantic Index Generation.
|
||
|
||
Generates N parallel LLM calls with different temperatures (e.g., 0.0, 0.3, 0.7),
|
||
then deterministically merges the results via ensemble_merge (pure Python, no LLM).
|
||
The merged output includes confidence scores for each concept and function_unit.
|
||
|
||
Outputs:
|
||
- output/semantic_index_r1.json (T=0.0 raw)
|
||
- output/semantic_index_r2.json (T=0.3 raw)
|
||
- output/semantic_index_r3.json (T=0.7 raw)
|
||
- output/semantic_index.json (ensemble-merged final)
|
||
"""
|
||
|
||
import concurrent.futures
|
||
import json
|
||
import re
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import config
|
||
from ensemble_merge import ensemble_merge
|
||
|
||
|
||
# ---- Path Enumeration (for prompt embedding) ----
|
||
|
||
|
||
def _traverse_nested(node: dict, image_id: str, path_nodes: list,
|
||
branch_taken: str | None) -> list[dict]:
|
||
"""DFS traversal of a logic_tree_nested node, returning leaf path records."""
|
||
node_id = node.get("id", "?")
|
||
node_type = node.get("type", "?")
|
||
node_name = node.get("name", "")
|
||
|
||
path_nodes = path_nodes + [{
|
||
"id": node_id,
|
||
"type": node_type,
|
||
"label": node_name,
|
||
"branch_taken": branch_taken,
|
||
}]
|
||
|
||
if node_type == "end":
|
||
return [_make_path_record(path_nodes, image_id)]
|
||
|
||
children = node.get("children", [])
|
||
if not children:
|
||
return [_make_path_record(path_nodes, image_id)]
|
||
|
||
all_paths = []
|
||
for child in children:
|
||
# Decision nodes have {condition, node} wrappers; others are direct node dicts
|
||
if node_type == "decision":
|
||
condition = child.get("condition", "")
|
||
child_node = child.get("node", child)
|
||
else:
|
||
condition = "(implicit)"
|
||
child_node = child
|
||
|
||
all_paths.extend(
|
||
_traverse_nested(child_node, image_id, path_nodes, condition)
|
||
)
|
||
|
||
return all_paths
|
||
|
||
|
||
def _make_path_record(path_nodes: list, image_id: str) -> dict:
|
||
"""Build a path record from a completed node chain."""
|
||
action_nodes = [n for n in path_nodes if n["type"] == "action"]
|
||
decision_nodes = [n for n in path_nodes if n["type"] == "decision"]
|
||
node_ids = [n["id"] for n in path_nodes]
|
||
|
||
return {
|
||
"path_id": f"PATH-{image_id}-{'-'.join(node_ids)}",
|
||
"nodes": path_nodes,
|
||
"meaning": _describe_path(path_nodes),
|
||
"image_id": image_id,
|
||
"action_nodes": action_nodes,
|
||
"decision_nodes": decision_nodes,
|
||
"node_ids": node_ids,
|
||
}
|
||
|
||
|
||
def enumerate_logic_tree_paths(nested_tree: dict, image_id: str = "") -> list[dict]:
|
||
"""Enumerate all root-to-leaf paths from a logic_tree_nested structure.
|
||
|
||
Uses the nested tree directly (no flat-list adjacency). Decision nodes
|
||
fork by {condition, node} branches; other nodes have direct children.
|
||
"""
|
||
if not nested_tree:
|
||
return []
|
||
return _traverse_nested(nested_tree, image_id, [], None)
|
||
|
||
|
||
def _describe_path(path_nodes: list[dict]) -> str:
|
||
"""Generate a human-readable description of a logic tree path."""
|
||
parts = []
|
||
for n in path_nodes:
|
||
label = n["label"]
|
||
if n["branch_taken"] and n["branch_taken"] != "(implicit)":
|
||
label = f"{label} → {n['branch_taken']}"
|
||
parts.append(label)
|
||
return " → ".join(parts)
|
||
|
||
|
||
def enumerate_all_paths(doc: dict) -> dict[str, list[dict]]:
|
||
"""Enumerate paths for all logic trees in the document.
|
||
|
||
Uses logic_tree_nested when available (proper tree), falling back to
|
||
flat logic_tree. Returns {image_id: [path, ...]}.
|
||
"""
|
||
result = {}
|
||
for img in doc.get("image_analysis", []):
|
||
rid = img.get("rid", "")
|
||
if not rid:
|
||
continue
|
||
nested = img.get("logic_tree_nested")
|
||
if nested:
|
||
result[rid] = enumerate_logic_tree_paths(nested, image_id=rid)
|
||
else:
|
||
lt = img.get("logic_tree")
|
||
if lt and lt.get("nodes"):
|
||
lt["image_id"] = rid
|
||
result[rid] = _enumerate_flat_tree(lt)
|
||
elif lt:
|
||
result[rid] = []
|
||
return result
|
||
|
||
|
||
def _enumerate_flat_tree(tree: dict) -> list[dict]:
|
||
"""Fallback: enumerate paths from flat logic_tree using adjacency.
|
||
Handles start/process/action/state nodes as implicit chain links.
|
||
"""
|
||
nodes = tree.get("nodes", [])
|
||
if not nodes:
|
||
return []
|
||
node_map = {n["id"]: n for n in nodes}
|
||
image_id = tree.get("image_id", "")
|
||
|
||
# Find root: first start/state node, or first process node, or first node
|
||
root = None
|
||
for n in nodes:
|
||
if n["type"] in ("start", "state"):
|
||
root = n
|
||
break
|
||
if root is None:
|
||
for n in nodes:
|
||
if n["type"] == "process":
|
||
root = n
|
||
break
|
||
if root is None:
|
||
root = nodes[0]
|
||
|
||
adj = _build_adjacency(nodes, node_map)
|
||
paths = []
|
||
|
||
def dfs(current_id, visited, path_nodes, branch_taken):
|
||
if current_id in visited:
|
||
return
|
||
new_visited = visited | {current_id}
|
||
node = node_map.get(current_id)
|
||
if node is None:
|
||
return
|
||
|
||
path_nodes = path_nodes + [{
|
||
"id": current_id,
|
||
"type": node["type"],
|
||
"label": node.get("description") or node.get("condition", ""),
|
||
"branch_taken": branch_taken,
|
||
}]
|
||
|
||
outgoing = adj.get(current_id, [])
|
||
if not outgoing:
|
||
action_nodes = [n for n in path_nodes if n["type"] == "action"]
|
||
decision_nodes = [n for n in path_nodes if n["type"] == "decision"]
|
||
node_ids = [n["id"] for n in path_nodes]
|
||
paths.append({
|
||
"path_id": f"PATH-{image_id}-{'-'.join(node_ids)}",
|
||
"nodes": path_nodes,
|
||
"meaning": _describe_path(path_nodes),
|
||
"image_id": image_id,
|
||
"action_nodes": action_nodes,
|
||
"decision_nodes": decision_nodes,
|
||
"node_ids": node_ids,
|
||
})
|
||
else:
|
||
for branch_val, target_id in outgoing:
|
||
dfs(target_id, new_visited, path_nodes, branch_val)
|
||
|
||
dfs(root["id"], set(), [], None)
|
||
return paths
|
||
|
||
|
||
def _build_adjacency(nodes, node_map):
|
||
"""Build {node_id: [(branch_value, target_id)]} adjacency for flat trees.
|
||
|
||
Handles: decision branches (explicit), non-branching nodes (implicit sequential).
|
||
"""
|
||
NON_BRANCHING = {"start", "process", "state", "action"}
|
||
|
||
adj = {}
|
||
has_explicit_incoming = set()
|
||
for n in nodes:
|
||
for br in n.get("branches", []):
|
||
has_explicit_incoming.add(br["target"])
|
||
|
||
for i, node in enumerate(nodes):
|
||
nid = node["id"]
|
||
adj.setdefault(nid, [])
|
||
|
||
# Explicit edges from decision nodes
|
||
for br in node.get("branches", []):
|
||
adj[nid].append((br["value"], br["target"]))
|
||
|
||
# Implicit edges for non-branching nodes (start/process/state/action)
|
||
if node["type"] in NON_BRANCHING and not node.get("branches"):
|
||
j = i + 1
|
||
targets = []
|
||
while j < len(nodes):
|
||
next_node = nodes[j]
|
||
next_nid = next_node["id"]
|
||
if next_nid in has_explicit_incoming:
|
||
break
|
||
if next_node["type"] in NON_BRANCHING | {"end"}:
|
||
targets.append(next_nid)
|
||
has_explicit_incoming.add(next_nid)
|
||
j += 1
|
||
continue
|
||
elif next_node["type"] == "decision":
|
||
if not targets:
|
||
targets.append(next_nid)
|
||
break
|
||
j += 1
|
||
for t in targets:
|
||
adj[nid].append(("(implicit)", t))
|
||
|
||
return adj
|
||
|
||
|
||
def format_paths_for_prompt(all_paths: dict[str, list[dict]]) -> str:
|
||
"""Format enumerated paths as a readable list for the LLM prompt."""
|
||
if not all_paths:
|
||
return "(无逻辑树路径)"
|
||
|
||
lines = []
|
||
for image_id, paths in all_paths.items():
|
||
lines.append(f"\n### {image_id} 的全部决策路径(共 {len(paths)} 条):")
|
||
for i, path in enumerate(paths, 1):
|
||
lines.append(f"\n**路径 {i}** (ID: {path['path_id']})")
|
||
lines.append(f" 含义: {path['meaning']}")
|
||
lines.append(f" 节点: {path['node_ids']}")
|
||
lines.append(f" 决策节点: {[n['id'] for n in path['decision_nodes']]}")
|
||
lines.append(f" 动作节点: {[n['id'] for n in path['action_nodes']]}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
# ---- Document Formatting ----
|
||
|
||
|
||
def format_document_for_prompt(doc: dict) -> str:
|
||
"""Render the full parsed document as a readable string for the LLM prompt."""
|
||
lines = []
|
||
|
||
lines.append("=== SECTIONS ===")
|
||
for i, section in enumerate(doc.get("sections", [])):
|
||
source = section.get("source", f"(无标题-章节{i})")
|
||
lines.append(f"\n--- Section: {source} ---")
|
||
|
||
for block in section.get("blocks", []):
|
||
if block["type"] == "para":
|
||
lines.append(f"[段落 {block['index']}] {block['text']}")
|
||
elif block["type"] == "table":
|
||
lines.append(f"[表格 {block.get('table', '?')}]")
|
||
headers = block.get("headers", [])
|
||
lines.append(f" 表头: {' | '.join(headers)}")
|
||
for row in block.get("rows", []):
|
||
cols = row.get("columns", [])
|
||
cell_texts = []
|
||
for c in cols:
|
||
cell_texts.append(
|
||
f"[行{c.get('row','?')}]{c.get('name','')}: {c.get('text','')}"
|
||
)
|
||
lines.append(f" {'; '.join(cell_texts)}")
|
||
|
||
images = section.get("images", [])
|
||
if images:
|
||
lines.append(f" 图片引用: {', '.join(images)}")
|
||
|
||
lines.append("\n\n=== IMAGE_ANALYSIS (流程图逻辑树) ===")
|
||
for img in doc.get("image_analysis", []):
|
||
rid = img.get("rid", "?")
|
||
img_type = img.get("type", "?")
|
||
lines.append(f"\n--- Image: {rid} (type={img_type}) ---")
|
||
lines.append(f" 描述: {img.get('description', '')[:300]}")
|
||
|
||
lt = img.get("logic_tree")
|
||
if lt:
|
||
lines.append(f" 逻辑树根节点: {lt.get('root', '?')}")
|
||
lines.append(" 节点详情:")
|
||
for node in lt.get("nodes", []):
|
||
nid = node.get("id", "?")
|
||
ntype = node.get("type", "?")
|
||
desc = node.get("description", "") or node.get("condition", "")
|
||
lines.append(f" [{ntype}] {nid}: {desc}")
|
||
branches = node.get("branches", [])
|
||
if branches:
|
||
for br in branches:
|
||
lines.append(f" → {br['value']} → {br['target']}")
|
||
|
||
conflicts = doc.get("resolved_conflicts", [])
|
||
if conflicts:
|
||
lines.append("\n\n=== RESOLVED_CONFLICTS (图文冲突仲裁) ===")
|
||
for c in conflicts:
|
||
lines.append(
|
||
f" [{c.get('conflict_type','?')}] {c.get('section','?')}: "
|
||
f"以{c.get('source','?')}为准 — {c.get('correction','')}"
|
||
)
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
# ---- Prompt Building ----
|
||
|
||
|
||
def build_prompt(doc: dict, feedback: str = "", all_paths: dict | None = None) -> str:
|
||
"""Load the prompt template and inject the formatted document + paths + feedback."""
|
||
template_path = Path(config.PROMPTS_DIR) / "step1_semantic_index.txt"
|
||
template = template_path.read_text(encoding="utf-8")
|
||
|
||
formatted_doc = format_document_for_prompt(doc)
|
||
prompt = template.replace("{document_json}", formatted_doc)
|
||
|
||
if all_paths is None:
|
||
all_paths = enumerate_all_paths(doc)
|
||
path_text = format_paths_for_prompt(all_paths)
|
||
prompt = prompt.replace("{logic_tree_paths}", path_text)
|
||
|
||
if feedback:
|
||
prompt = prompt.replace("{feedback}", feedback)
|
||
else:
|
||
prompt = prompt.replace("{feedback}", "")
|
||
|
||
return prompt
|
||
|
||
|
||
# ---- Validation ----
|
||
|
||
|
||
def _quick_validate(
|
||
semantic_index: dict, doc: dict, all_paths: dict | None = None
|
||
) -> tuple[bool, dict]:
|
||
"""Validate semantic index and return (passed, gaps).
|
||
|
||
Uses a single COVERAGE_TARGET threshold (default 0.95).
|
||
"""
|
||
gaps = {
|
||
"missing_paths": [],
|
||
"missing_concepts": [],
|
||
"format_issues": [],
|
||
"parent_issues": [],
|
||
"coverage_warnings": [], # section/table coverage below threshold (non-blocking)
|
||
}
|
||
|
||
units = semantic_index.get("function_units", [])
|
||
concepts = semantic_index.get("concepts", [])
|
||
|
||
# --- Check function_units non-empty ---
|
||
if not units:
|
||
gaps["format_issues"].append("function_units 为空")
|
||
return False, gaps
|
||
|
||
# --- Check each function_unit has path ---
|
||
for fu in units:
|
||
uid = fu.get("unit_id", "?")
|
||
if not fu.get("path"):
|
||
gaps["format_issues"].append(f"{uid}: 缺少 path 字段")
|
||
if not fu.get("sources"):
|
||
gaps["format_issues"].append(f"{uid}: 缺少 sources")
|
||
|
||
# --- Logic tree node coverage ---
|
||
all_nodes = _collect_logic_tree_nodes(doc)
|
||
referenced = _collect_referenced_nodes(units)
|
||
|
||
threshold = config.COVERAGE_TARGET
|
||
|
||
for image_id, node_set in all_nodes.items():
|
||
ref_set = referenced.get(image_id, set())
|
||
checkable = {
|
||
nid for nid, ntype in node_set.items()
|
||
if ntype in ("decision", "action")
|
||
}
|
||
if not checkable:
|
||
continue
|
||
covered = checkable & ref_set
|
||
coverage = len(covered) / len(checkable) if checkable else 1.0
|
||
|
||
if coverage < threshold:
|
||
missing = checkable - ref_set
|
||
gaps["missing_paths"].append(
|
||
f"{image_id}: 覆盖率 {coverage:.0%} < {threshold:.0%}, "
|
||
f"未覆盖节点: {sorted(missing)}"
|
||
)
|
||
|
||
# --- Check logic tree path consistency ---
|
||
# A unit's logic_tree_nodes must form a valid (connected) path in the tree.
|
||
if all_paths is not None:
|
||
for fu in units:
|
||
uid = fu.get("unit_id", "?")
|
||
for src in fu.get("sources", []):
|
||
if src.get("type") != "logic_tree":
|
||
continue
|
||
image_id = src.get("image_id", "")
|
||
unit_nodes = set(src.get("logic_tree_nodes", []))
|
||
if not unit_nodes:
|
||
continue
|
||
# Check if there exists a path containing all these nodes
|
||
valid = False
|
||
for path in all_paths.get(image_id, []):
|
||
path_nodes = set(path.get("node_ids", []))
|
||
if unit_nodes.issubset(path_nodes):
|
||
valid = True
|
||
break
|
||
if not valid:
|
||
gaps["format_issues"].append(
|
||
f"{uid}: logic_tree_nodes 不构成有效路径 "
|
||
f"(image={image_id}, nodes={sorted(unit_nodes)})"
|
||
)
|
||
|
||
# --- Check for trivial units (only state/switch nodes, no actions) ---
|
||
if all_paths is not None:
|
||
for fu in units:
|
||
uid = fu.get("unit_id", "?")
|
||
has_logic_ref = False
|
||
has_action = False
|
||
has_non_trivial_decision = False
|
||
for src in fu.get("sources", []):
|
||
if src.get("type") != "logic_tree":
|
||
continue
|
||
has_logic_ref = True
|
||
node_ids = src.get("logic_tree_nodes", [])
|
||
node_types = {}
|
||
for image_id, nset in all_nodes.items():
|
||
for nid in node_ids:
|
||
if nid in nset:
|
||
node_types[nid] = nset[nid]
|
||
for nid in node_ids:
|
||
ntype = node_types.get(nid, "")
|
||
if ntype == "action":
|
||
has_action = True
|
||
# Count decisions beyond first level (e.g., n1/n2 are just root+switch)
|
||
decisions = [nid for nid in node_ids
|
||
if node_types.get(nid, "") == "decision"]
|
||
if len(decisions) > 1:
|
||
has_non_trivial_decision = True
|
||
if has_logic_ref and not has_action and not has_non_trivial_decision:
|
||
gaps["format_issues"].append(
|
||
f"{uid}: 可能为空壳单元(仅有state/开关节点,无action或深层decision)"
|
||
)
|
||
|
||
# --- Concept parent validity ---
|
||
concept_names = {c["name"] for c in concepts}
|
||
for c in concepts:
|
||
name = c.get("name", "?")
|
||
parent = c.get("parent") # can be None for scope-level
|
||
if parent is not None and parent not in concept_names:
|
||
gaps["parent_issues"].append(
|
||
f"concept '{name}' 的 parent '{parent}' 不存在"
|
||
)
|
||
# Warn about scope-level concepts without parent=null
|
||
for c in concepts:
|
||
if c.get("parent") is not None:
|
||
continue
|
||
name = c.get("name", "")
|
||
# Scope-level concepts (国内/海外) should have parent=null
|
||
if name not in ("国内", "海外", ""):
|
||
gaps["parent_issues"].append(
|
||
f"concept '{name}' 的 parent 为 null,但它可能不是 scope 概念"
|
||
)
|
||
|
||
# --- Check for missing scope concepts ---
|
||
if "国内" not in concept_names:
|
||
gaps["missing_concepts"].append("缺少 scope 概念: 国内")
|
||
if "海外" not in concept_names and any(
|
||
"海外" in s.get("source", "") for s in doc.get("sections", [])
|
||
):
|
||
gaps["missing_concepts"].append("缺少 scope 概念: 海外")
|
||
|
||
# --- Section and table coverage ---
|
||
# Filter out non-functional sections (background, glossary, changelog, etc.)
|
||
non_functional_patterns = [
|
||
re.compile(p) for p in [
|
||
r"编制.*变更.*日志", r"变更日志", r"文档背景", r"文档范围",
|
||
r"术语解释", r"参考", r"附录", r"版本", r"变更记录",
|
||
r"目录", r"前言", r"概述", r"简介",
|
||
r"PRD", r"前置条件", r"依赖", r"行业规范", r"输入文件",
|
||
r"后方输入", r"政策法规", r"相关文档", r"概要说明",
|
||
]
|
||
]
|
||
|
||
def _is_functional_section(sec_name: str) -> bool:
|
||
if not sec_name.strip():
|
||
return False
|
||
# Check non-functional patterns first (even if section is numbered)
|
||
for pat in non_functional_patterns:
|
||
if pat.search(sec_name):
|
||
return False
|
||
# Numbered sections (e.g., "3.1.1") are functional
|
||
if re.match(r"^([\d.]+)", sec_name):
|
||
return True
|
||
return True
|
||
|
||
def _has_section_content(sec: dict) -> bool:
|
||
"""Check if a section has meaningful content (text >= 10 chars, table, or image).
|
||
|
||
A section is considered "empty" if all its text blocks have fewer than
|
||
10 characters and it contains no tables or images. These typically come
|
||
from image-only Word sections that doc_parser cannot extract text from.
|
||
"""
|
||
for block in sec.get("blocks", []):
|
||
blk_type = block.get("type", "")
|
||
if blk_type == "table":
|
||
return True
|
||
if blk_type in ("image", "figure", "picture"):
|
||
return True
|
||
text = block.get("text", "")
|
||
if isinstance(text, str) and len(text.strip()) >= 10:
|
||
return True
|
||
return False
|
||
|
||
func_sections = [
|
||
s for s in doc.get("sections", [])
|
||
if _is_functional_section(s.get("source", ""))
|
||
and _has_section_content(s)
|
||
]
|
||
covered_sections: set[str] = set()
|
||
for fu in units:
|
||
for src in fu.get("sources", []):
|
||
sec = src.get("section", "")
|
||
if sec:
|
||
covered_sections.add(sec)
|
||
|
||
# Use lower threshold for section/table coverage (70% vs 95% for logic trees)
|
||
SECTION_COVERAGE_TARGET = 0.70
|
||
|
||
section_cov = len(covered_sections) / max(len(func_sections), 1)
|
||
print(f" 章节覆盖率: {section_cov:.0%} ({len(covered_sections)}/{len(func_sections)} "
|
||
f"functional sections)", flush=True)
|
||
if section_cov < SECTION_COVERAGE_TARGET:
|
||
uncovered = [s["source"] for s in func_sections
|
||
if s["source"] not in covered_sections]
|
||
gaps["coverage_warnings"].append(
|
||
f"章节覆盖率 {section_cov:.0%} < {SECTION_COVERAGE_TARGET:.0%}, "
|
||
f"未覆盖: {uncovered[:5]}"
|
||
)
|
||
|
||
# Count table rows — only from functional sections with content
|
||
total_rows = sum(
|
||
len(b.get("rows", []))
|
||
for s in doc.get("sections", [])
|
||
if _is_functional_section(s.get("source", ""))
|
||
and _has_section_content(s)
|
||
for b in s.get("blocks", [])
|
||
if b.get("type") == "table"
|
||
)
|
||
covered_set: set[tuple] = set()
|
||
for fu in units:
|
||
for src in fu.get("sources", []):
|
||
if src.get("type") == "table" and src.get("row"):
|
||
covered_set.add((src.get("section", ""), src.get("row")))
|
||
covered_rows = len(covered_set)
|
||
# When there are no table rows to cover, skip check
|
||
if total_rows == 0:
|
||
row_cov = 1.0
|
||
else:
|
||
row_cov = covered_rows / total_rows
|
||
print(f" 表格行覆盖率: {row_cov:.0%} ({covered_rows}/{total_rows} rows)", flush=True)
|
||
if row_cov < SECTION_COVERAGE_TARGET:
|
||
# Collect specific missing rows with content for targeted feedback
|
||
missing_rows: list[dict] = []
|
||
for s in doc.get("sections", []):
|
||
if not _is_functional_section(s.get("source", "")):
|
||
continue
|
||
if not _has_section_content(s):
|
||
continue
|
||
sec_name = s.get("source", "").split()[0] if s.get("source") else "?"
|
||
for b in s.get("blocks", []):
|
||
if b.get("type") != "table":
|
||
continue
|
||
for row in b.get("rows", []):
|
||
rn = row.get("row")
|
||
if (sec_name, rn) not in covered_set:
|
||
key_col = ""
|
||
val_col = ""
|
||
for col in row.get("columns", []):
|
||
cn = col.get("name", "")
|
||
ct = col.get("text", "")[:100]
|
||
if cn in ("功能", "三级功能", "一级功能", "功能名称"):
|
||
key_col = ct
|
||
elif cn in ("功能详细说明", "详细说明", "四级功能", "说明"):
|
||
val_col = ct
|
||
if not key_col:
|
||
# Use first column as key
|
||
for col in row.get("columns", []):
|
||
key_col = col.get("text", "")[:60]
|
||
break
|
||
missing_rows.append({
|
||
"section": sec_name,
|
||
"row": rn,
|
||
"key": key_col,
|
||
"value": val_col,
|
||
})
|
||
gaps["coverage_warnings"].append(
|
||
f"表格行覆盖率 {row_cov:.0%} < {SECTION_COVERAGE_TARGET:.0%}, "
|
||
f"({covered_rows}/{total_rows} rows from functional sections)"
|
||
)
|
||
gaps["missing_table_rows"] = missing_rows
|
||
|
||
# Coverage warnings are non-blocking (depend on LLM prompt quality)
|
||
if gaps["coverage_warnings"]:
|
||
print(f" [WARN] 覆盖率低于 {SECTION_COVERAGE_TARGET:.0%} 阈值,但 pipeline 继续运行。"
|
||
f"请通过 Prompt 优化或反馈重试提升。", flush=True)
|
||
|
||
# Only format_issues and logic_tree missing_paths block the pipeline.
|
||
# parent_issues and coverage_warnings are non-blocking (LLM quality).
|
||
passed = (
|
||
not gaps["missing_paths"]
|
||
and not gaps["format_issues"]
|
||
)
|
||
return passed, gaps
|
||
|
||
|
||
def _build_coverage_feedback(gaps: dict) -> str:
|
||
"""Generate targeted feedback text for re-prompting when coverage is below threshold."""
|
||
parts = []
|
||
for item in gaps.get("coverage_warnings", []):
|
||
parts.append(f"- {item}")
|
||
|
||
# Include specific missing table rows with their content
|
||
missing_rows = gaps.get("missing_table_rows", [])
|
||
if missing_rows:
|
||
parts.append(f"\n### 以下具体表格行缺少对应 function_unit(共 {len(missing_rows)} 行):\n")
|
||
for mr in missing_rows:
|
||
sec = mr.get("section", "?")
|
||
rn = mr.get("row", "?")
|
||
key = mr.get("key", "")
|
||
val = mr.get("value", "")
|
||
parts.append(
|
||
f"- **章节 {sec}, 行 {rn}**: {key}"
|
||
+ (f" — {val}" if val else "")
|
||
)
|
||
|
||
if not parts:
|
||
return ""
|
||
|
||
return (
|
||
"\n## 关键覆盖反馈(上一轮 LLM 输出存在缺口,请重新处理)\n\n"
|
||
+ "\n".join(parts)
|
||
+ "\n\n"
|
||
"### 修复动作(必须执行)\n\n"
|
||
"1. **重新扫描上述每个缺失章节和表格行**,从文字和表格中提取所有可被测试的功能行为\n"
|
||
"2. **为上述每个缺失表格行创建独立的 function_unit**,不得合并不同行的规则\n"
|
||
"3. **每个 function_unit 必须引用具体的 section 号和 row 号**作为 source\n"
|
||
"4. **非功能章节可以跳过**(如背景、术语、变更日志),但行为规则章节必须覆盖\n"
|
||
"5. 输出中必须包含针对上述缺口的新 function_unit,**尤其是列出具体缺失的表格行**\n"
|
||
)
|
||
|
||
|
||
def _collect_logic_tree_nodes(doc: dict) -> dict[str, dict[str, str]]:
|
||
"""Return {image_id: {node_id: node_type}} for all logic trees."""
|
||
result = {}
|
||
for img in doc.get("image_analysis", []):
|
||
lt = img.get("logic_tree")
|
||
rid = img.get("rid", "")
|
||
if lt and rid:
|
||
result[rid] = {n["id"]: n["type"] for n in lt.get("nodes", [])}
|
||
return result
|
||
|
||
|
||
def _collect_referenced_nodes(units: list[dict]) -> dict[str, set[str]]:
|
||
"""Return {image_id: {referenced node_ids}} across all function_units."""
|
||
refs = {}
|
||
for fu in units:
|
||
for src in fu.get("sources", []):
|
||
if src.get("type") == "logic_tree":
|
||
image_id = src.get("image_id", "")
|
||
if image_id not in refs:
|
||
refs[image_id] = set()
|
||
refs[image_id].update(src.get("logic_tree_nodes", []))
|
||
return refs
|
||
|
||
|
||
# ---- LLM Calls ----
|
||
|
||
|
||
def extract_json_from_response(text: str) -> str:
|
||
"""Robustly extract JSON from LLM response."""
|
||
m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text)
|
||
if m:
|
||
return m.group(1).strip()
|
||
|
||
start = text.find("{")
|
||
if start == -1:
|
||
raise ValueError("No JSON object found in LLM response")
|
||
|
||
depth = 0
|
||
for i in range(start, len(text)):
|
||
if text[i] == "{":
|
||
depth += 1
|
||
elif text[i] == "}":
|
||
depth -= 1
|
||
if depth == 0:
|
||
return text[start : i + 1]
|
||
|
||
raise ValueError("Unclosed JSON object in LLM response")
|
||
|
||
|
||
def call_llm(prompt: str, max_retries: int = 2,
|
||
temperature: float | None = None) -> dict:
|
||
"""Send prompt to LLM, return parsed JSON dict.
|
||
|
||
Args:
|
||
temperature: Override config.TEMPERATURE. If None, uses config default.
|
||
"""
|
||
import sys as _sys
|
||
|
||
try:
|
||
client = config.llm_client()
|
||
except Exception as e:
|
||
print(f" LLM 客户端初始化失败: {e}", file=_sys.stderr)
|
||
print(f" 请检查: IR_PROVIDER={config.LLM_PROVIDER}, secrets.yaml 或环境变量", file=_sys.stderr)
|
||
raise
|
||
|
||
temp = temperature if temperature is not None else config.TEMPERATURE
|
||
|
||
for attempt in range(max_retries + 1):
|
||
print(f" LLM 调用 model={config.MODEL_NAME} T={temp} "
|
||
f"(尝试 {attempt + 1}/{max_retries + 1})...", flush=True)
|
||
try:
|
||
resp = client.chat.completions.create(
|
||
model=config.MODEL_NAME,
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个精确的 JSON 输出引擎。只输出合法的 JSON。",
|
||
},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
temperature=temp,
|
||
max_tokens=config.MAX_TOKENS,
|
||
)
|
||
content = resp.choices[0].message.content
|
||
if content is None:
|
||
raise RuntimeError(
|
||
"LLM 返回空响应 (content=None)。可能是 API 配额不足或模型不可用。"
|
||
)
|
||
|
||
# Log response length and first characters for diagnostics
|
||
print(f" 响应长度: {len(content)} 字符", flush=True)
|
||
|
||
json_str = extract_json_from_response(content)
|
||
result = json.loads(json_str)
|
||
n_units = len(result.get("function_units", []))
|
||
n_concepts = len(result.get("concepts", []))
|
||
print(f" 提取: {n_concepts} 概念, {n_units} 功能单元", flush=True)
|
||
return result
|
||
|
||
except (json.JSONDecodeError, ValueError) as e:
|
||
print(f" JSON 解析失败: {e}", file=_sys.stderr)
|
||
# Show a snippet of what the LLM returned for diagnosis
|
||
print(f" LLM 返回内容前 500 字符: {content[:500] if content else '(None)'}", file=_sys.stderr)
|
||
if attempt < max_retries:
|
||
time.sleep(2)
|
||
|
||
raise RuntimeError(
|
||
f"无法从 LLM 响应中解析 JSON({max_retries + 1} 次尝试均失败)。"
|
||
f"最后返回内容前 500 字符: {content[:500] if content else '(None)'}"
|
||
)
|
||
|
||
|
||
# ---- Ensemble Orchestration ----
|
||
|
||
|
||
def run_ensemble_semantic_index(doc: dict) -> dict:
|
||
"""Run N parallel LLM calls at different temperatures, then ensemble-merge.
|
||
|
||
1. Enumerate all logic tree paths (once).
|
||
2. Build the prompt (once — no iterative feedback needed).
|
||
3. Launch len(ENSEMBLE_TEMPERATURES) parallel LLM calls via ThreadPoolExecutor.
|
||
4. Collect all results.
|
||
5. Call ensemble_merge() for deterministic merge.
|
||
6. Validate final output with _quick_validate().
|
||
7. Save individual version outputs + merged output.
|
||
"""
|
||
all_paths = enumerate_all_paths(doc)
|
||
print(f" 已枚举逻辑树路径: {sum(len(v) for v in all_paths.values())} 条")
|
||
|
||
prompt = build_prompt(doc, "", all_paths)
|
||
print(f" Prompt 长度: {len(prompt)} 字符")
|
||
|
||
temperatures = config.ENSEMBLE_TEMPERATURES
|
||
print(f" 集成温度: {temperatures}")
|
||
|
||
# Parallel LLM calls
|
||
raw_results: list[tuple[int, float, dict]] = []
|
||
|
||
with concurrent.futures.ThreadPoolExecutor(
|
||
max_workers=len(temperatures)
|
||
) as executor:
|
||
future_to_meta = {}
|
||
for i, temp in enumerate(temperatures):
|
||
future = executor.submit(call_llm, prompt, 2, temp)
|
||
future_to_meta[future] = (i, temp)
|
||
|
||
for future in concurrent.futures.as_completed(future_to_meta):
|
||
idx, temp = future_to_meta[future]
|
||
try:
|
||
si = future.result()
|
||
n_units = len(si.get("function_units", []))
|
||
n_concepts = len(si.get("concepts", []))
|
||
print(f" T={temp}: {n_concepts} 概念, {n_units} 功能单元")
|
||
raw_results.append((idx, temp, si))
|
||
except Exception as e:
|
||
print(f" T={temp}: FAIL — {e}")
|
||
raw_results.append((idx, temp, {
|
||
"feature_name": "", "concepts": [], "function_units": []
|
||
}))
|
||
|
||
if not raw_results:
|
||
raise RuntimeError("所有集成的 LLM 调用均失败")
|
||
|
||
# Check that at least some raw results have function_units
|
||
all_empty = all(
|
||
len(r[2].get("function_units", [])) == 0 for r in raw_results
|
||
)
|
||
if all_empty:
|
||
raise RuntimeError(
|
||
"所有集成的 LLM 调用返回了空的 function_units。请检查:\n"
|
||
" 1. API Key 是否配置正确 (secrets.yaml 或环境变量)\n"
|
||
" 2. 输入文档格式是否与 Prompt 兼容\n"
|
||
" 3. LLM 服务是否可访问"
|
||
)
|
||
|
||
# Sort by temperature for determinism
|
||
raw_results.sort(key=lambda x: x[1])
|
||
semantic_indices = [r[2] for r in raw_results]
|
||
|
||
# Save individual version outputs
|
||
version_paths = {
|
||
0: config.SEMANTIC_INDEX_R1_JSON,
|
||
1: config.SEMANTIC_INDEX_R2_JSON,
|
||
2: config.SEMANTIC_INDEX_R3_JSON,
|
||
}
|
||
for i, si in enumerate(semantic_indices):
|
||
out_path = version_paths.get(i)
|
||
if out_path:
|
||
config.save_json(si, out_path)
|
||
print(f" 保存版本 {i} (T={temperatures[i]}): {out_path}")
|
||
|
||
# Ensemble merge
|
||
print(f"\n 集成合并 {len(semantic_indices)} 个版本...")
|
||
merged = ensemble_merge(semantic_indices)
|
||
merged["ensemble_temperatures"] = list(temperatures)
|
||
|
||
# Validate
|
||
passed, gaps = _quick_validate(merged, doc, all_paths)
|
||
merged["validation_passed"] = passed
|
||
merged["validation_gaps"] = {
|
||
k: v for k, v in gaps.items() if v
|
||
}
|
||
|
||
# Print summary
|
||
cs = merged.get("confidence_summary", {})
|
||
print(f" 合并后: {cs.get('total_concepts', 0)} 概念, "
|
||
f"{cs.get('total_units', 0)} 功能单元")
|
||
print(f" 置信度: high={cs.get('high', 0)}, medium={cs.get('medium', 0)}, "
|
||
f"low={cs.get('low', 0)}")
|
||
print(f" 验证: {'PASS' if passed else 'GAPS FOUND'}")
|
||
if not passed:
|
||
for k, v in gaps.items():
|
||
if v:
|
||
print(f" {k}: {len(v)} 个问题")
|
||
|
||
# Feedback retry: re-run with coverage feedback (up to 2 retries, quality-gated)
|
||
retry_count = 0
|
||
while retry_count < 2:
|
||
feedback = _build_coverage_feedback(gaps)
|
||
if not feedback:
|
||
break
|
||
retry_count += 1
|
||
print(f"\n 覆盖反馈重试 #{retry_count} (feedback长度={len(feedback)}字符)...", flush=True)
|
||
try:
|
||
# record pre-retry coverage to gate quality
|
||
pre_warnings = len(gaps.get("coverage_warnings", []))
|
||
pre_missing_rows = len(gaps.get("missing_table_rows", []))
|
||
|
||
retry_prompt = build_prompt(doc, feedback, all_paths)
|
||
print(f" 重试 prompt 长度: {len(retry_prompt)} 字符", flush=True)
|
||
retry_result = call_llm(retry_prompt, max_retries=1, temperature=0.3)
|
||
n_retry_units = len(retry_result.get("function_units", []))
|
||
n_retry_concepts = len(retry_result.get("concepts", []))
|
||
print(f" 重试返回: {n_retry_concepts} 概念, {n_retry_units} 功能单元", flush=True)
|
||
if n_retry_units > 0:
|
||
retry_sections = set()
|
||
for fu in retry_result.get("function_units", []):
|
||
for src in fu.get("sources", []):
|
||
if src.get("section"):
|
||
retry_sections.add(src["section"])
|
||
print(f" 重试新增 sections: {sorted(retry_sections)}", flush=True)
|
||
# Quality gate: only include retry if it improves coverage
|
||
trial_indices = semantic_indices + [retry_result]
|
||
trial_merged = ensemble_merge(trial_indices)
|
||
trial_passed, trial_gaps = _quick_validate(trial_merged, doc, all_paths)
|
||
trial_warnings = len(trial_gaps.get("coverage_warnings", []))
|
||
trial_missing = len(trial_gaps.get("missing_table_rows", []))
|
||
if trial_warnings < pre_warnings or trial_missing < pre_missing_rows:
|
||
semantic_indices.append(retry_result)
|
||
merged = trial_merged
|
||
passed, gaps = trial_passed, trial_gaps
|
||
merged["ensemble_temperatures"] = list(temperatures) + [f"feedback_retry_{retry_count}"]
|
||
merged["validation_passed"] = passed
|
||
merged["validation_gaps"] = {
|
||
k: v for k, v in gaps.items() if v
|
||
}
|
||
print(f" 重试后验证 (已采纳): {'PASS' if passed else 'GAPS FOUND'} "
|
||
f"(warnings {pre_warnings}→{trial_warnings}, "
|
||
f"missing_rows {pre_missing_rows}→{trial_missing})", flush=True)
|
||
else:
|
||
print(f" 重试结果未提升覆盖率,丢弃 "
|
||
f"(warnings {pre_warnings}→{trial_warnings}, "
|
||
f"missing_rows {pre_missing_rows}→{trial_missing})", flush=True)
|
||
except Exception as e:
|
||
print(f" 覆盖反馈重试失败: {e}", flush=True)
|
||
import traceback
|
||
traceback.print_exc()
|
||
break
|
||
|
||
return merged
|
||
|
||
|
||
# ---- Main ----
|
||
|
||
|
||
def main():
|
||
print("=" * 60)
|
||
print("阶段一:集成语义索引 (Ensemble Semantic Index)")
|
||
print("=" * 60)
|
||
|
||
# 1. Load input
|
||
print(f"\n[1/3] 加载输入文档: {config.INPUT_JSON}")
|
||
doc = config.load_input_document()
|
||
print(f" 已加载 {len(doc.get('sections', []))} 个 section, "
|
||
f"{len(doc.get('image_analysis', []))} 张图片分析")
|
||
|
||
# 2. Run ensemble generation + merge
|
||
print(f"\n[2/3] 运行集成语义索引 ({len(config.ENSEMBLE_TEMPERATURES)} 个温度版本)...")
|
||
merged_index = run_ensemble_semantic_index(doc)
|
||
|
||
# 3. Save outputs
|
||
print(f"\n[3/3] 保存最终语义索引: {config.SEMANTIC_INDEX_JSON}")
|
||
config.save_json(merged_index, config.SEMANTIC_INDEX_JSON)
|
||
|
||
# Also save path enumeration for downstream use
|
||
all_paths = enumerate_all_paths(doc)
|
||
config.save_json(
|
||
{"logic_tree_paths": {k: v for k, v in all_paths.items()}},
|
||
config.PATH_ENUM_JSON,
|
||
)
|
||
print(f" 路径枚举: {config.PATH_ENUM_JSON}")
|
||
|
||
cs = merged_index.get("confidence_summary", {})
|
||
n_concepts = cs.get("total_concepts", len(merged_index.get("concepts", [])))
|
||
n_units = cs.get("total_units", len(merged_index.get("function_units", [])))
|
||
n_versions = merged_index.get("ensemble_versions", len(config.ENSEMBLE_TEMPERATURES))
|
||
|
||
if not merged_index.get("validation_passed", True):
|
||
print(f"\n注意: 语义索引验证发现以下问题 (非阻塞,pipeline 继续运行):")
|
||
gaps = merged_index.get("validation_gaps", {})
|
||
for category, issues in gaps.items():
|
||
for issue in issues:
|
||
print(f" [{category}] {issue}")
|
||
|
||
print(f"\n完成! {n_versions} 版本集成, {n_concepts} 个概念, {n_units} 个功能单元.")
|
||
print(f"输出: {config.SEMANTIC_INDEX_JSON}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|