""" Stage 2: Per Function Unit IR Extraction. For each function unit from the semantic index, constructs a precision context package and calls the LLM to extract detailed IR rules. Runs multiple LLM calls in parallel (up to MAX_CONCURRENCY). Output: output/ir_fragments.json """ import concurrent.futures import json import re import sys import time from pathlib import Path import config MAX_CONCURRENCY = 3 # Max parallel LLM calls def load_semantic_index() -> dict: """Load the semantic index from Stage 1.""" return config.load_json(config.SEMANTIC_INDEX_JSON) def build_document_lookup(doc: dict): """Build lookup structures for fast context extraction from the document.""" # sections_by_source: "3.1.1" -> section dict sections_by_source = {} for section in doc.get("sections", []): source = section.get("source", "") # Normalize: extract leading number like "3.1.1" parts = source.split() if parts: key = parts[0].strip() sections_by_source[key] = section # image_by_rid: "rId16" -> image_analysis entry image_by_rid = {} for img in doc.get("image_analysis", []): rid = img.get("rid", "") if rid: image_by_rid[rid] = img # Conflicts indexed by section conflicts_by_section = {} for c in doc.get("resolved_conflicts", []): section = c.get("section", "") key = section.split()[0] if section else "" conflicts_by_section.setdefault(key, []).append(c) return sections_by_source, image_by_rid, conflicts_by_section def extract_context_package( fu: dict, doc: dict, sections_by_source: dict, image_by_rid: dict, conflicts_by_section: dict ) -> dict: """Build a precision context package for a single function unit.""" texts = [] tables = [] logic_trees = [] seen_sections = set() seen_images = set() for src in fu.get("sources", []): src_type = src.get("type", "") section_key = src.get("section", "").split()[0] if src.get("section") else "" # --- Text source --- if src_type in ("table", "para") and section_key: if section_key in seen_sections: continue seen_sections.add(section_key) section = sections_by_source.get(section_key) if section is None: # Fuzzy match by prefix for key in sections_by_source: if key.startswith(section_key): section = sections_by_source[key] break if section: for block in section.get("blocks", []): if block["type"] == "para": texts.append({ "section": section_key, "text": block["text"] }) elif block["type"] == "table": row_num = src.get("row") if src_type == "table" else None if row_num is not None: # Extract only the specific row matching_rows = [] for r in block.get("rows", []): for c in r.get("columns", []): if c.get("row") == row_num: matching_rows.append({ "headers": block.get("headers", []), "cells": { col["name"]: col["text"] for col in r["columns"] }, "row": row_num }) break tables.append({ "section": section_key, "headers": block.get("headers", []), "rows": matching_rows, "all_rows": [ { "row": col.get("row"), "name": col.get("name"), "text": col.get("text") } for row in block.get("rows", []) for col in row.get("columns", []) ] }) else: # Include full table tables.append({ "section": section_key, "headers": block.get("headers", []), "all_rows": [ { "row": col.get("row"), "name": col.get("name"), "text": col.get("text") } for row in block.get("rows", []) for col in row.get("columns", []) ] }) # --- Logic tree source --- if src_type == "logic_tree": image_id = src.get("image_id", "") if not image_id or image_id in seen_images: continue seen_images.add(image_id) img = image_by_rid.get(image_id) if img: lt = img.get("logic_tree") if lt: logic_trees.append({ "image_id": image_id, "description": img.get("description", ""), "tree": lt }) # Include relevant resolved conflicts relevant_conflicts = [] for section_key in seen_sections: for c in conflicts_by_section.get(section_key, []): relevant_conflicts.append(c) return { "unit_id": fu["unit_id"], "unit_name": fu.get("name", ""), "unit_description": fu.get("description", ""), "unit_path": fu.get("path", []), "texts": texts, "tables": tables, "logic_trees": logic_trees, "resolved_conflicts": relevant_conflicts } def format_context_package(pkg: dict) -> str: """Format a context package as a readable string for the prompt.""" parts = [] # Texts parts.append("【文字段落】") for i, t in enumerate(pkg.get("texts", [])): parts.append(f"[{t.get('section', '?')}] {t.get('text', '')}") if not pkg.get("texts"): parts.append("(无)") # Tables parts.append("\n【表格数据】") for i, tbl in enumerate(pkg.get("tables", [])): parts.append(f"表格 {i+1} (section={tbl.get('section', '?')})") headers = tbl.get("headers", []) parts.append(f" 表头: {headers}") parts.append(" 全部行数据:") for row in tbl.get("all_rows", []): parts.append( f" 行{row.get('row','?')}[{row.get('name','?')}]: {row.get('text','')}" ) # Highlight matched rows if any matched = tbl.get("rows", []) if matched: parts.append(" <重点关注行>:") for mr in matched: parts.append(f" 行{mr.get('row','?')}: {mr.get('cells', {})}") if not pkg.get("tables"): parts.append("(无)") # Logic trees parts.append("\n【逻辑树】") for i, lt in enumerate(pkg.get("logic_trees", [])): parts.append(f"逻辑树 {i+1} (image_id={lt.get('image_id', '?')})") parts.append(f" 描述: {lt.get('description', '')[:200]}") tree = lt.get("tree", {}) parts.append(f" 根: {tree.get('root', '?')}") parts.append(" 节点:") for node in tree.get("nodes", []): nid = node.get("id", "?") ntype = node.get("type", "?") desc = node.get("description", "") or node.get("condition", "") parts.append(f" [{ntype}] {nid}: {desc}") for br in node.get("branches", []): parts.append(f" → {br['value']} → {br['target']}") if not pkg.get("logic_trees"): parts.append("(无)") # Conflicts conflicts = pkg.get("resolved_conflicts", []) if conflicts: parts.append("\n【图文冲突仲裁】") for c in conflicts: parts.append( f" [{c.get('conflict_type', '?')}] 以{c.get('source', '?')}为准: " f"{c.get('correction', '')}" ) return "\n".join(parts) def _escape_json_for_format(s: str) -> str: """Escape curly braces in a JSON string for use with str.format().""" return s.replace("{", "{{").replace("}", "}}") def build_prompt(pkg: dict, format_feedback: str = "") -> str: """Build the LLM prompt for a single function unit.""" template_path = Path(config.PROMPTS_DIR) / "step2_ir_extraction.txt" template = template_path.read_text(encoding="utf-8") prompt = template.format( unit_id=pkg["unit_id"], unit_name=_escape_json_for_format(pkg["unit_name"]), unit_description=_escape_json_for_format(pkg["unit_description"]), texts=_escape_json_for_format( json.dumps(pkg.get("texts", []), ensure_ascii=False, indent=2) ), tables=_escape_json_for_format( json.dumps(pkg.get("tables", []), ensure_ascii=False, indent=2) ), logic_trees=_escape_json_for_format( json.dumps(pkg.get("logic_trees", []), ensure_ascii=False, indent=2) ), resolved_conflicts=_escape_json_for_format( json.dumps(pkg.get("resolved_conflicts", []), ensure_ascii=False, indent=2) ), format_feedback=_escape_json_for_format(format_feedback), ) return prompt def extract_json_from_response(text: str) -> str: """Extract JSON array from LLM response.""" m = re.search(r"```(?:json)?\s*(\[[\s\S]*?\])\s*```", text) if m: return m.group(1).strip() # Find outermost [ ... ] start = text.find("[") if start == -1: raise ValueError("No JSON array found in LLM response") depth = 0 for i in range(start, len(text)): if text[i] == "[": depth += 1 elif text[i] == "]": depth -= 1 if depth == 0: return text[start : i + 1] raise ValueError("Unclosed JSON array in LLM response") def _check_rule_fields(rules: list[dict]) -> tuple[bool, list[dict]]: """Validate each rule has required fields. Returns (passed, failures). Each failure: {rule_id, field, issue} """ failures = [] for j, rule in enumerate(rules): if not isinstance(rule, dict): failures.append({"rule_id": f"rule[{j}]", "field": "-", "issue": "规则不是 dict"}) continue rid = rule.get("rule_id") or f"rule[{j}]" if not rule.get("path"): failures.append({"rule_id": rid, "field": "path", "issue": "缺少 path 字段(必填)"}) precond = rule.get("precondition") or {} if not precond.get("geographic_scope"): failures.append({"rule_id": rid, "field": "precondition.geographic_scope", "issue": "缺少 geographic_scope(必填)"}) for k, action in enumerate(rule.get("actions") or []): if not isinstance(action, dict): continue if action.get("type") == "user_interaction": content = action.get("content") or "" if not content: failures.append({ "rule_id": rid, "field": f"actions[{k}].content", "issue": "user_interaction 的 content 为空" }) elif any(ph in content for ph in ["文案由业务定义", "待定", "自定义"]): failures.append({ "rule_id": rid, "field": f"actions[{k}].content", "issue": f"content 包含占位符: '{content}'" }) trigger = rule.get("trigger") or {} for k, cond in enumerate(trigger.get("conditions") or []): if isinstance(cond, dict): if not cond.get("signal"): failures.append({ "rule_id": rid, "field": f"trigger.conditions[{k}].signal", "issue": "缺少 signal" }) if not cond.get("operator"): failures.append({ "rule_id": rid, "field": f"trigger.conditions[{k}].operator", "issue": "缺少 operator" }) if "value" not in cond: failures.append({ "rule_id": rid, "field": f"trigger.conditions[{k}].value", "issue": "缺少 value" }) return len(failures) == 0, failures def _build_fix_prompt(failures: list[dict]) -> str: """Build a format-fix instruction block for the prompt.""" if not failures: return "" lines = [ "\n## 上一轮格式问题修正\n", "上一轮输出的规则存在以下格式问题,请修正后重新输出:\n", ] for f in failures: lines.append(f"- **{f['rule_id']}.{f['field']}**: {f['issue']}") lines.append("\n请修正以上所有问题,重新输出完整的规则数组。") return "\n".join(lines) def extract_rules_for_unit(pkg: dict, max_retries: int | None = None) -> list[dict]: """Call LLM for one function unit, return its IR rules. Includes format validation with auto-fix retries. """ if max_retries is None: max_retries = config.MAX_RETRIES_PER_STAGE client = config.llm_client() prompt = build_prompt(pkg) last_failures = [] for attempt in range(max_retries + 1): # Append format feedback on retry if attempt > 0 and last_failures: fix_text = _build_fix_prompt(last_failures) prompt = build_prompt(pkg, format_feedback=fix_text) try: resp = client.chat.completions.create( model=config.MODEL_NAME, messages=[ { "role": "system", "content": "你是一个精确的 JSON 输出引擎。只输出合法的 JSON 数组。", }, {"role": "user", "content": prompt}, ], temperature=config.TEMPERATURE, max_tokens=config.MAX_TOKENS, ) content = resp.choices[0].message.content if content is None: raise RuntimeError("LLM returned empty response") json_str = extract_json_from_response(content) rules = json.loads(json_str) if not isinstance(rules, list): raise ValueError(f"Expected JSON array, got {type(rules).__name__}") # Format validation passed, failures = _check_rule_fields(rules) if passed: return rules # Format issues found — retry with fix instructions print(f" 格式问题 ({len(failures)} 个): {[f['field'] for f in failures[:5]]}") last_failures = failures if attempt < max_retries: time.sleep(1) except (json.JSONDecodeError, ValueError) as e: print(f" JSON 解析失败 (尝试 {attempt + 1}): {e}") last_failures = [{"rule_id": "?", "field": "json", "issue": str(e)}] if attempt < max_retries: time.sleep(2) # Exhausted retries — return what we have (even if imperfect) print(f" WARN: {pkg['unit_id']} 格式修复耗尽了 {max_retries} 次重试") return [] def extract_all_rules( semantic_index: dict, doc: dict ) -> list[dict]: """Extract IR rules for all function units. Runs in parallel up to MAX_CONCURRENCY.""" sections_by_source, image_by_rid, conflicts_by_section = build_document_lookup(doc) function_units = semantic_index.get("function_units", []) print(f" 共 {len(function_units)} 个功能单元待处理") print(f" 最大并发: {MAX_CONCURRENCY}") # Build context packages (serial — fast) packages = [] for fu in function_units: pkg = extract_context_package( fu, doc, sections_by_source, image_by_rid, conflicts_by_section ) packages.append(pkg) # Run LLM calls in parallel fragments = [] with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_CONCURRENCY) as executor: futures = {} for i, pkg in enumerate(packages): future = executor.submit(extract_rules_for_unit, pkg) futures[future] = (i, pkg["unit_id"], pkg["unit_name"]) for future in concurrent.futures.as_completed(futures): i, uid, uname = futures[future] try: rules = future.result() fragments.append({ "unit_id": uid, "unit_name": uname, "rules": rules }) print(f" [OK] {uid} ({uname}): {len(rules)} 条规则") except Exception as e: print(f" [FAIL] {uid} ({uname}): 失败 — {e}") fragments.append({ "unit_id": uid, "unit_name": uname, "rules": [], "error": str(e) }) # Sort by unit_id to maintain stable ordering fragments.sort(key=lambda f: f["unit_id"]) return fragments def main(): print("=" * 60) print("阶段二:逐功能单元 IR 提取") print("=" * 60) # 1. Load inputs print(f"\n[1/3] 加载输入...") semantic_index = load_semantic_index() doc = config.load_input_document() n_units = len(semantic_index.get("function_units", [])) print(f" 语义索引: {n_units} 个功能单元") if n_units == 0: print("错误: 语义索引中无功能单元 (function_units 为空)。") print(" 请检查 step1_semantic_index 是否正确运行。") print(" 可能原因: LLM API Key 未配置、Prompt 不兼容、或输入文档格式异常。") sys.exit(1) # 2. Extract rules print(f"\n[2/3] 逐单元提取 IR 规则...") fragments = extract_all_rules(semantic_index, doc) # Filter out fragments with empty rules (LLM extraction failures) empty_units = [f["unit_id"] for f in fragments if not f.get("rules") and not f.get("error")] if empty_units: print(f" [WARN] {len(empty_units)} 个单元规则为空,已过滤: {empty_units}") fragments = [f for f in fragments if f.get("rules") or f.get("error")] # 3. Save print(f"\n[3/3] 保存 IR 片段...") config.save_json(fragments, config.IR_FRAGMENTS_JSON) total_rules = sum(len(f["rules"]) for f in fragments) failed_units = [f for f in fragments if f.get("error")] print(f"\n完成! {len(fragments)} 个功能单元, 共 {total_rules} 条规则") if failed_units: print(f" [WARN] {len(failed_units)} 个单元提取失败: " f"{[f['unit_id'] for f in failed_units]}") print(f"输出: {config.IR_FRAGMENTS_JSON}") if __name__ == "__main__": main()