fec4c09ee0
CI / test (push) Successful in 8s
doc_parser_skill: - New: verify_flowchart.py (flowchart validation) - Updated: LLM.py (multi-provider: DeepSeek + DashScope) - Updated: image_parser.py (logic tree support, external prompts) - Updated: SKILL.md, prompts/image_prompt.md conflict_detection_skill: - Updated: LLM.py (multi-provider sync) - Updated: detect_conflicts.py (logic tree text conversion) ir_generation_skill: - Replaced old scripts/LLM.py + ir_generator.py with standalone project - New: main.py, config.py, step1-3_*.py, ensemble_merge.py - New: prompts/, tests/ subdirectories tests: - New: acceptance/ test suite with schema validation - Fixed: conftest no longer globally skips non-acceptance tests - Updated: test_sample.py for new ir_generation structure Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
509 lines
19 KiB
Python
509 lines
19 KiB
Python
"""
|
|
Stage 2: Per Function Unit IR Extraction.
|
|
|
|
For each function unit from the semantic index, constructs a precision context
|
|
package and calls the LLM to extract detailed IR rules.
|
|
|
|
Runs multiple LLM calls in parallel (up to MAX_CONCURRENCY).
|
|
|
|
Output: output/ir_fragments.json
|
|
"""
|
|
|
|
import concurrent.futures
|
|
import json
|
|
import re
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import config
|
|
|
|
|
|
MAX_CONCURRENCY = 3 # Max parallel LLM calls
|
|
|
|
|
|
def load_semantic_index() -> dict:
|
|
"""Load the semantic index from Stage 1."""
|
|
return config.load_json(config.SEMANTIC_INDEX_JSON)
|
|
|
|
|
|
def build_document_lookup(doc: dict):
|
|
"""Build lookup structures for fast context extraction from the document."""
|
|
|
|
# sections_by_source: "3.1.1" -> section dict
|
|
sections_by_source = {}
|
|
for section in doc.get("sections", []):
|
|
source = section.get("source", "")
|
|
# Normalize: extract leading number like "3.1.1"
|
|
parts = source.split()
|
|
if parts:
|
|
key = parts[0].strip()
|
|
sections_by_source[key] = section
|
|
|
|
# image_by_rid: "rId16" -> image_analysis entry
|
|
image_by_rid = {}
|
|
for img in doc.get("image_analysis", []):
|
|
rid = img.get("rid", "")
|
|
if rid:
|
|
image_by_rid[rid] = img
|
|
|
|
# Conflicts indexed by section
|
|
conflicts_by_section = {}
|
|
for c in doc.get("resolved_conflicts", []):
|
|
section = c.get("section", "")
|
|
key = section.split()[0] if section else ""
|
|
conflicts_by_section.setdefault(key, []).append(c)
|
|
|
|
return sections_by_source, image_by_rid, conflicts_by_section
|
|
|
|
|
|
def extract_context_package(
|
|
fu: dict, doc: dict, sections_by_source: dict, image_by_rid: dict,
|
|
conflicts_by_section: dict
|
|
) -> dict:
|
|
"""Build a precision context package for a single function unit."""
|
|
texts = []
|
|
tables = []
|
|
logic_trees = []
|
|
seen_sections = set()
|
|
seen_images = set()
|
|
|
|
for src in fu.get("sources", []):
|
|
src_type = src.get("type", "")
|
|
section_key = src.get("section", "").split()[0] if src.get("section") else ""
|
|
|
|
# --- Text source ---
|
|
if src_type in ("table", "para") and section_key:
|
|
if section_key in seen_sections:
|
|
continue
|
|
seen_sections.add(section_key)
|
|
|
|
section = sections_by_source.get(section_key)
|
|
if section is None:
|
|
# Fuzzy match by prefix
|
|
for key in sections_by_source:
|
|
if key.startswith(section_key):
|
|
section = sections_by_source[key]
|
|
break
|
|
|
|
if section:
|
|
for block in section.get("blocks", []):
|
|
if block["type"] == "para":
|
|
texts.append({
|
|
"section": section_key,
|
|
"text": block["text"]
|
|
})
|
|
elif block["type"] == "table":
|
|
row_num = src.get("row") if src_type == "table" else None
|
|
if row_num is not None:
|
|
# Extract only the specific row
|
|
matching_rows = []
|
|
for r in block.get("rows", []):
|
|
for c in r.get("columns", []):
|
|
if c.get("row") == row_num:
|
|
matching_rows.append({
|
|
"headers": block.get("headers", []),
|
|
"cells": {
|
|
col["name"]: col["text"]
|
|
for col in r["columns"]
|
|
},
|
|
"row": row_num
|
|
})
|
|
break
|
|
tables.append({
|
|
"section": section_key,
|
|
"headers": block.get("headers", []),
|
|
"rows": matching_rows,
|
|
"all_rows": [
|
|
{
|
|
"row": col.get("row"),
|
|
"name": col.get("name"),
|
|
"text": col.get("text")
|
|
}
|
|
for row in block.get("rows", [])
|
|
for col in row.get("columns", [])
|
|
]
|
|
})
|
|
else:
|
|
# Include full table
|
|
tables.append({
|
|
"section": section_key,
|
|
"headers": block.get("headers", []),
|
|
"all_rows": [
|
|
{
|
|
"row": col.get("row"),
|
|
"name": col.get("name"),
|
|
"text": col.get("text")
|
|
}
|
|
for row in block.get("rows", [])
|
|
for col in row.get("columns", [])
|
|
]
|
|
})
|
|
|
|
# --- Logic tree source ---
|
|
if src_type == "logic_tree":
|
|
image_id = src.get("image_id", "")
|
|
if not image_id or image_id in seen_images:
|
|
continue
|
|
seen_images.add(image_id)
|
|
|
|
img = image_by_rid.get(image_id)
|
|
if img:
|
|
lt = img.get("logic_tree")
|
|
if lt:
|
|
logic_trees.append({
|
|
"image_id": image_id,
|
|
"description": img.get("description", ""),
|
|
"tree": lt
|
|
})
|
|
|
|
# Include relevant resolved conflicts
|
|
relevant_conflicts = []
|
|
for section_key in seen_sections:
|
|
for c in conflicts_by_section.get(section_key, []):
|
|
relevant_conflicts.append(c)
|
|
|
|
return {
|
|
"unit_id": fu["unit_id"],
|
|
"unit_name": fu.get("name", ""),
|
|
"unit_description": fu.get("description", ""),
|
|
"unit_path": fu.get("path", []),
|
|
"texts": texts,
|
|
"tables": tables,
|
|
"logic_trees": logic_trees,
|
|
"resolved_conflicts": relevant_conflicts
|
|
}
|
|
|
|
|
|
def format_context_package(pkg: dict) -> str:
|
|
"""Format a context package as a readable string for the prompt."""
|
|
parts = []
|
|
|
|
# Texts
|
|
parts.append("【文字段落】")
|
|
for i, t in enumerate(pkg.get("texts", [])):
|
|
parts.append(f"[{t.get('section', '?')}] {t.get('text', '')}")
|
|
if not pkg.get("texts"):
|
|
parts.append("(无)")
|
|
|
|
# Tables
|
|
parts.append("\n【表格数据】")
|
|
for i, tbl in enumerate(pkg.get("tables", [])):
|
|
parts.append(f"表格 {i+1} (section={tbl.get('section', '?')})")
|
|
headers = tbl.get("headers", [])
|
|
parts.append(f" 表头: {headers}")
|
|
parts.append(" 全部行数据:")
|
|
for row in tbl.get("all_rows", []):
|
|
parts.append(
|
|
f" 行{row.get('row','?')}[{row.get('name','?')}]: {row.get('text','')}"
|
|
)
|
|
# Highlight matched rows if any
|
|
matched = tbl.get("rows", [])
|
|
if matched:
|
|
parts.append(" <重点关注行>:")
|
|
for mr in matched:
|
|
parts.append(f" 行{mr.get('row','?')}: {mr.get('cells', {})}")
|
|
if not pkg.get("tables"):
|
|
parts.append("(无)")
|
|
|
|
# Logic trees
|
|
parts.append("\n【逻辑树】")
|
|
for i, lt in enumerate(pkg.get("logic_trees", [])):
|
|
parts.append(f"逻辑树 {i+1} (image_id={lt.get('image_id', '?')})")
|
|
parts.append(f" 描述: {lt.get('description', '')[:200]}")
|
|
tree = lt.get("tree", {})
|
|
parts.append(f" 根: {tree.get('root', '?')}")
|
|
parts.append(" 节点:")
|
|
for node in tree.get("nodes", []):
|
|
nid = node.get("id", "?")
|
|
ntype = node.get("type", "?")
|
|
desc = node.get("description", "") or node.get("condition", "")
|
|
parts.append(f" [{ntype}] {nid}: {desc}")
|
|
for br in node.get("branches", []):
|
|
parts.append(f" → {br['value']} → {br['target']}")
|
|
if not pkg.get("logic_trees"):
|
|
parts.append("(无)")
|
|
|
|
# Conflicts
|
|
conflicts = pkg.get("resolved_conflicts", [])
|
|
if conflicts:
|
|
parts.append("\n【图文冲突仲裁】")
|
|
for c in conflicts:
|
|
parts.append(
|
|
f" [{c.get('conflict_type', '?')}] 以{c.get('source', '?')}为准: "
|
|
f"{c.get('correction', '')}"
|
|
)
|
|
|
|
return "\n".join(parts)
|
|
|
|
|
|
def _escape_json_for_format(s: str) -> str:
|
|
"""Escape curly braces in a JSON string for use with str.format()."""
|
|
return s.replace("{", "{{").replace("}", "}}")
|
|
|
|
|
|
def build_prompt(pkg: dict, format_feedback: str = "") -> str:
|
|
"""Build the LLM prompt for a single function unit."""
|
|
template_path = Path(config.PROMPTS_DIR) / "step2_ir_extraction.txt"
|
|
template = template_path.read_text(encoding="utf-8")
|
|
|
|
prompt = template.format(
|
|
unit_id=pkg["unit_id"],
|
|
unit_name=_escape_json_for_format(pkg["unit_name"]),
|
|
unit_description=_escape_json_for_format(pkg["unit_description"]),
|
|
texts=_escape_json_for_format(
|
|
json.dumps(pkg.get("texts", []), ensure_ascii=False, indent=2)
|
|
),
|
|
tables=_escape_json_for_format(
|
|
json.dumps(pkg.get("tables", []), ensure_ascii=False, indent=2)
|
|
),
|
|
logic_trees=_escape_json_for_format(
|
|
json.dumps(pkg.get("logic_trees", []), ensure_ascii=False, indent=2)
|
|
),
|
|
resolved_conflicts=_escape_json_for_format(
|
|
json.dumps(pkg.get("resolved_conflicts", []), ensure_ascii=False, indent=2)
|
|
),
|
|
format_feedback=_escape_json_for_format(format_feedback),
|
|
)
|
|
return prompt
|
|
|
|
|
|
def extract_json_from_response(text: str) -> str:
|
|
"""Extract JSON array from LLM response."""
|
|
m = re.search(r"```(?:json)?\s*(\[[\s\S]*?\])\s*```", text)
|
|
if m:
|
|
return m.group(1).strip()
|
|
|
|
# Find outermost [ ... ]
|
|
start = text.find("[")
|
|
if start == -1:
|
|
raise ValueError("No JSON array found in LLM response")
|
|
|
|
depth = 0
|
|
for i in range(start, len(text)):
|
|
if text[i] == "[":
|
|
depth += 1
|
|
elif text[i] == "]":
|
|
depth -= 1
|
|
if depth == 0:
|
|
return text[start : i + 1]
|
|
|
|
raise ValueError("Unclosed JSON array in LLM response")
|
|
|
|
|
|
def _check_rule_fields(rules: list[dict]) -> tuple[bool, list[dict]]:
|
|
"""Validate each rule has required fields. Returns (passed, failures).
|
|
|
|
Each failure: {rule_id, field, issue}
|
|
"""
|
|
failures = []
|
|
for j, rule in enumerate(rules):
|
|
if not isinstance(rule, dict):
|
|
failures.append({"rule_id": f"rule[{j}]", "field": "-", "issue": "规则不是 dict"})
|
|
continue
|
|
rid = rule.get("rule_id") or f"rule[{j}]"
|
|
|
|
if not rule.get("path"):
|
|
failures.append({"rule_id": rid, "field": "path", "issue": "缺少 path 字段(必填)"})
|
|
|
|
precond = rule.get("precondition") or {}
|
|
if not precond.get("geographic_scope"):
|
|
failures.append({"rule_id": rid, "field": "precondition.geographic_scope", "issue": "缺少 geographic_scope(必填)"})
|
|
|
|
for k, action in enumerate(rule.get("actions") or []):
|
|
if not isinstance(action, dict):
|
|
continue
|
|
if action.get("type") == "user_interaction":
|
|
content = action.get("content") or ""
|
|
if not content:
|
|
failures.append({
|
|
"rule_id": rid, "field": f"actions[{k}].content",
|
|
"issue": "user_interaction 的 content 为空"
|
|
})
|
|
elif any(ph in content for ph in ["文案由业务定义", "待定", "自定义"]):
|
|
failures.append({
|
|
"rule_id": rid, "field": f"actions[{k}].content",
|
|
"issue": f"content 包含占位符: '{content}'"
|
|
})
|
|
|
|
trigger = rule.get("trigger") or {}
|
|
for k, cond in enumerate(trigger.get("conditions") or []):
|
|
if isinstance(cond, dict):
|
|
if not cond.get("signal"):
|
|
failures.append({
|
|
"rule_id": rid, "field": f"trigger.conditions[{k}].signal",
|
|
"issue": "缺少 signal"
|
|
})
|
|
if not cond.get("operator"):
|
|
failures.append({
|
|
"rule_id": rid, "field": f"trigger.conditions[{k}].operator",
|
|
"issue": "缺少 operator"
|
|
})
|
|
if "value" not in cond:
|
|
failures.append({
|
|
"rule_id": rid, "field": f"trigger.conditions[{k}].value",
|
|
"issue": "缺少 value"
|
|
})
|
|
|
|
return len(failures) == 0, failures
|
|
|
|
|
|
def _build_fix_prompt(failures: list[dict]) -> str:
|
|
"""Build a format-fix instruction block for the prompt."""
|
|
if not failures:
|
|
return ""
|
|
|
|
lines = [
|
|
"\n## 上一轮格式问题修正\n",
|
|
"上一轮输出的规则存在以下格式问题,请修正后重新输出:\n",
|
|
]
|
|
for f in failures:
|
|
lines.append(f"- **{f['rule_id']}.{f['field']}**: {f['issue']}")
|
|
|
|
lines.append("\n请修正以上所有问题,重新输出完整的规则数组。")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def extract_rules_for_unit(pkg: dict, max_retries: int | None = None) -> list[dict]:
|
|
"""Call LLM for one function unit, return its IR rules.
|
|
|
|
Includes format validation with auto-fix retries.
|
|
"""
|
|
if max_retries is None:
|
|
max_retries = config.MAX_RETRIES_PER_STAGE
|
|
client = config.llm_client()
|
|
prompt = build_prompt(pkg)
|
|
last_failures = []
|
|
|
|
for attempt in range(max_retries + 1):
|
|
# Append format feedback on retry
|
|
if attempt > 0 and last_failures:
|
|
fix_text = _build_fix_prompt(last_failures)
|
|
prompt = build_prompt(pkg, format_feedback=fix_text)
|
|
|
|
try:
|
|
resp = client.chat.completions.create(
|
|
model=config.MODEL_NAME,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "你是一个精确的 JSON 输出引擎。只输出合法的 JSON 数组。",
|
|
},
|
|
{"role": "user", "content": prompt},
|
|
],
|
|
temperature=config.TEMPERATURE,
|
|
max_tokens=config.MAX_TOKENS,
|
|
)
|
|
content = resp.choices[0].message.content
|
|
if content is None:
|
|
raise RuntimeError("LLM returned empty response")
|
|
|
|
json_str = extract_json_from_response(content)
|
|
rules = json.loads(json_str)
|
|
if not isinstance(rules, list):
|
|
raise ValueError(f"Expected JSON array, got {type(rules).__name__}")
|
|
|
|
# Format validation
|
|
passed, failures = _check_rule_fields(rules)
|
|
if passed:
|
|
return rules
|
|
|
|
# Format issues found — retry with fix instructions
|
|
print(f" 格式问题 ({len(failures)} 个): {[f['field'] for f in failures[:5]]}")
|
|
last_failures = failures
|
|
if attempt < max_retries:
|
|
time.sleep(1)
|
|
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
print(f" JSON 解析失败 (尝试 {attempt + 1}): {e}")
|
|
last_failures = [{"rule_id": "?", "field": "json", "issue": str(e)}]
|
|
if attempt < max_retries:
|
|
time.sleep(2)
|
|
|
|
# Exhausted retries — return what we have (even if imperfect)
|
|
print(f" WARN: {pkg['unit_id']} 格式修复耗尽了 {max_retries} 次重试")
|
|
return []
|
|
|
|
|
|
def extract_all_rules(
|
|
semantic_index: dict, doc: dict
|
|
) -> list[dict]:
|
|
"""Extract IR rules for all function units. Runs in parallel up to MAX_CONCURRENCY."""
|
|
sections_by_source, image_by_rid, conflicts_by_section = build_document_lookup(doc)
|
|
function_units = semantic_index.get("function_units", [])
|
|
|
|
print(f" 共 {len(function_units)} 个功能单元待处理")
|
|
print(f" 最大并发: {MAX_CONCURRENCY}")
|
|
|
|
# Build context packages (serial — fast)
|
|
packages = []
|
|
for fu in function_units:
|
|
pkg = extract_context_package(
|
|
fu, doc, sections_by_source, image_by_rid, conflicts_by_section
|
|
)
|
|
packages.append(pkg)
|
|
|
|
# Run LLM calls in parallel
|
|
fragments = []
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_CONCURRENCY) as executor:
|
|
futures = {}
|
|
for i, pkg in enumerate(packages):
|
|
future = executor.submit(extract_rules_for_unit, pkg)
|
|
futures[future] = (i, pkg["unit_id"], pkg["unit_name"])
|
|
|
|
for future in concurrent.futures.as_completed(futures):
|
|
i, uid, uname = futures[future]
|
|
try:
|
|
rules = future.result()
|
|
fragments.append({
|
|
"unit_id": uid,
|
|
"unit_name": uname,
|
|
"rules": rules
|
|
})
|
|
print(f" [OK] {uid} ({uname}): {len(rules)} 条规则")
|
|
except Exception as e:
|
|
print(f" [FAIL] {uid} ({uname}): 失败 — {e}")
|
|
fragments.append({
|
|
"unit_id": uid,
|
|
"unit_name": uname,
|
|
"rules": [],
|
|
"error": str(e)
|
|
})
|
|
|
|
# Sort by unit_id to maintain stable ordering
|
|
fragments.sort(key=lambda f: f["unit_id"])
|
|
return fragments
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("阶段二:逐功能单元 IR 提取")
|
|
print("=" * 60)
|
|
|
|
# 1. Load inputs
|
|
print(f"\n[1/3] 加载输入...")
|
|
semantic_index = load_semantic_index()
|
|
doc = config.load_input_document()
|
|
n_units = len(semantic_index.get("function_units", []))
|
|
print(f" 语义索引: {n_units} 个功能单元")
|
|
|
|
# 2. Extract rules
|
|
print(f"\n[2/3] 逐单元提取 IR 规则...")
|
|
fragments = extract_all_rules(semantic_index, doc)
|
|
|
|
# 3. Save
|
|
print(f"\n[3/3] 保存 IR 片段...")
|
|
config.save_json(fragments, config.IR_FRAGMENTS_JSON)
|
|
|
|
total_rules = sum(len(f["rules"]) for f in fragments)
|
|
failed_units = [f for f in fragments if f.get("error")]
|
|
print(f"\n完成! {len(fragments)} 个功能单元, 共 {total_rules} 条规则")
|
|
if failed_units:
|
|
print(f" [WARN] {len(failed_units)} 个单元提取失败: "
|
|
f"{[f['unit_id'] for f in failed_units]}")
|
|
print(f"输出: {config.IR_FRAGMENTS_JSON}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|