From fec4c09ee08e194eb1c8254b42db922ac7e571f9 Mon Sep 17 00:00:00 2001 From: Peter Zhang <18501667167@qq.com> Date: Sat, 30 May 2026 22:45:08 +0800 Subject: [PATCH] sync: update all skills from latest workspace code doc_parser_skill: - New: verify_flowchart.py (flowchart validation) - Updated: LLM.py (multi-provider: DeepSeek + DashScope) - Updated: image_parser.py (logic tree support, external prompts) - Updated: SKILL.md, prompts/image_prompt.md conflict_detection_skill: - Updated: LLM.py (multi-provider sync) - Updated: detect_conflicts.py (logic tree text conversion) ir_generation_skill: - Replaced old scripts/LLM.py + ir_generator.py with standalone project - New: main.py, config.py, step1-3_*.py, ensemble_merge.py - New: prompts/, tests/ subdirectories tests: - New: acceptance/ test suite with schema validation - Fixed: conftest no longer globally skips non-acceptance tests - Updated: test_sample.py for new ir_generation structure Co-Authored-By: Claude Opus 4.7 --- .gitea/workflows/acceptance.yml | 54 + .gitignore | 4 + scripts/create_failure_issue.py | 10 +- .../conflict_detection_skill/scripts/LLM.py | 95 +- .../scripts/detect_conflicts.py | 85 +- skills/doc_parser_skill/SKILL.md | 5 +- .../doc_parser_skill/prompts/image_prompt.md | 203 +++ skills/doc_parser_skill/scripts/LLM.py | 95 +- .../doc_parser_skill/scripts/image_parser.py | 357 +++++- .../scripts/verify_flowchart.py | 384 ++++++ skills/ir_generation_skill/.gitignore | 9 + skills/ir_generation_skill/config.py | 137 +++ skills/ir_generation_skill/ensemble_merge.py | 593 +++++++++ skills/ir_generation_skill/main.py | 157 +++ .../prompts/step1_feedback.txt | 46 + .../prompts/step1_semantic_index.txt | 123 ++ .../prompts/step2_ir_extraction.txt | 200 +++ skills/ir_generation_skill/scripts/LLM.py | 105 -- .../scripts/ir_generator.py | 359 ------ .../step1_semantic_index.py | 717 +++++++++++ .../step2_5_branch_coverage.py | 399 ++++++ .../step2_ir_extraction.py | 508 ++++++++ .../step3_merge_and_audit.py | 1094 +++++++++++++++++ .../tests/test_ensemble_merge.py | 472 +++++++ .../ir_generation_skill/tests/test_step1.py | 370 ++++++ .../ir_generation_skill/tests/test_step2.py | 322 +++++ .../ir_generation_skill/tests/test_step2_5.py | 152 +++ .../ir_generation_skill/tests/test_step3.py | 232 ++++ tests/__init__.py | 1 + tests/acceptance/__init__.py | 1 + tests/acceptance/conftest.py | 186 +++ tests/acceptance/ir_schema.py | 325 +++++ tests/acceptance/report.py | 178 +++ tests/acceptance/test_main_health.py | 558 +++++++++ tests/test_sample.py | 15 +- 35 files changed, 8021 insertions(+), 530 deletions(-) create mode 100644 .gitea/workflows/acceptance.yml create mode 100644 skills/doc_parser_skill/prompts/image_prompt.md create mode 100644 skills/doc_parser_skill/scripts/verify_flowchart.py create mode 100644 skills/ir_generation_skill/.gitignore create mode 100644 skills/ir_generation_skill/config.py create mode 100644 skills/ir_generation_skill/ensemble_merge.py create mode 100644 skills/ir_generation_skill/main.py create mode 100644 skills/ir_generation_skill/prompts/step1_feedback.txt create mode 100644 skills/ir_generation_skill/prompts/step1_semantic_index.txt create mode 100644 skills/ir_generation_skill/prompts/step2_ir_extraction.txt delete mode 100644 skills/ir_generation_skill/scripts/LLM.py delete mode 100644 skills/ir_generation_skill/scripts/ir_generator.py create mode 100644 skills/ir_generation_skill/step1_semantic_index.py create mode 100644 skills/ir_generation_skill/step2_5_branch_coverage.py create mode 100644 skills/ir_generation_skill/step2_ir_extraction.py create mode 100644 skills/ir_generation_skill/step3_merge_and_audit.py create mode 100644 skills/ir_generation_skill/tests/test_ensemble_merge.py create mode 100644 skills/ir_generation_skill/tests/test_step1.py create mode 100644 skills/ir_generation_skill/tests/test_step2.py create mode 100644 skills/ir_generation_skill/tests/test_step2_5.py create mode 100644 skills/ir_generation_skill/tests/test_step3.py create mode 100644 tests/__init__.py create mode 100644 tests/acceptance/__init__.py create mode 100644 tests/acceptance/conftest.py create mode 100644 tests/acceptance/ir_schema.py create mode 100644 tests/acceptance/report.py create mode 100644 tests/acceptance/test_main_health.py diff --git a/.gitea/workflows/acceptance.yml b/.gitea/workflows/acceptance.yml new file mode 100644 index 0000000..15da188 --- /dev/null +++ b/.gitea/workflows/acceptance.yml @@ -0,0 +1,54 @@ +name: QE Acceptance Tests + +on: + workflow_dispatch: + inputs: + acceptance_runs: + description: 'Layer B stability runs (1 = skip stability testing)' + required: false + default: '1' + ir_path: + description: 'Path to IR JSON file (relative to workspace)' + required: false + default: 'output/ir_final.json' + parsed_path: + description: 'Path to _parsed.json or _updated.json (relative to workspace)' + required: false + default: 'output/车机娱乐系统禁止功能文档_精简_updated.json' + +jobs: + acceptance: + runs-on: shell + timeout-minutes: 30 + steps: + - name: Checkout main branch + run: | + git clone --depth 1 http://localhost:3000/pzhang_zywl/document_analyzer.git . + git checkout main + + - name: Install dependencies + run: pip install -r requirements.txt + + - name: Run QE Acceptance Tests + run: >- + python -m pytest tests/acceptance/ -v + --run-acceptance + --acceptance-runs=${{ github.event.inputs.acceptance_runs }} + --ir-path=${{ github.event.inputs.ir_path }} + --parsed-path=${{ github.event.inputs.parsed_path }} + --tb=long + env: + DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }} + + - name: Create issue on failure + if: failure() + env: + GITEA_API_TOKEN: ${{ secrets.GITEA_TOKEN }} + run: >- + python scripts/create_failure_issue.py + --sha "${{ github.sha }}" + --branch "main" + --run "${{ github.run_number }}" + --message "QE Acceptance Tests Failed" + --workflow "QE Acceptance" + --labels "acceptance-failure,agent-task" diff --git a/.gitignore b/.gitignore index d55d938..014408a 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,7 @@ output/ dist/ .runner *_output/ +*.png +*.jpg +acceptance-report.json +ir_final.json diff --git a/scripts/create_failure_issue.py b/scripts/create_failure_issue.py index c888aef..72be1fe 100644 --- a/scripts/create_failure_issue.py +++ b/scripts/create_failure_issue.py @@ -17,14 +17,18 @@ def main(): parser.add_argument("--run", required=True) parser.add_argument("--message", required=True) parser.add_argument("--api-token", default=os.environ.get("GITEA_API_TOKEN", "")) + parser.add_argument("--workflow", default="CI", help="Workflow name that triggered this (default: CI)") + parser.add_argument("--labels", default="ci-failure,agent-task", + help="Comma-separated labels for the issue (default: ci-failure,agent-task)") args = parser.parse_args() sha_short = args.sha[:7] run_url = f"{GITEA_URL}/{REPO}/actions/runs/{args.run}" + labels = [l.strip() for l in args.labels.split(",") if l.strip()] - title = f"CI Failure: {args.message[:80]}" + title = f"[{args.workflow}] Failure: {args.message[:80]}" body = ( - f"## CI 测试失败\n\n" + f"## {args.workflow} 测试失败\n\n" f"- **Commit:** {sha_short}\n" f"- **Branch:** {args.branch}\n" f"- **工作流运行:** {run_url}\n\n" @@ -38,7 +42,7 @@ def main(): payload = json.dumps({ "title": title, "body": body, - "labels": [], + "labels": labels, }).encode("utf-8") url = f"{GITEA_URL}/api/v1/repos/{REPO}/issues" diff --git a/skills/conflict_detection_skill/scripts/LLM.py b/skills/conflict_detection_skill/scripts/LLM.py index e6f2099..8fff911 100644 --- a/skills/conflict_detection_skill/scripts/LLM.py +++ b/skills/conflict_detection_skill/scripts/LLM.py @@ -1,38 +1,97 @@ import logging import os import time +from pathlib import Path from typing import Optional from openai import OpenAI logger = logging.getLogger(__name__) +# Resolve secrets file: priority 1) env OPENCLAW_SECRETS, +# 2) workspace-document-analyzer/config/ (relative to skills dir), +# 3) .openclaw/config/ +_SECRETS_FILE = None +for _candidate in ( + os.environ.get("OPENCLAW_SECRETS", ""), + Path(__file__).resolve().parents[3] / "config" / "secrets.yaml", + Path(__file__).resolve().parents[5] / ".openclaw" / "config" / "secrets.yaml", +): + if _candidate and Path(_candidate).exists(): + _SECRETS_FILE = Path(_candidate) + break +if _SECRETS_FILE is None: + _SECRETS_FILE = Path("") # empty fallback + + +def _load_secrets() -> dict: + """Load API keys from secrets.yaml, with env-var overrides.""" + secrets = {} + if _SECRETS_FILE.exists(): + try: + import yaml + with open(_SECRETS_FILE, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) or {} + for provider in ("deepseek", "dashscope"): + if provider in data and isinstance(data[provider], dict): + secrets[provider] = data[provider] + except ImportError: + logger.warning("pyyaml not installed, cannot read %s", _SECRETS_FILE) + except Exception as e: + logger.warning("Failed to load %s: %s", _SECRETS_FILE, e) + + # Env overrides + dk_env = os.environ.get("DEEPSEEK_API_KEY", "") + ds_env = os.environ.get("DASHSCOPE_API_KEY", "") + if dk_env: + secrets.setdefault("deepseek", {})["apiKey"] = dk_env + if ds_env: + secrets.setdefault("dashscope", {})["apiKey"] = ds_env + return secrets + class LLMClient: - """Low-level OpenAI-compatible LLM client with retry and token tracking. + """Multi-provider LLM client with retry and token tracking. + + Routes text models to DeepSeek, vision models to DashScope (Bailian). + Reads API keys from openclaw config/secrets.yaml, with env-var overrides. Usage:: llm = LLMClient() - content = llm.chat("qwen3.5-flash", [{"role": "user", "content": "Hello"}]) + content = llm.chat("deepseek-v4-pro", [{"role": "user", "content": "Hello"}]) print(llm.usage) """ IMAGE_MODEL = "qwen3-vl-plus" - TEXT_MODEL = "qwen3.5-flash-2026-02-23" + TEXT_MODEL = "deepseek-v4-flash" + + DASHSCOPE_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1" + DEEPSEEK_BASE = "https://api.deepseek.com/v1" + TIMEOUT = 120 MAX_RETRIES = 3 + _VISION_KEYWORDS = ("vl", "vision", "qwen-vl", "qwen3-vl") + def __init__( self, *, - base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1", timeout: int | None = None, ): - key = os.environ.get("DASHSCOPE_API_KEY", "") - if not key: - raise ValueError("DASHSCOPE_API_KEY environment variable is not set.") - self._client = OpenAI(api_key=key, base_url=base_url) + secrets = _load_secrets() + + ds_cfg = secrets.get("dashscope", {}) + dk_cfg = secrets.get("deepseek", {}) + + dashscope_key = ds_cfg.get("apiKey", "") + dashscope_url = ds_cfg.get("baseUrl", self.DASHSCOPE_BASE) + deepseek_key = dk_cfg.get("apiKey", "") + deepseek_url = dk_cfg.get("baseUrl", self.DEEPSEEK_BASE) + + self._ds_client = OpenAI(api_key=dashscope_key, base_url=dashscope_url) if dashscope_key else None + self._dk_client = OpenAI(api_key=deepseek_key, base_url=deepseek_url) if deepseek_key else None + self._timeout = timeout or self.TIMEOUT self._prompt_tokens = 0 self._completion_tokens = 0 @@ -49,7 +108,7 @@ class LLMClient: @staticmethod def estimate_tokens(text: str) -> int: """Quick token estimate. CJK ≈1.7/token, others ≈3.0/token.""" - cjk = sum(1 for c in text if '一' <= c <= '鿿' or ' ' <= c <= '〿') + cjk = sum(1 for c in text if '\u4e00' <= c <= '\u9fff' or '\u3000' <= c <= '\u303f') other = len(text) - cjk return max(1, int(cjk / 1.7 + other / 3.0)) @@ -58,6 +117,20 @@ class LLMClient: """Fixed estimate for one vision-model image (~500 tokens).""" return 500 + @staticmethod + def _is_vision_model(model: str) -> bool: + return any(kw in model.lower() for kw in LLMClient._VISION_KEYWORDS) + + def _get_client(self, model: str) -> OpenAI: + if self._is_vision_model(model): + if self._ds_client is None: + raise ValueError("DASHSCOPE_API_KEY not set but required for vision model") + return self._ds_client + else: + if self._dk_client is None: + raise ValueError("DEEPSEEK_API_KEY not set but required for text model") + return self._dk_client + def chat( self, model: str, messages: list[dict], *, timeout: int | None = None, response_format: dict | None = None, @@ -65,8 +138,10 @@ class LLMClient: """Send a chat completion request and return the response content. Automatically retries on failure and accumulates token usage. + Routes to DeepSeek for text, DashScope for vision. """ label = f"chat({model})" + client = self._get_client(model) def _call(): t0 = time.time() @@ -74,7 +149,7 @@ class LLMClient: if response_format is not None: kwargs["response_format"] = response_format kwargs["temperature"] = 0 - resp = self._client.chat.completions.create(**kwargs) + resp = client.chat.completions.create(**kwargs) content = resp.choices[0].message.content usg = resp.usage if usg: diff --git a/skills/conflict_detection_skill/scripts/detect_conflicts.py b/skills/conflict_detection_skill/scripts/detect_conflicts.py index ddafd33..96864a2 100644 --- a/skills/conflict_detection_skill/scripts/detect_conflicts.py +++ b/skills/conflict_detection_skill/scripts/detect_conflicts.py @@ -96,6 +96,77 @@ PROMPT_DETECT_CONFLICT = """你是一个文档一致性检查专家。以下内 """ +def _is_nested_tree(lt: dict) -> bool: + """Return True if logic_tree uses the nested children format.""" + return isinstance(lt.get("children"), list) + + +def _logic_tree_to_text(lt: dict) -> str: + """Convert logic_tree JSON to readable text for conflict detection. + + Supports both the new nested-tree format and the legacy flat-nodes format. + """ + if _is_nested_tree(lt): + return _nested_tree_to_text(lt) + return _flat_tree_to_text(lt) + + +def _nested_tree_to_text(tree: dict) -> str: + """Convert a nested flowchart tree to readable text.""" + lines: list[str] = [] + + def _walk(node: dict, indent: int = 0): + prefix = " " * indent + nid = node.get("id", "") + name = node.get("name", "") + ntype = node.get("type", "") + + type_label = { + "start": "起始", "end": "结束", "process": "处理", + "decision": "判断", "action": "动作", + }.get(ntype, ntype) + + lines.append(f"{prefix}[{type_label}] {nid}: {name}") + + if ntype == "decision": + for child in node.get("children", []): + cond = child.get("condition", "") + lines.append(f"{prefix} 分支 \"{cond}\":") + _walk(child["node"], indent + 2) + elif "children" in node: + for child in node.get("children", []): + _walk(child, indent + 1) + + _walk(tree) + return "\n".join(lines) + + +def _flat_tree_to_text(lt: dict) -> str: + """Convert legacy flat-nodes logic_tree to readable text.""" + lines: list[str] = [] + root = lt.get("root", "") + if root: + lines.append(f"根节点: {root}") + for node in lt.get("nodes", []): + nid = node.get("id", "") + ntype = node.get("type", "") + if ntype == "decision": + cond = node.get("condition", "") + branches = node.get("branches", []) + lines.append(f"判断节点 {nid}: 条件=\"{cond}\"") + for b in branches: + lines.append(f" - 分支 \"{b.get('value', '')}\" → {b.get('target', '')}") + elif ntype == "action": + lines.append(f"动作节点 {nid}: {node.get('description', '')}") + elif ntype == "state": + lines.append(f"状态节点 {nid}: {node.get('description', '')}") + elif ntype == "start": + lines.append(f"起始节点 {nid}: {node.get('description', '')}") + elif ntype == "end": + lines.append(f"结束节点 {nid}: {node.get('description', '')}") + return "\n".join(lines) + + def _build_text_for_section(sections: list[dict], section_name: str) -> str: """Build a single text block for the given section name.""" texts: list[str] = [] @@ -184,8 +255,9 @@ def detect_conflicts( img_type = img.get("type", "other") rid = img.get("rid", "") description = img.get("description", "").strip() + logic_tree = img.get("logic_tree_nested") or img.get("logic_tree") - if img_type not in DIAGRAM_TYPES or not description: + if img_type not in DIAGRAM_TYPES or (not description and not logic_tree): logger.info("Skip conflict check: rid=%s type=%s", rid, img_type) continue @@ -211,8 +283,17 @@ def detect_conflicts( logger.info(" [DRY RUN] would call LLM to detect conflicts") continue + # Enrich description with logic_tree if available + combined_desc = description + if logic_tree: + lt_text = _logic_tree_to_text(logic_tree) + if combined_desc: + combined_desc = f"[结构化逻辑树]\n{lt_text}\n\n[文字描述]\n{combined_desc}" + else: + combined_desc = f"[结构化逻辑树]\n{lt_text}" + prompt = PROMPT_DETECT_CONFLICT.format( - image_description=description, + image_description=combined_desc, text_description=text_content, section_name=section_name, ) diff --git a/skills/doc_parser_skill/SKILL.md b/skills/doc_parser_skill/SKILL.md index 53329b8..31726b4 100644 --- a/skills/doc_parser_skill/SKILL.md +++ b/skills/doc_parser_skill/SKILL.md @@ -29,7 +29,10 @@ description: 解析文档(.docx, .pdf)以提取图像和文本结构,并 该技能生成一个结构化JSON文件,文件名为输入文档的基本名称后跟'_parsed.json',包含: - `sections`:按标题分组的文档文本结构 - `image_sources`:从图像标识符到其在文档中位置的映射 -- `image_analysis`:由视觉大语言模型确定的每个图像的类型和内容描述 +- `image_analysis`:由视觉大语言模型确定的每个图像的类型、内容描述和(如适用)结构化逻辑树 + - `type`: 图片类型(flowchart/architecture/state/sequence/activity/other) + - `description`: 图片的文字描述 + - `logic_tree`(可选,仅图表类型):结构化逻辑树JSON,包含 `root`(根节点描述)和 `nodes` 数组。节点类型:`decision`(判断)、`action`(动作)、`state`(状态)、`start`(开始)、`end`(结束)。decision 节点包含 `condition` 和 `branches` 字段,其他节点包含 `description` 字段。 ## 集成点 diff --git a/skills/doc_parser_skill/prompts/image_prompt.md b/skills/doc_parser_skill/prompts/image_prompt.md new file mode 100644 index 0000000..84ad9bc --- /dev/null +++ b/skills/doc_parser_skill/prompts/image_prompt.md @@ -0,0 +1,203 @@ +请分析这张图片,判断类型并输出文字描述和(如适用)结构化逻辑树。 + +## 判断图片类型 + +如果是 **流程图 / 架构图 / 状态图 / 时序图 / 活动图**,你需要输出三项内容: +1. 类型标签 +2. **嵌套逻辑树 JSON**(见下方格式) +3. 文字描述 + +如果是 **其他类型**(UI原型图 / 界面截图 / 设计稿 / 手机屏幕截图 / 网页截图等),只输出类型标签和简要文字描述。 + +## 嵌套逻辑树 JSON 格式(仅流程图/架构图/状态图/时序图/活动图需要) + +**核心原则:用嵌套的 `children` 数组表达流程的层级关系,而不是用 id 引用。** 这种格式更贴近流程图的自然结构,每个节点的后续步骤直接嵌套在其下方。 + +### 节点类型 + +| 类型 | 含义 | 对应形状 | +|------|------|----------| +| `start` | 起始节点 | 椭圆/圆角矩形 | +| `end` | 结束节点 | 椭圆/圆角矩形 | +| `process` | 处理/状态节点 | 矩形/圆角矩形 | +| `decision` | 判断节点 | 菱形 | +| `action` | 动作节点 | 矩形 | + +### 非判断节点的 children 格式 + +对于 `start`、`end`、`process`、`action` 节点,`children` 是一个数组,包含后续步骤节点: + +```json +{ + "id": "n1", + "name": "节点名称", + "type": "process", + "children": [ + { + "id": "n2", + "name": "下一个步骤", + "type": "action", + "children": [...] + } + ] +} +``` + +### 判断节点的 children 格式 + +对于 `decision` 节点,`children` 是一个数组,每个元素包含 `condition`(分支条件)和 `node`(该分支对应的子节点): + +```json +{ + "id": "n5", + "name": "是否满足条件?", + "type": "decision", + "children": [ + { + "condition": "是", + "node": { + "id": "n6", + "name": "满足条件时的动作", + "type": "action", + "children": [...] + } + }, + { + "condition": "否", + "node": { + "id": "n7", + "name": "不满足条件时的动作", + "type": "action", + "children": [...] + } + } + ] +} +``` + +### 结束节点 + +`end` 节点没有 `children` 字段: + +```json +{ + "id": "n10", + "name": "流程结束", + "type": "end" +} +``` + +### 完整示例 + +```json +{ + "id": "n1", + "name": "开关状态", + "type": "start", + "children": [ + { + "id": "n2", + "name": "开启", + "type": "process", + "children": [ + { + "id": "n3", + "name": "是否在目标场景?", + "type": "decision", + "children": [ + { + "condition": "否", + "node": { + "id": "n4", + "name": "不受限", + "type": "end" + } + }, + { + "condition": "是", + "node": { + "id": "n5", + "name": "车速是否≥15km/h且持续5秒?", + "type": "decision", + "children": [ + { + "condition": "否", + "node": { + "id": "n6", + "name": "不受限", + "type": "end" + } + }, + { + "condition": "是", + "node": { + "id": "n7", + "name": "暂停功能", + "type": "action", + "children": [ + { + "id": "n8", + "name": "发起Toast提示", + "type": "end" + } + ] + } + } + ] + } + } + ] + } + ] + }, + { + "id": "n9", + "name": "关闭", + "type": "process", + "children": [ + { + "id": "n10", + "name": "不受限", + "type": "end" + } + ] + } + ] +} +``` + +### 规则 + +1. 每条从根节点到 `end` 节点的路径必须是完整的逻辑链 +2. `decision` 节点的 `children` 必须穷举所有分支(通常为"是/否"),每条分支包含 `condition` 和 `node` +3. 只有 `end` 节点没有 `children` 字段,其他所有节点都应该有 `children` +4. 节点 id 使用 "n1", "n2", "n3"... 格式,按流程图从上到下、从左到右的顺序编号 +5. 仔细阅读图片中的每个判断条件和分支走向,确保分支目标节点正确 +6. 如果流程图中某个分支的后续步骤在图片中没有展示,将其标记为 `end` 节点,`name` 设为"(图中未展示)" +7. **如果图片包含多个独立的流程图**(例如上半部分和下半部分分别描述不同场景),使用一个统一的 `process` 根节点将它们组织在一起。例如图片中有"策略A"和"策略B"两个流程,结构为: +```json +{ + "id": "n1", + "name": "策略总览", + "type": "process", + "children": [ + {"id": "n2", "name": "策略A流程", "type": "process", "children": [...]}, + {"id": "n3", "name": "策略B流程", "type": "process", "children": [...]} + ] +} +``` + +## 输出格式 + +**1. 类型标签(单独一行):** +type: + +**2. 逻辑树 JSON(仅上述5种类型,以 logic_tree: 开头,后跟 JSON 对象):** +logic_tree: +{...} + +**3. 文字描述(以 description: 开头):** +description: +该图片的详细文字描述。对于流程图/架构图等类型,这里提供自然语言总结;对于其他类型,这是唯一的描述内容。 + +不要输出 ``` 代码块包裹符号,不要输出 ---YAML--- 分隔符,不要添加任何额外的解释或问候语。 diff --git a/skills/doc_parser_skill/scripts/LLM.py b/skills/doc_parser_skill/scripts/LLM.py index e6f2099..8fff911 100644 --- a/skills/doc_parser_skill/scripts/LLM.py +++ b/skills/doc_parser_skill/scripts/LLM.py @@ -1,38 +1,97 @@ import logging import os import time +from pathlib import Path from typing import Optional from openai import OpenAI logger = logging.getLogger(__name__) +# Resolve secrets file: priority 1) env OPENCLAW_SECRETS, +# 2) workspace-document-analyzer/config/ (relative to skills dir), +# 3) .openclaw/config/ +_SECRETS_FILE = None +for _candidate in ( + os.environ.get("OPENCLAW_SECRETS", ""), + Path(__file__).resolve().parents[3] / "config" / "secrets.yaml", + Path(__file__).resolve().parents[5] / ".openclaw" / "config" / "secrets.yaml", +): + if _candidate and Path(_candidate).exists(): + _SECRETS_FILE = Path(_candidate) + break +if _SECRETS_FILE is None: + _SECRETS_FILE = Path("") # empty fallback + + +def _load_secrets() -> dict: + """Load API keys from secrets.yaml, with env-var overrides.""" + secrets = {} + if _SECRETS_FILE.exists(): + try: + import yaml + with open(_SECRETS_FILE, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) or {} + for provider in ("deepseek", "dashscope"): + if provider in data and isinstance(data[provider], dict): + secrets[provider] = data[provider] + except ImportError: + logger.warning("pyyaml not installed, cannot read %s", _SECRETS_FILE) + except Exception as e: + logger.warning("Failed to load %s: %s", _SECRETS_FILE, e) + + # Env overrides + dk_env = os.environ.get("DEEPSEEK_API_KEY", "") + ds_env = os.environ.get("DASHSCOPE_API_KEY", "") + if dk_env: + secrets.setdefault("deepseek", {})["apiKey"] = dk_env + if ds_env: + secrets.setdefault("dashscope", {})["apiKey"] = ds_env + return secrets + class LLMClient: - """Low-level OpenAI-compatible LLM client with retry and token tracking. + """Multi-provider LLM client with retry and token tracking. + + Routes text models to DeepSeek, vision models to DashScope (Bailian). + Reads API keys from openclaw config/secrets.yaml, with env-var overrides. Usage:: llm = LLMClient() - content = llm.chat("qwen3.5-flash", [{"role": "user", "content": "Hello"}]) + content = llm.chat("deepseek-v4-pro", [{"role": "user", "content": "Hello"}]) print(llm.usage) """ IMAGE_MODEL = "qwen3-vl-plus" - TEXT_MODEL = "qwen3.5-flash-2026-02-23" + TEXT_MODEL = "deepseek-v4-flash" + + DASHSCOPE_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1" + DEEPSEEK_BASE = "https://api.deepseek.com/v1" + TIMEOUT = 120 MAX_RETRIES = 3 + _VISION_KEYWORDS = ("vl", "vision", "qwen-vl", "qwen3-vl") + def __init__( self, *, - base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1", timeout: int | None = None, ): - key = os.environ.get("DASHSCOPE_API_KEY", "") - if not key: - raise ValueError("DASHSCOPE_API_KEY environment variable is not set.") - self._client = OpenAI(api_key=key, base_url=base_url) + secrets = _load_secrets() + + ds_cfg = secrets.get("dashscope", {}) + dk_cfg = secrets.get("deepseek", {}) + + dashscope_key = ds_cfg.get("apiKey", "") + dashscope_url = ds_cfg.get("baseUrl", self.DASHSCOPE_BASE) + deepseek_key = dk_cfg.get("apiKey", "") + deepseek_url = dk_cfg.get("baseUrl", self.DEEPSEEK_BASE) + + self._ds_client = OpenAI(api_key=dashscope_key, base_url=dashscope_url) if dashscope_key else None + self._dk_client = OpenAI(api_key=deepseek_key, base_url=deepseek_url) if deepseek_key else None + self._timeout = timeout or self.TIMEOUT self._prompt_tokens = 0 self._completion_tokens = 0 @@ -49,7 +108,7 @@ class LLMClient: @staticmethod def estimate_tokens(text: str) -> int: """Quick token estimate. CJK ≈1.7/token, others ≈3.0/token.""" - cjk = sum(1 for c in text if '一' <= c <= '鿿' or ' ' <= c <= '〿') + cjk = sum(1 for c in text if '\u4e00' <= c <= '\u9fff' or '\u3000' <= c <= '\u303f') other = len(text) - cjk return max(1, int(cjk / 1.7 + other / 3.0)) @@ -58,6 +117,20 @@ class LLMClient: """Fixed estimate for one vision-model image (~500 tokens).""" return 500 + @staticmethod + def _is_vision_model(model: str) -> bool: + return any(kw in model.lower() for kw in LLMClient._VISION_KEYWORDS) + + def _get_client(self, model: str) -> OpenAI: + if self._is_vision_model(model): + if self._ds_client is None: + raise ValueError("DASHSCOPE_API_KEY not set but required for vision model") + return self._ds_client + else: + if self._dk_client is None: + raise ValueError("DEEPSEEK_API_KEY not set but required for text model") + return self._dk_client + def chat( self, model: str, messages: list[dict], *, timeout: int | None = None, response_format: dict | None = None, @@ -65,8 +138,10 @@ class LLMClient: """Send a chat completion request and return the response content. Automatically retries on failure and accumulates token usage. + Routes to DeepSeek for text, DashScope for vision. """ label = f"chat({model})" + client = self._get_client(model) def _call(): t0 = time.time() @@ -74,7 +149,7 @@ class LLMClient: if response_format is not None: kwargs["response_format"] = response_format kwargs["temperature"] = 0 - resp = self._client.chat.completions.create(**kwargs) + resp = client.chat.completions.create(**kwargs) content = resp.choices[0].message.content usg = resp.usage if usg: diff --git a/skills/doc_parser_skill/scripts/image_parser.py b/skills/doc_parser_skill/scripts/image_parser.py index 443b7b4..afbb7eb 100644 --- a/skills/doc_parser_skill/scripts/image_parser.py +++ b/skills/doc_parser_skill/scripts/image_parser.py @@ -1,6 +1,8 @@ import base64 +import json import logging import os +import re from typing import Optional from LLM import LLMClient @@ -8,32 +10,56 @@ from LLM import LLMClient logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- -# Prompts +# Prompt loading # --------------------------------------------------------------------------- -PROMPT_IMAGE = """请分析这张图片,判断类型并输出文字描述。 +def _load_prompt() -> str: + """Load PROMPT_IMAGE from external file, falling back to inline default.""" + prompt_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "prompts") + prompt_path = os.path.join(prompt_dir, "image_prompt.md") + if os.path.isfile(prompt_path): + with open(prompt_path, "r", encoding="utf-8") as f: + return f.read() + + # Fallback inline prompt (nested tree format) + return """请分析这张图片,判断类型并输出文字描述和(如适用)结构化逻辑树。 ## 判断图片类型 -如果是 **流程图 / 架构图 / 状态图 / 时序图 / 活动图**,详细描述: -- 图中所有节点/步骤/状态/组件的名称 -- 所有连线/箭头/转换关系及其方向 -- 所有分支条件、判断逻辑和判断结果 -- 所有文字标注、注释、标签 -- 图的整体结构和逻辑流程 -- 如果图片包含多个子图,拆解描述 +如果是 **流程图 / 架构图 / 状态图 / 时序图 / 活动图**,你需要输出三项内容: +1. 类型标签 +2. **嵌套逻辑树 JSON**(见下方格式) +3. 文字描述 -如果是 **其他类型**(UI原型图 / 界面截图 / 设计稿 / 手机屏幕截图 / 网页截图等),简要描述图片内容。 +如果是 **其他类型**(UI原型图 / 界面截图 / 设计稿 / 手机屏幕截图 / 网页截图等),只输出类型标签和简要文字描述。 + +## 嵌套逻辑树 JSON 格式(仅流程图/架构图/状态图/时序图/活动图需要) + +**核心原则:用嵌套的 `children` 数组表达流程的层级关系,而不是用 id 引用。** + +节点类型:`start`(起始), `end`(结束), `process`(处理/状态), `decision`(判断), `action`(动作) + +非判断节点的 `children` 是子节点数组。`end` 节点无 `children`。 + +判断节点的 `children` 格式: +```json +{"condition": "是", "node": {"id": "n6", "name": "...", "type": "action", "children": [...]}} +``` + +每条从根到 `end` 的路径必须是完整逻辑链。decision 必须穷举所有分支。 +节点 id 使用 "n1", "n2", "n3"... 格式。 ## 输出格式 -**1. 类型标签(单独一行):** type: -**2. 文字描述:** -该图片的详细文字描述。 +logic_tree: +{...} -不要输出 ---YAML--- 分隔符或 YAML 内容,不要添加任何额外的解释或问候语。""" +description: +该图片的详细文字描述。""" + +PROMPT_IMAGE = _load_prompt() # --------------------------------------------------------------------------- @@ -41,7 +67,10 @@ type: # --------------------------------------------------------------------------- class ImageParser: - """Vision LLM wrapper for parsing images (type + description). + """Vision LLM wrapper for parsing images (type + description + logic_tree). + + The nested-tree ``logic_tree`` is stored alongside a backward-compatible + flat representation so downstream consumers are not broken. Usage:: @@ -49,7 +78,7 @@ class ImageParser: result = parser.parse_image("images/img1.png") """ - _VALID_TYPES = {"flowchart", "architecture", "state", "sequence", "activity", "text"} + _VALID_TYPES = {"flowchart", "architecture", "state", "sequence", "activity", "other"} def __init__(self, llm: LLMClient | None = None): self._llm = llm or LLMClient() @@ -59,9 +88,9 @@ class ImageParser: return self._llm.usage def parse_image(self, image_path: str) -> Optional[dict]: - """Parse an image and return its type and description (no YAML IR). + """Parse an image and return its type, description, and optional logic_tree. - Returns ``{type, description}``, or *None* for UI mockups. + Returns ``{type, description, [logic_tree], [logic_tree_nested]}``. """ logger.info("Parsing image: %s", image_path) @@ -84,34 +113,292 @@ class ImageParser: logger.error(str(e)) return {"type": "other", "description": "", "error": str(e)} - parsed = self._parse_type_and_description(content) + parsed = self._parse_response(content) if parsed is None: return None - return {"type": parsed[0], "description": parsed[1]} + ptype, description, logic_tree_nested = parsed + + result: dict = {"type": ptype, "description": description} + if logic_tree_nested is not None: + result["logic_tree_nested"] = logic_tree_nested + result["logic_tree"] = self._flatten_tree(logic_tree_nested) + return result # ---- internals ---------------------------------------------------------- - def _parse_type_and_description(self, content: str) -> Optional[tuple[str, str]]: - """Extract ``(type, description)`` from LLM response. + def _parse_response(self, content: str) -> Optional[tuple[str, str, Optional[dict]]]: + """Extract ``(type, description, logic_tree_nested)`` from LLM response. - Returns *None* for ``[[UI]]`` (skip). + Parses the nested-tree format. Returns *None* for unparseable content. """ content = content.strip() - if content == "[[UI]]" or content.startswith("[[UI]]"): - return None parsed_type = "other" - desc_lines: list[str] = [] - for line in content.splitlines(): - stripped = line.strip() - if (stripped.startswith("type:") or stripped.startswith("类型:")) and parsed_type == "other": - type_val = stripped.split(":", 1)[1].strip().lower() - if type_val in self._VALID_TYPES: - parsed_type = type_val - else: - desc_lines.append(line) + logic_tree = None + description = "" - return parsed_type, "\n".join(desc_lines).strip() + # --- type --- + type_match = re.search(r'(?:type|类型):\s*(\S+)', content) + if type_match: + type_val = type_match.group(1).strip().lower() + if type_val in self._VALID_TYPES: + parsed_type = type_val + + # --- logic_tree (anchored at line start) --- + lt_match = re.search(r'(?m)^logic_tree:\s*', content) + desc_match = re.search(r'(?m)^description:\s*', content) + + if lt_match: + lt_start = lt_match.end() + lt_end = desc_match.start() if desc_match and desc_match.start() > lt_start else len(content) + lt_raw = content[lt_start:lt_end].strip() + + # Try multiple JSON extraction strategies + logic_tree = self._extract_json(lt_raw) + + if logic_tree is not None: + is_valid, err_msg = self._validate_flowchart(logic_tree) + if not is_valid: + logger.warning("Flowchart validation warning: %s", err_msg) + else: + logger.info("Failed to extract logic_tree JSON. Raw block length=%d", len(lt_raw)) + logger.debug("Raw logic_tree block: %s", lt_raw[:500]) + elif parsed_type in self._VALID_TYPES - {"other"}: + logger.info("Diagram type=%s but no logic_tree: in response. Response length=%d", + parsed_type, len(content)) + logger.debug("Raw response (first 500): %s", content[:500]) + + # --- description --- + if desc_match: + description = content[desc_match.end():].strip() + else: + desc = content + if type_match: + desc = desc[type_match.end():] + desc = re.sub(r'(?m)^logic_tree:\s*\{.*?\}\s*', '', desc, flags=re.DOTALL) + description = desc.strip() + + return parsed_type, description, logic_tree + + @staticmethod + def _validate_flowchart(tree: dict) -> tuple[bool, str]: + """Validate a nested flowchart tree structure. + + Returns ``(is_valid, error_message)``. Non-fatal: returns ``False`` + with a warning message but the tree is still kept. + """ + if not isinstance(tree, dict): + return False, "logic_tree is not a dict" + + seen_ids: set[str] = set() + + def _walk(node: dict, depth: int = 0) -> tuple[bool, str]: + if depth > 20: + return False, f"Tree too deep (>20) at node {node.get('id', '?')}" + + nid = node.get("id", "") + if not nid: + return False, "Node missing 'id' field" + if not isinstance(nid, str): + return False, f"Node id must be string, got {type(nid).__name__}" + if nid in seen_ids: + return False, f"Duplicate node id: {nid}" + seen_ids.add(nid) + + ntype = node.get("type", "") + if ntype not in ("start", "end", "process", "decision", "action"): + return False, f"Unknown node type '{ntype}' at {nid}" + + if ntype == "end": + if "children" in node: + return False, f"End node {nid} should not have children" + return True, "" + + children = node.get("children") + if not children: + if ntype != "end": + return False, f"Non-end node {nid} ({ntype}) has no children" + return True, "" + + if not isinstance(children, list): + return False, f"children of {nid} is not a list" + + if ntype == "decision": + for child in children: + if not isinstance(child, dict): + return False, f"decision child of {nid} is not a dict" + if "condition" not in child: + return False, f"decision child of {nid} missing 'condition'" + if "node" not in child: + return False, f"decision child of {nid} missing 'node'" + ok, err = _walk(child["node"], depth + 1) + if not ok: + return False, err + else: + for child in children: + if not isinstance(child, dict): + return False, f"child of {nid} is not a dict" + ok, err = _walk(child, depth + 1) + if not ok: + return False, err + + return True, "" + + return _walk(tree) + + @staticmethod + def _flatten_tree(tree: dict) -> dict: + """Convert a nested flowchart tree into the legacy flat-nodes format. + + This preserves backward compatibility with downstream consumers + (conflict_detection_skill, ir_generator) that expect the flat format. + """ + nodes: list[dict] = [] + root_name = "" + + def _collect(node: dict): + nonlocal root_name + nid = node.get("id", "") + ntype = node.get("type", "") + name = node.get("name", "") + + if root_name == "" and "children" in node: + root_name = name + + if ntype == "decision": + branches = [] + for child in node.get("children", []): + branches.append({ + "value": child.get("condition", ""), + "target": child["node"].get("id", ""), + }) + _collect(child["node"]) + nodes.append({ + "id": nid, + "type": ntype, + "condition": name, + "branches": branches, + }) + elif ntype in ("action", "process", "state"): + nodes.append({ + "id": nid, + "type": ntype, + "description": name, + }) + for child in node.get("children", []): + _collect(child) + elif ntype == "start": + nodes.append({ + "id": nid, + "type": ntype, + "description": name, + }) + for child in node.get("children", []): + _collect(child) + # end nodes are collected but have no children + + _collect(tree) + + # Add end nodes from the nested tree + ends: list[dict] = [] + + def _collect_ends(node: dict): + if node.get("type") == "end": + ends.append({ + "id": node.get("id", ""), + "type": "end", + "description": node.get("name", ""), + }) + elif "children" in node: + for child in node.get("children", []): + if isinstance(child, dict): + if "node" in child: + _collect_ends(child["node"]) + else: + _collect_ends(child) + + _collect_ends(tree) + nodes.extend(ends) + + return {"root": root_name, "nodes": nodes} + + @staticmethod + def extract_paths(tree: dict) -> list[list[dict]]: + """Extract all root-to-leaf paths from a nested flowchart tree. + + Each path is a list of node dicts (each with id, name, type). + Returns a list of paths useful for human review and LLM verification. + """ + paths: list[list[dict]] = [] + + def _walk(node: dict, current_path: list[dict]): + entry = {"id": node.get("id", ""), "name": node.get("name", ""), "type": node.get("type", "")} + new_path = current_path + [entry] + + if node.get("type") == "end": + paths.append(new_path) + return + + children = node.get("children", []) + if not children: + paths.append(new_path) + return + + if node.get("type") == "decision": + for child in children: + _walk(child["node"], new_path) + else: + for child in children: + _walk(child, new_path) + + _walk(tree, []) + return paths + + @staticmethod + def paths_to_text(paths: list[list[dict]]) -> str: + """Render extracted paths as human-readable text for review.""" + lines: list[str] = [] + for i, path in enumerate(paths, 1): + steps = [] + for node in path: + if node["type"] == "decision": + steps.append(f"[判断] {node['name']}") + elif node["type"] == "end": + steps.append(f"[结束] {node['name']}") + else: + steps.append(f"[{node['type']}] {node['name']}") + lines.append(f"路径 {i}: {' -> '.join(steps)}") + return "\n".join(lines) + + @staticmethod + def _extract_json(text: str) -> Optional[dict]: + """Try multiple strategies to extract a JSON object from text. + + Returns the parsed dict or None. + """ + # Strategy 1: first { ... } pair (simple regex) + json_match = re.search(r'\{.*\}', text, re.DOTALL) + if json_match: + try: + return json.loads(json_match.group()) + except json.JSONDecodeError: + pass + + # Strategy 2: find balanced braces + start = text.find("{") + if start >= 0: + depth = 0 + for i in range(start, len(text)): + if text[i] == "{": + depth += 1 + elif text[i] == "}": + depth -= 1 + if depth == 0: + try: + return json.loads(text[start:i + 1]) + except json.JSONDecodeError: + break + return None @staticmethod def _mime_type(image_path: str) -> str: diff --git a/skills/doc_parser_skill/scripts/verify_flowchart.py b/skills/doc_parser_skill/scripts/verify_flowchart.py new file mode 100644 index 0000000..2834991 --- /dev/null +++ b/skills/doc_parser_skill/scripts/verify_flowchart.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python3 +"""Verify flowchart logic trees for structural correctness and consistency. + +Usage:: + + python verify_flowchart.py [--llm] [--output-report REPORT.md] + +Performs three levels of checks: + +1. **Structural validation** — tree integrity, node uniqueness, leaf types +2. **Path extraction** — renders all root-to-leaf paths as readable text +3. **LLM consistency check** (opt-in with ``--llm``) — compares extracted paths + against the original text description for logical inconsistencies + +Outputs PASS/FAIL and a detailed report. +""" + +import argparse +import json +import logging +import os +import sys + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from image_parser import ImageParser +from LLM import LLMClient + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Prompt for LLM path-vs-description consistency check +# --------------------------------------------------------------------------- + +PROMPT_VERIFY_PATHS = """你是一个流程图审核专家。以下内容来自同一张流程图的解析结果: + +## 流程图路径(从嵌套逻辑树提取的所有根到叶路径) +``` +{paths_text} +``` + +## 原始文字描述 +``` +{description} +``` + +## 你的任务 +逐条检查每条路径是否与文字描述一致。重点关注: + +1. **分支方向错误**:路径中的判断分支走向是否与文字描述矛盾? + 例如:文字说"满足条件后退出",但路径中"是"分支走向了"不受限"。 +2. **缺失步骤**:路径中是否缺少文字描述中提到的关键步骤? +3. **冗余步骤**:路径中是否包含文字描述未提及的多余步骤? +4. **条件颠倒**:判断条件的"是/否"分支是否与文字描述相反? + +## 输出格式 + +如果**所有路径一致**,只输出: +``` +[[PATHS_CONSISTENT]] +``` + +如果**发现不一致**,输出 JSON 数组: +```json +[ + {{ + "path_index": 1, + "issue_type": "branch_error|missing_step|redundant_step|condition_reversed", + "severity": "high|medium|low", + "description": "用中文说明具体问题" + }} +] +``` + +注意:输出必须是严格合法的 JSON 数组,不要有尾随逗号,不要包含代码块包裹符号。 +""" + + +# --------------------------------------------------------------------------- +# Core verification logic +# --------------------------------------------------------------------------- + +def verify_parsed_json(parsed_path: str, *, use_llm: bool = False) -> dict: + """Load _parsed.json and verify all flowchart logic trees. + + Returns a report dict with keys: + - total_flowcharts: int + - passed: int + - failed: int + - results: list of per-flowchart results + """ + with open(parsed_path, "r", encoding="utf-8") as f: + data = json.load(f) + + image_analysis = data.get("image_analysis", []) + flowcharts = [img for img in image_analysis if img.get("type") == "flowchart"] + + report = { + "total_flowcharts": len(flowcharts), + "passed": 0, + "failed": 0, + "results": [], + } + + llm = LLMClient() if use_llm else None + + for img in flowcharts: + rid = img.get("rid", "unknown") + logger.info("Verifying flowchart: rid=%s", rid) + + result = _verify_single(img, llm) + report["results"].append(result) + + if result["structural_ok"] and (not use_llm or result.get("llm_ok", True)): + report["passed"] += 1 + else: + report["failed"] += 1 + + return report + + +def verify_flowchart_file(filepath: str, *, use_llm: bool = False) -> dict: + """Load a standalone flowchart JSON file and verify it.""" + with open(filepath, "r", encoding="utf-8") as f: + tree = json.load(f) + + img = {"logic_tree_nested": tree, "description": "", "rid": os.path.basename(filepath)} + llm = LLMClient() if use_llm else None + result = _verify_single(img, llm) + + return { + "total_flowcharts": 1, + "passed": 1 if result["structural_ok"] else 0, + "failed": 0 if result["structural_ok"] else 1, + "results": [result], + } + + +def _verify_single(img: dict, llm: LLMClient | None) -> dict: + """Verify a single flowchart image analysis entry.""" + rid = img.get("rid", "unknown") + description = img.get("description", "").strip() + + # Try nested format first, fall back to flat format + tree = img.get("logic_tree_nested") or img.get("logic_tree") + if tree is None: + return { + "rid": rid, + "structural_ok": False, + "errors": ["No logic_tree found"], + "paths_text": "", + "llm_issues": [], + } + + # Check if it's the new nested format or old flat format + is_nested = "children" in tree and isinstance(tree.get("children"), list) + + # --- Level 1: Structural validation --- + structural_ok = True + errors: list[str] = [] + + if is_nested: + ok, err = ImageParser._validate_flowchart(tree) + if not ok: + structural_ok = False + errors.append(f"Structure: {err}") + + # Extract paths + paths = ImageParser.extract_paths(tree) + paths_text = ImageParser.paths_to_text(paths) + errors.append(f"Path count: {len(paths)}") + else: + # Old flat format — basic check + nodes = tree.get("nodes", []) + ids = [n.get("id", "") for n in nodes] + if len(ids) != len(set(ids)): + structural_ok = False + errors.append("Structure: duplicate node ids in flat format") + + # Build simple path-like text for flat format + paths_text = _flat_to_text(tree) + + # --- Level 2: Path count sanity check --- + if is_nested and len(paths) == 0: + structural_ok = False + errors.append("No paths extracted from tree") + + # --- Level 3: LLM consistency check --- + llm_issues: list[dict] = [] + llm_ok = True + if llm and description and paths_text: + prompt = PROMPT_VERIFY_PATHS.format( + paths_text=paths_text, + description=description, + ) + try: + raw = llm.chat( + model=LLMClient.TEXT_MODEL, + messages=[{"role": "user", "content": prompt}], + ) + llm_issues = _parse_llm_issues(raw) + if llm_issues: + llm_ok = False + errors.append(f"LLM found {len(llm_issues)} issue(s)") + except RuntimeError as e: + errors.append(f"LLM check failed: {e}") + + return { + "rid": rid, + "structural_ok": structural_ok, + "errors": errors, + "paths_text": paths_text, + "llm_ok": llm_ok, + "llm_issues": llm_issues, + } + + +def _flat_to_text(tree: dict) -> str: + """Build path-like text from old flat-format logic_tree.""" + nodes = tree.get("nodes", []) + root = tree.get("root", "") + lines = [f"Root: {root}"] + + node_map = {n["id"]: n for n in nodes} + + def _trace(node_id: str, visited: set, path: list[str]) -> list[str]: + if node_id in visited: + path.append(f"[循环] {node_id}") + return path + visited.add(node_id) + node = node_map.get(node_id) + if node is None: + path.append(f"[缺失] {node_id}") + return path + ntype = node.get("type", "") + if ntype == "decision": + cond = node.get("condition", "") + for b in node.get("branches", []): + val = b.get("value", "") + tgt = b.get("target", "") + new_path = path + [f"[判断] {cond} → {val}"] + _trace(tgt, visited.copy(), new_path) + elif ntype == "end": + path.append(f"[结束] {node.get('description', '')}") + lines.append(" -> ".join(path)) + else: + path.append(f"[{ntype}] {node.get('description', '')}") + # Flat format doesn't have explicit children for non-decision nodes + # so we can't trace further + lines.append(" -> ".join(path)) + return path + + # Try to find start nodes + starts = [n for n in nodes if n.get("type") == "start"] + if starts: + for s in starts: + _trace(s["id"], set(), []) + else: + lines.append("(Cannot trace: no start node in flat format)") + + return "\n".join(lines) + + +def _parse_llm_issues(content: str) -> list[dict]: + """Parse LLM response for path consistency issues.""" + stripped = content.strip() + if "[[PATHS_CONSISTENT]]" in stripped: + return [] + + # Remove markdown code fences + if "```json" in stripped: + stripped = stripped.split("```json", 1)[1] + if "```" in stripped: + stripped = stripped.split("```", 1)[0] + elif "```" in stripped: + stripped = stripped.split("```", 1)[1] + if "```" in stripped: + stripped = stripped.split("```", 1)[0] + + stripped = stripped.strip() + if not stripped: + return [] + + try: + issues = json.loads(stripped) + if isinstance(issues, list): + return issues + return [] + except json.JSONDecodeError: + logger.debug("Failed to parse LLM issues: %s", stripped[:200]) + return [] + + +# --------------------------------------------------------------------------- +# Report rendering +# --------------------------------------------------------------------------- + +def print_report(report: dict) -> str: + """Print a human-readable verification report and return it as a string.""" + lines: list[str] = [] + lines.append("=" * 60) + lines.append("流程图校验报告") + lines.append("=" * 60) + lines.append(f"流程图总数: {report['total_flowcharts']}") + lines.append(f"通过: {report['passed']}") + lines.append(f"失败: {report['failed']}") + + overall = "PASS" if report["failed"] == 0 else "FAIL" + lines.append(f"总体结果: {overall}") + lines.append("") + + for i, r in enumerate(report["results"], 1): + rid = r["rid"] + status = "[PASS]" if r["structural_ok"] else "[FAIL]" + lines.append(f"[{i}] rid={rid} {status}") + for err in r.get("errors", []): + lines.append(f" - {err}") + + if r.get("paths_text"): + lines.append(" 路径:") + for path_line in r["paths_text"].split("\n"): + lines.append(f" {path_line}") + + llm_issues = r.get("llm_issues", []) + if llm_issues: + lines.append(" LLM发现的问题:") + for issue in llm_issues: + lines.append(f" [{issue.get('severity', '?')}] {issue.get('description', '')}") + lines.append("") + + report_text = "\n".join(lines) + print(report_text) + return report_text + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Verify flowchart logic trees for correctness.", + ) + parser.add_argument( + "input", metavar="FILE", + help="Path to _parsed.json or standalone flowchart JSON", + ) + parser.add_argument( + "--llm", action="store_true", + help="Run LLM consistency check (compares paths against text description)", + ) + parser.add_argument( + "--output-report", metavar="PATH", + help="Save verification report to a file", + ) + args = parser.parse_args() + + # Determine input type + with open(args.input, "r", encoding="utf-8") as f: + data = json.load(f) + + if "image_analysis" in data: + report = verify_parsed_json(args.input, use_llm=args.llm) + else: + report = verify_flowchart_file(args.input, use_llm=args.llm) + + report_text = print_report(report) + + if args.output_report: + with open(args.output_report, "w", encoding="utf-8") as f: + f.write(report_text) + logger.info("Report saved: %s", args.output_report) + + # Exit code: 0 for PASS, 1 for FAIL + if report["failed"] > 0: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/skills/ir_generation_skill/.gitignore b/skills/ir_generation_skill/.gitignore new file mode 100644 index 0000000..0efba25 --- /dev/null +++ b/skills/ir_generation_skill/.gitignore @@ -0,0 +1,9 @@ +# Generated output +output/ + +# Python +__pycache__/ +*.pyc + +# Console log +Console output.txt diff --git a/skills/ir_generation_skill/config.py b/skills/ir_generation_skill/config.py new file mode 100644 index 0000000..6f17ef2 --- /dev/null +++ b/skills/ir_generation_skill/config.py @@ -0,0 +1,137 @@ +""" +Shared configuration for the IR Generation pipeline. +Reads API keys from a secrets.yaml file, falling back to environment variables. +""" + +import os +import json +import yaml + +# ---- Paths ---- +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +WORKSPACE_DIR = os.path.dirname(BASE_DIR) +DOC_PARSER_OUTPUT = os.path.join(WORKSPACE_DIR, "doc_parser_skill", "output") +PROMPTS_DIR = os.path.join(BASE_DIR, "prompts") +TESTS_DIR = os.path.join(BASE_DIR, "tests") +OUTPUT_DIR = os.path.join(BASE_DIR, "output") + +# Input file (the parsed PRD JSON) +_DEFAULT_INPUT = os.path.join( + DOC_PARSER_OUTPUT, + "车机娱乐系统禁止功能文档_脱敏 v0.9_v2_updated.json", +) +INPUT_JSON = os.environ.get("IR_INPUT_JSON", _DEFAULT_INPUT) + + +def set_input_file(path: str) -> None: + """Override the default input JSON path.""" + global INPUT_JSON + INPUT_JSON = path + +# Secrets file (shared with workspace-document-analyzer) +# .openclaw/workspace/skills/ir_generation_new_skill -> .openclaw/workspace-document-analyzer +OPENCLAW_HOME = os.path.dirname(os.path.dirname(WORKSPACE_DIR)) +SECRETS_YAML = os.path.join( + OPENCLAW_HOME, "workspace-document-analyzer", "config", "secrets.yaml", +) + +# Intermediate outputs +SEMANTIC_INDEX_R1_JSON = os.path.join(OUTPUT_DIR, "semantic_index_r1.json") +SEMANTIC_INDEX_R2_JSON = os.path.join(OUTPUT_DIR, "semantic_index_r2.json") +SEMANTIC_INDEX_R3_JSON = os.path.join(OUTPUT_DIR, "semantic_index_r3.json") +SEMANTIC_INDEX_JSON = os.path.join(OUTPUT_DIR, "semantic_index.json") # merged final +IR_FRAGMENTS_JSON = os.path.join(OUTPUT_DIR, "ir_fragments.json") +PATH_ENUM_JSON = os.path.join(OUTPUT_DIR, "path_enumeration.json") +IR_AUTOCOMPLETE_FRAGMENTS_JSON = os.path.join(OUTPUT_DIR, "ir_autocomplete_fragments.json") + +# Final deliverables (placed in doc_parser output per spec) +IR_FINAL_JSON = os.path.join(DOC_PARSER_OUTPUT, "ir_final.json") +IR_AUDIT_REPORT_MD = os.path.join(DOC_PARSER_OUTPUT, "ir_audit_report.md") + +# ---- LLM API ---- +# Choose provider: "deepseek" | "dashscope" +LLM_PROVIDER = os.environ.get("IR_PROVIDER", "deepseek") + +# Model names per provider +PROVIDER_MODELS = { + "deepseek": os.environ.get("IR_MODEL", "deepseek-v4-flash"), + "dashscope": os.environ.get("IR_MODEL", "qwen-max"), +} +MODEL_NAME = PROVIDER_MODELS.get(LLM_PROVIDER, PROVIDER_MODELS["deepseek"]) + +# Maximum tokens for LLM responses +MAX_TOKENS = int(os.environ.get("IR_MAX_TOKENS", "16000")) +TEMPERATURE = float(os.environ.get("IR_TEMPERATURE", "0.1")) + +# ---- Iteration & Quality ---- +MAX_RETRIES_PER_STAGE = int(os.environ.get("IR_MAX_RETRIES", "3")) +COVERAGE_TARGET = float(os.environ.get("IR_COVERAGE_TARGET", "0.95")) + +# Stage 1 ensemble temperatures (parallel multi-temperature generation) +ENSEMBLE_TEMPERATURES = [ + float(os.environ.get("IR_ENSEMBLE_T1", "0.0")), + float(os.environ.get("IR_ENSEMBLE_T2", "0.3")), + float(os.environ.get("IR_ENSEMBLE_T3", "0.7")), +] + + +def _load_secrets() -> dict[str, dict[str, str]]: + """Load provider credentials from secrets.yaml. + + Returns a dict like: {"deepseek": {"apiKey": "...", "baseUrl": "..."}, ...} + """ + if os.path.isfile(SECRETS_YAML): + with open(SECRETS_YAML, "r", encoding="utf-8") as f: + return yaml.safe_load(f) or {} + return {} + + +def _get_provider_config(provider: str) -> dict[str, str]: + """Get {apiKey, baseUrl} for a provider from secrets, with env-var fallback.""" + secrets = _load_secrets() + entry = secrets.get(provider, {}) + + env_prefix = provider.upper() + api_key = ( + os.environ.get(f"{env_prefix}_API_KEY") + or entry.get("apiKey", "") + ) + base_url = ( + os.environ.get(f"{env_prefix}_BASE_URL") + or entry.get("baseUrl", "https://api.deepseek.com/v1") + ) + + if not api_key: + raise RuntimeError( + f"No API key found for provider '{provider}'. " + f"Check {SECRETS_YAML} or set {env_prefix}_API_KEY." + ) + return {"apiKey": api_key, "baseUrl": base_url} + + +def llm_client(): + """Return an OpenAI-compatible client configured from secrets.yaml.""" + from openai import OpenAI + + cfg = _get_provider_config(LLM_PROVIDER) + return OpenAI(base_url=cfg["baseUrl"], api_key=cfg["apiKey"]) + + +def load_input_document(path: str | None = None) -> dict: + """Load the parsed PRD JSON document.""" + path = path or INPUT_JSON + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +def save_json(data, path: str) -> None: + """Save data as formatted JSON.""" + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + +def load_json(path: str) -> dict: + """Load a JSON file.""" + with open(path, "r", encoding="utf-8") as f: + return json.load(f) diff --git a/skills/ir_generation_skill/ensemble_merge.py b/skills/ir_generation_skill/ensemble_merge.py new file mode 100644 index 0000000..82f8013 --- /dev/null +++ b/skills/ir_generation_skill/ensemble_merge.py @@ -0,0 +1,593 @@ +""" +Deterministic ensemble merge for semantic index generation. + +All functions are pure Python with zero LLM calls. Fully testable with mock data. + +Cross-references N semantic_index outputs (generated with different temperatures) +and produces a single merged index with confidence scores. + +Used by: step1_semantic_index.py +Tested by: tests/test_ensemble_merge.py +""" + +from collections import defaultdict +from difflib import SequenceMatcher + + +# ============================================================================= +# Concept Name Similarity +# ============================================================================= + +def concept_name_similarity(name_a: str, name_b: str) -> float: + """Compute similarity between two concept names for cross-version matching. + + Strategy (in order of precedence): + 1. Exact string match -> 1.0 + 2. Substring containment (one is a substring of the other) -> 0.9 + 3. SequenceMatcher ratio on character sequences -> 0.0-1.0 + + Returns: + float in [0.0, 1.0] where >= 0.7 means "likely the same concept". + """ + if name_a == name_b: + return 1.0 + + # Substring containment: one name is contained in the other + if name_a in name_b or name_b in name_a: + # Only count as similar if they're of comparable length + # (avoid matching "国内" with "国内行车娱乐限制") + len_ratio = min(len(name_a), len(name_b)) / max(len(name_a), len(name_b)) + if len_ratio >= 0.5: + return 0.85 + 0.05 * len_ratio # range 0.875-0.90 + return 0.55 # too different in length → below threshold + + return SequenceMatcher(None, name_a, name_b).ratio() + + +# ============================================================================= +# Concept Clustering & Merging +# ============================================================================= + +def cluster_concepts( + all_concepts_lists: list[list[dict]], + similarity_threshold: float = 0.7, +) -> list[list[tuple[int, dict]]]: + """Group concepts across ensemble versions by name similarity. + + Uses greedy single-pass clustering: for each concept, find the best-matching + existing cluster. If max similarity >= threshold, add to it; otherwise, + create a new cluster. + + Args: + all_concepts_lists: List of concept lists, one per ensemble version. + all_concepts_lists[i] = concepts from version i. + similarity_threshold: Minimum name similarity to join a cluster. + + Returns: + List of clusters. Each cluster is list of (version_idx, concept_dict). + """ + clusters = [] # type: list[list[tuple[int, dict]]] + + for version_idx, concepts in enumerate(all_concepts_lists): + for c in concepts: + name = c.get("name", "") + if not name: + continue + + best_cluster = None + best_sim = 0.0 + + for cluster in clusters: + # Compare against the first member of the cluster (seed) + seed_name = cluster[0][1].get("name", "") + sim = concept_name_similarity(name, seed_name) + if sim > best_sim: + best_sim = sim + best_cluster = cluster + + if best_cluster is not None and best_sim >= similarity_threshold: + best_cluster.append((version_idx, c)) + else: + clusters.append([(version_idx, c)]) + + return clusters + + +def merge_concept_cluster( + cluster: list[tuple[int, dict]], + total_versions: int, +) -> tuple[dict, str]: + """Merge a single cluster of matched concepts into one concept dict. + + Rules: + - name: Longest name (most specific). Tie-break by lower version_idx. + - aliases: Union of all aliases across versions. + - defined_in: Union of all defined_in across versions. + - parent: Most common non-null parent (voting). Tie-break by lower version_idx. + + Returns: + (merged_concept_dict, confidence_level) where confidence is "high"/"medium"/"low". + """ + if not cluster: + return {}, "low" + + # --- name: longest (most specific) --- + best_name = "" + best_name_len = 0 + for v_idx, c in cluster: + n = c.get("name", "") + if len(n) > best_name_len: + best_name = n + best_name_len = len(n) + elif len(n) == best_name_len and v_idx < cluster[0][0]: # lower version idx + best_name = n + + # --- aliases: union --- + aliases = set() + for _, c in cluster: + for a in c.get("aliases", []): + aliases.add(a) + + # --- defined_in: union --- + defined_in = set() + for _, c in cluster: + for d in c.get("defined_in", []): + defined_in.add(d) + + # --- parent: most common non-null parent (vote) --- + parent_votes = defaultdict(int) + for v_idx, c in cluster: + p = c.get("parent") + if p is not None: + parent_votes[p] += 1 + + if parent_votes: + best_parent = max(parent_votes, key=lambda p: (parent_votes[p], -1)) + else: + best_parent = None + + # --- confidence --- + versions_present = len({v_idx for v_idx, _ in cluster}) + confidence = compute_confidence_versions(versions_present, total_versions, + any(v_idx == 0 for v_idx, _ in cluster)) + + merged = { + "name": best_name, + "aliases": sorted(aliases), + "defined_in": sorted(defined_in), + "parent": best_parent, + "confidence": confidence, + } + return merged, confidence + + +# ============================================================================= +# Unit Similarity Functions +# ============================================================================= + +def _collect_logic_tree_nodes(unit: dict) -> set[str]: + """Extract the flattened set of all logic tree node IDs from a function_unit.""" + nodes = set() + for src in unit.get("sources", []): + if src.get("type") == "logic_tree": + nodes.update(src.get("logic_tree_nodes", [])) + return nodes + + +def unit_node_jaccard(unit_a: dict, unit_b: dict) -> float: + """Compute Jaccard similarity on logic tree node sets between two units. + + Jaccard(A, B) = |A ∩ B| / |A ∪ B|. Returns 0.0 if both have no nodes. + """ + nodes_a = _collect_logic_tree_nodes(unit_a) + nodes_b = _collect_logic_tree_nodes(unit_b) + + if not nodes_a and not nodes_b: + return 0.0 + if not nodes_a or not nodes_b: + return 0.0 + + intersection = nodes_a & nodes_b + union = nodes_a | nodes_b + return len(intersection) / len(union) + + +def path_similarity(path_a: list[str], path_b: list[str]) -> float: + """Compute similarity between two path arrays. + + Hybrid approach: + - Sequential similarity (order-aware): SequenceMatcher on joined strings. + - Set similarity (order-independent): Jaccard on path element sets. + - Final score: 0.5 * seq_sim + 0.5 * set_sim + + Returns: + float in [0.0, 1.0]. + """ + if not path_a and not path_b: + return 1.0 + if not path_a or not path_b: + return 0.0 + + # Sequential similarity + joined_a = "|".join(path_a) + joined_b = "|".join(path_b) + seq_sim = SequenceMatcher(None, joined_a, joined_b).ratio() + + # Set similarity + set_a = set(path_a) + set_b = set(path_b) + set_sim = len(set_a & set_b) / len(set_a | set_b) + + return 0.5 * seq_sim + 0.5 * set_sim + + +def unit_similarity(unit_a: dict, unit_b: dict) -> float: + """Combined similarity between two function_units. + + Weighted combination: + - 0.6 * unit_node_jaccard (primary signal: same logic tree nodes = same rule) + - 0.4 * path_similarity (secondary signal: semantic agreement) + + Returns: + float in [0.0, 1.0]. >= 0.5 means "likely the same function_unit". + """ + return 0.6 * unit_node_jaccard(unit_a, unit_b) + 0.4 * path_similarity( + unit_a.get("path", []), unit_b.get("path", []) + ) + + +# ============================================================================= +# Function Unit Clustering & Merging +# ============================================================================= + +def cluster_function_units( + all_units_lists: list[list[dict]], + similarity_threshold: float = 0.5, +) -> list[list[tuple[int, dict]]]: + """Group function_units across ensemble versions by content similarity. + + Lowest-temperature versions are processed first (most stable → cluster seeds). + Higher-temperature variants join existing clusters if similar enough. + + Args: + all_units_lists: List of unit lists, one per ensemble version. + similarity_threshold: Minimum unit_similarity to join a cluster. + + Returns: + List of clusters. Each cluster is list of (version_idx, unit_dict). + """ + clusters = [] # type: list[list[tuple[int, dict]]] + + for version_idx, units in enumerate(all_units_lists): + for unit in units: + best_cluster = None + best_sim = 0.0 + + for cluster in clusters: + # Compare against all members already in the cluster + cluster_sim = max( + unit_similarity(unit, existing_unit) + for (_, existing_unit) in cluster + ) + if cluster_sim > best_sim: + best_sim = cluster_sim + best_cluster = cluster + + if best_cluster is not None and best_sim >= similarity_threshold: + best_cluster.append((version_idx, unit)) + else: + clusters.append([(version_idx, unit)]) + + return clusters + + +def pick_best_representative( + cluster: list[tuple[int, dict]], +) -> dict: + """Select the best function_unit from a cluster as the merged representative. + + Scoring formula (all normalized to [0, 1]): + - 0.35: Node count (more logic_tree_nodes = more complete trace) + - 0.25: Source count (more sources = more evidence) + - 0.20: Description length (longer = more detail, capped at 500 chars) + - 0.20: Temperature rank (lower version_idx = lower temp = more stable) + + Returns a deep copy of the winning unit dict. + """ + if not cluster: + return {} + + # Compute max values for normalization + max_nodes = max( + len(_collect_logic_tree_nodes(unit)) for _, unit in cluster + ) + max_sources = max( + len(unit.get("sources", [])) for _, unit in cluster + ) + max_desc_len = max( + len(unit.get("description", "")) for _, unit in cluster + ) + max_version_idx = max(v_idx for v_idx, _ in cluster) + num_versions = len(cluster) + + def score(v_idx: int, unit: dict) -> float: + nodes = len(_collect_logic_tree_nodes(unit)) + sources = len(unit.get("sources", [])) + desc_len = min(len(unit.get("description", "")), 500) + temp_rank = 1.0 - (v_idx / max(num_versions, max_version_idx + 1)) + + return ( + 0.35 * (nodes / max(1, max_nodes)) + + 0.25 * (sources / max(1, max_sources)) + + 0.20 * (desc_len / max(1, max_desc_len)) + + 0.20 * temp_rank + ) + + best = max(cluster, key=lambda x: score(x[0], x[1])) + return dict(best[1]) # deep-ish copy (1 level) + + +def merge_unit_sources( + cluster: list[tuple[int, dict]], +) -> list[dict]: + """Union all sources from units in a cluster, deduplicating by (type, image_id, section). + + When the same source key appears in multiple versions, keeps the one with + the most logic_tree_nodes. + """ + # Group by dedup key + source_groups = defaultdict(list) + + for v_idx, unit in cluster: + for src in unit.get("sources", []): + # Build a dedup key + src_type = src.get("type", "") + if src_type == "logic_tree": + key = ("logic_tree", src.get("image_id", "")) + else: + key = (src_type, src.get("section", ""), src.get("row", "")) + + source_groups[key].append(src) + + # Pick best per group + result = [] + for key, sources in source_groups.items(): + # Pick the source with the most logic_tree_nodes (if any) + best = max(sources, key=lambda s: len(s.get("logic_tree_nodes", []))) + result.append(dict(best)) + + return result + + +def compute_confidence_versions( + versions_present: int, + total_versions: int, + includes_lowest_temp: bool = False, +) -> str: + """Compute 3-level confidence based on cross-version agreement. + + - "high": Appears in all versions, OR >= 2/3 with lowest-temp version (T=0.0). + - "medium": Appears in >= half the versions but not all. + - "low": Appears in fewer than half (singleton in ensemble). + + Args: + versions_present: Number of versions this item appeared in. + total_versions: Total number of ensemble versions. + includes_lowest_temp: Whether the item appeared in the T=0.0 version. + """ + ratio = versions_present / total_versions + + if ratio >= 1.0: + return "high" + if ratio >= 0.5 and includes_lowest_temp: + return "high" + if ratio >= 0.5: + return "medium" + return "low" + + +def ensemble_merge_concepts( + all_concepts_lists: list[list[dict]], +) -> list[dict]: + """Merge concepts across all ensemble versions. + + Returns: + List of merged concept dicts, each with added "confidence" field. + """ + total = len(all_concepts_lists) + clusters = cluster_concepts(all_concepts_lists) + merged = [] + seen_names = set() + + for cluster in clusters: + concept, confidence = merge_concept_cluster(cluster, total) + name = concept.get("name", "") + if name and name not in seen_names: + concept["ensemble_support"] = f"{len({v for v, _ in cluster})}/{total}" + merged.append(concept) + seen_names.add(name) + + # Sort: high confidence first, then by name + conf_order = {"high": 0, "medium": 1, "low": 2} + merged.sort(key=lambda c: (conf_order.get(c.get("confidence", "low"), 3), c.get("name", ""))) + + # Validate and fix parent references + merged = _validate_concept_parents(merged) + + return merged + + +def _validate_concept_parents(concepts: list[dict]) -> list[dict]: + """Post-merge: validate that every concept's parent exists in the list. + + Strategy for dangling parents: + 1. Fuzzy match (concept_name_similarity >= 0.7) → fix reference + 2. No match → set parent to null, downgrade confidence to "low" + """ + concept_names = {c["name"] for c in concepts} + conf_order = {"high": 0, "medium": 1, "low": 2} + + for c in concepts: + parent = c.get("parent") + if parent is None: + continue + if parent in concept_names: + continue + + # Dangling parent — try fuzzy match + best_match = None + best_sim = 0.0 + for name in concept_names: + sim = concept_name_similarity(parent, name) + if sim > best_sim: + best_sim = sim + best_match = name + + if best_match and best_sim >= 0.7: + c["parent"] = best_match + # Downgrade if match was fuzzy (not exact) + if best_sim < 1.0: + current_conf = c.get("confidence", "low") + c["confidence"] = _downgrade_confidence(current_conf) + else: + c["parent"] = None + c["confidence"] = _downgrade_confidence(c.get("confidence", "low")) + + # Re-sort after confidence changes + concepts.sort(key=lambda c: (conf_order.get(c.get("confidence", "low"), 3), c.get("name", ""))) + return concepts + + +def _downgrade_confidence(current: str) -> str: + """Drop confidence one level.""" + if current == "high": + return "medium" + return "low" + + +def ensemble_merge_function_units( + all_units_lists: list[list[dict]], +) -> list[dict]: + """Merge function_units across all ensemble versions. + + 1. Cluster units across versions. + 2. For each cluster: pick best, merge sources, compute confidence. + 3. Reassign stable unit_ids: FU-ENS-001, FU-ENS-002, ... + + Returns: + List of merged function_unit dicts with added "confidence", + "ensemble_support", "source_versions" fields. + """ + total = len(all_units_lists) + clusters = cluster_function_units(all_units_lists) + + merged = [] + for cluster in clusters: + # Pick best representative + best = pick_best_representative(cluster) + + # Merge sources from all cluster members + best["sources"] = merge_unit_sources(cluster) + + # Compute confidence + versions_present = len({v_idx for v_idx, _ in cluster}) + includes_t0 = any(v_idx == 0 for v_idx, _ in cluster) + confidence = compute_confidence_versions( + versions_present, total, includes_t0 + ) + + best["confidence"] = confidence + best["ensemble_support"] = f"{versions_present}/{total}" + best["source_versions"] = versions_present + + merged.append(best) + + # Sort by confidence desc, then by unit_id + conf_order = {"high": 0, "medium": 1, "low": 2} + merged.sort(key=lambda u: (conf_order.get(u.get("confidence", "low"), 3), + u.get("unit_id", ""))) + + # Reassign stable unit_ids + for i, unit in enumerate(merged): + # Preserve original unit_id for traceability + if "original_unit_id" not in unit: + unit["original_unit_id"] = unit.get("unit_id", "") + unit["unit_id"] = f"FU-ENS-{i + 1:03d}" + + return merged + + +# ============================================================================= +# Top-Level Ensemble Merge +# ============================================================================= + +def ensemble_merge( + semantic_indices: list[dict], +) -> dict: + """Merge N semantic index outputs into one ensemble result. + + Args: + semantic_indices: List of semantic_index dicts from each temperature run. + semantic_indices[0] should be the lowest-temperature version. + + Returns: + Merged semantic_index dict with structure: + { + "feature_name": str, + "ensemble_versions": int, + "concepts": [...], + "function_units": [...], + "confidence_summary": {...}, + } + """ + if not semantic_indices: + return { + "feature_name": "", + "ensemble_versions": 0, + "concepts": [], + "function_units": [], + "confidence_summary": {}, + } + + total = len(semantic_indices) + + # Extract concepts and function_units from each version + all_concepts = [si.get("concepts", []) for si in semantic_indices] + all_units = [si.get("function_units", []) for si in semantic_indices] + + # Merge + merged_concepts = ensemble_merge_concepts(all_concepts) + merged_units = ensemble_merge_function_units(all_units) + + # Feature name: majority vote across versions + feature_names = [si.get("feature_name", "") for si in semantic_indices] + name_counts = defaultdict(int) + for fn in feature_names: + if fn: + name_counts[fn] += 1 + feature_name = max(name_counts, key=name_counts.get) if name_counts else "" + + # Confidence summary + unit_conf = defaultdict(int) + for u in merged_units: + unit_conf[u.get("confidence", "low")] += 1 + concept_conf = defaultdict(int) + for c in merged_concepts: + concept_conf[c.get("confidence", "low")] += 1 + + return { + "feature_name": feature_name, + "ensemble_versions": total, + "concepts": merged_concepts, + "function_units": merged_units, + "confidence_summary": { + "total_units": len(merged_units), + "high": unit_conf.get("high", 0), + "medium": unit_conf.get("medium", 0), + "low": unit_conf.get("low", 0), + "total_concepts": len(merged_concepts), + "concept_high": concept_conf.get("high", 0), + "concept_medium": concept_conf.get("medium", 0), + "concept_low": concept_conf.get("low", 0), + }, + } diff --git a/skills/ir_generation_skill/main.py b/skills/ir_generation_skill/main.py new file mode 100644 index 0000000..f5160b3 --- /dev/null +++ b/skills/ir_generation_skill/main.py @@ -0,0 +1,157 @@ +""" +IR Generation Pipeline Orchestrator. + +Run all four stages sequentially: + python main.py [--skip-step1] [--skip-step2] [--skip-step2.5] [--skip-step3] [--test-only] + +The pipeline reads the parsed PRD JSON from doc_parser and produces: + - ir_final.json: the final IR rules + - ir_audit_report.md: completeness audit report for human review +""" + +import argparse +import os +import subprocess +import sys +from pathlib import Path + +import config + +BASE_DIR = Path(__file__).parent + + +def _subprocess_env(extra: dict | None = None) -> dict: + """Build environment dict for subprocesses, carrying forward overrides.""" + env = os.environ.copy() + env.update(extra or {}) + return env + + +def run_step(script_name: str, description: str, extra_env: dict | None = None) -> bool: + """Run a single pipeline step script, return True if it succeeded.""" + print(f"\n{'#' * 60}") + print(f"# {description}") + print(f"{'#' * 60}") + script_path = BASE_DIR / script_name + if not script_path.exists(): + print(f"错误: 脚本不存在 {script_path}") + return False + result = subprocess.run( + [sys.executable, str(script_path)], + cwd=str(BASE_DIR), + env=_subprocess_env(extra_env), + ) + return result.returncode == 0 + + +def run_test(test_name: str, description: str, extra_env: dict | None = None) -> bool: + """Run a test script, return True if all tests passed.""" + print(f"\n{'='*60}") + print(f"测试: {description}") + print(f"{'='*60}") + test_path = BASE_DIR / "tests" / test_name + if not test_path.exists(): + print(f"错误: 测试脚本不存在 {test_path}") + return False + result = subprocess.run( + [sys.executable, str(test_path)], + cwd=str(BASE_DIR), + env=_subprocess_env(extra_env), + ) + return result.returncode == 0 + + +def main(): + parser = argparse.ArgumentParser(description="IR Generation Pipeline") + parser.add_argument("--skip-step1", action="store_true", + help="跳过阶段一(语义索引)") + parser.add_argument("--skip-step2", action="store_true", + help="跳过阶段二(IR 提取)") + parser.add_argument("--skip-step2.5", "--skip-step2-5", action="store_true", + dest="skip_step2_5", + help="跳过阶段2.5(分支覆盖自动补全)") + parser.add_argument("--skip-step3", action="store_true", + help="跳过阶段三(合并与审计)") + parser.add_argument("--test-only", action="store_true", + help="仅运行测试,不调用 LLM") + parser.add_argument( + "--input", "-i", type=str, default=None, + help="输入 JSON 文件路径(覆盖默认的 doc_parser 输出)" + ) + parser.add_argument( + "--provider", "-p", type=str, default=None, + help="LLM provider: deepseek | dashscope(覆盖 IR_PROVIDER 环境变量)" + ) + args = parser.parse_args() + + # Build extra env vars for subprocesses + extra_env = {} + if args.input: + extra_env["IR_INPUT_JSON"] = args.input + print(f"输入文件: {args.input}") + if args.provider: + extra_env["IR_PROVIDER"] = args.provider + print(f"LLM Provider: {args.provider}") + + if args.test_only: + all_ok = True + all_ok &= run_test("test_step1.py", "Step 1 验证", extra_env) + all_ok &= run_test("test_step2.py", "Step 2 验证", extra_env) + all_ok &= run_test("test_step2_5.py", "Step 2.5 验证", extra_env) + all_ok &= run_test("test_step3.py", "Step 3 验证", extra_env) + sys.exit(0 if all_ok else 1) + + failures = [] + + # Stage 1 + if not args.skip_step1: + ok = run_step("step1_semantic_index.py", + "阶段一:宏观语义索引", extra_env) + if not ok: + failures.append("阶段一") + print("\n阶段一失败,停止流水线。修复后重试。") + sys.exit(1) + run_test("test_step1.py", "Step 1 验证", extra_env) + + # Stage 2 + if not args.skip_step2: + ok = run_step("step2_ir_extraction.py", + "阶段二:逐功能单元 IR 提取", extra_env) + if not ok: + failures.append("阶段二") + print("\n阶段二失败,停止流水线。修复后重试。") + sys.exit(1) + run_test("test_step2.py", "Step 2 验证", extra_env) + + # Stage 2.5 + if not args.skip_step2_5: + ok = run_step("step2_5_branch_coverage.py", + "阶段2.5:分支覆盖自动补全", extra_env) + if not ok: + failures.append("阶段2.5") + print("\n阶段2.5失败,停止流水线。修复后重试。") + sys.exit(1) + run_test("test_step2_5.py", "Step 2.5 验证", extra_env) + + # Stage 3 + if not args.skip_step3: + ok = run_step("step3_merge_and_audit.py", + "阶段三:确定性合并与完整性校验", extra_env) + if not ok: + failures.append("阶段三") + sys.exit(1) + run_test("test_step3.py", "Step 3 验证", extra_env) + + if failures: + print(f"\n失败阶段: {', '.join(failures)}") + sys.exit(1) + + print(f"\n{'='*60}") + print("流水线全部完成!") + print(f"最终 IR: {config.IR_FINAL_JSON}") + print(f"审计报告: {config.IR_AUDIT_REPORT_MD}") + print(f"{'='*60}") + + +if __name__ == "__main__": + main() diff --git a/skills/ir_generation_skill/prompts/step1_feedback.txt b/skills/ir_generation_skill/prompts/step1_feedback.txt new file mode 100644 index 0000000..7d70432 --- /dev/null +++ b/skills/ir_generation_skill/prompts/step1_feedback.txt @@ -0,0 +1,46 @@ +## 上一轮遗漏分析 + +上一轮生成的语义索引经过自动校验,发现以下问题需要修正: + +### 遗漏的逻辑树路径 +以下逻辑树决策路径未被任何 function_unit 覆盖,请为每条路径生成对应的 function_unit: +{missing_paths} + +### 遗漏的概念 +以下关键概念未在 concepts 列表中出现,请补充: +{missing_concepts} + +### 格式问题 +以下 function_unit 或 concept 的格式不符合要求: +{format_issues} + +### concept parent 问题 +以下概念的 parent 引用有问题(悬空引用或缺少 parent): +{parent_issues} + +--- + +请在本次生成中针对以上问题进行修正。注意: +1. 你不需要从头生成完整的语义索引,只需要输出**补充和修正**的部分 +2. function_units 的输出应只包含本次新增或修正的单元(已有的正确单元不需要重复) +3. concepts 的输出应只包含本次新增或修正的概念 +4. 如果格式问题中提到"空壳单元":删除该 unit,或将其合并到包含实际 action 的 unit 中。纯开关状态不是独立的功能行为 +5. 如果格式问题中提到"不构成有效路径":说明你引用了互斥分支上的节点。检查 logic_tree_nodes,确保它们都落在逻辑树的**同一条分支路径**上(例如 n4 是关闭分支,n8 是开启分支,不能共存) +6. 如果格式问题提到"缺少 path"或"缺少 sources":补充对应字段 + +## 输出格式 + +只输出 JSON: + +{ + "feature_name": "(与之前相同)", + "supplemental_function_units": [ + // 只放新增的或修正的 function_unit + ], + "supplemental_concepts": [ + // 只放新增的或修正的 concept + ], + "corrections": { + // 需要修正的已有项: { "unit_id或concept_name": { 修正后的字段 }, ... } + } +} diff --git a/skills/ir_generation_skill/prompts/step1_semantic_index.txt b/skills/ir_generation_skill/prompts/step1_semantic_index.txt new file mode 100644 index 0000000..99b687d --- /dev/null +++ b/skills/ir_generation_skill/prompts/step1_semantic_index.txt @@ -0,0 +1,123 @@ +你是吉利汽车车机系统(XX Auto)的产品需求分析师。你的任务是从行车娱乐限制功能 PRD 文档中提取"语义索引"——一份结构化、有层级的功能清单,而不是逐字翻译。 + +## 文档结构说明 + +下面是一份 Word 文档的解析结果,包含: + +1. **sections**:按章节组织的混合内容(段落 + 表格),每个 section 有 `source`(章节标题)、`blocks`(`para` 文本段落和 `table` 结构表格)、`images`(引用的图片 ID 列表) +2. **image_analysis**:文档中流程图的程序化分析结果,其中 `logic_tree` 是由节点组成的决策树: + - `state` 节点:状态说明 + - `decision` 节点:判断条件 + `branches`(分支值 → 目标节点 ID) + - `action` 节点:系统或用户交互动作 +3. **resolved_conflicts**:文档中图文冲突的仲裁结果,明确指出应以文字还是图片为准 + +## 文档全文 + +{document_json} + +## 你的任务 + +阅读整份文档后,输出一份 **语义索引 JSON**,包含: + +### 1. feature_name +功能名称,如"行车娱乐限制" + +### 2. concepts(带层级) +文档中定义或使用的关键概念列表。每个概念包含: +- `name`:概念的标准名称(必填) +- `aliases`:同义词/别名列表(如"行车娱乐限制"、"行车娱乐禁止") +- `defined_in`:定义该概念的章节号列表(如 ["3.1", "3.1.1"]) +- `parent`:父概念名称(字符串或 null)(必填) + +**概念层级规则(重要)**: +你必须按照以下 4 层结构组织概念,并为每个概念指定正确的 `parent`: +- **Level 0(地理范围)**: "国内"、"海外" — parent 为 null +- **Level 1(功能)**: "行车娱乐限制"、"行车娱乐禁止" — parent 为对应的 scope(如 "国内") +- **Level 2(限制方式)**: "系统限制"、"SDK限制"、"其他应用" — parent 为对应的 feature +- **Level 3(具体行为)**: "前台打断"、"后台限制启动"、"后台暂停功能"、"无限制" — parent 为对应的 method + +除了以上层级,还可以有"行车娱乐限制开关"、"车速条件"、"档位条件"、"Toast提示"等辅助概念,它们应有合理的 parent。 + +**重要约束:每个 concept 的 parent 值必须是 concepts 列表中已存在的另一个 concept 的 name,或者是 null。禁止引用不存在的概念名。** + +### 3. function_units(带路径) +文档中描述的所有主要功能行为的列表。**每个 function_unit 对应逻辑树中的一条叶子路径**。每个 function unit 包含: + +- `unit_id`:唯一标识,格式 "FU-001", "FU-002"... +- `name`:简短名称,如"国内-系统限制-前台-行车打断" +- `description`:1-3 句描述该规则的行为 +- `path`:层级路径数组,从高到低,如 `["国内", "系统限制", "前台打断"]`(必填)。**path 中的每个元素必须是 concepts 列表中已存在的概念名。** +- `sources`:该规则在文档中的来源锚点列表,每项包含: + - `section`:章节号 + - `type`:来源类型,`"table"` 或 `"para"` 或 `"logic_tree"` + - `row`:如果是表格行(从 1 开始) + - `text_snippet`:前 200 字的关键文字 + - `image_id`:如果是逻辑树来源,填写图片 rId + - `logic_tree_nodes`:如果是逻辑树来源,列出相关节点 ID 列表 + +## function_units 分解策略(重要) + +**按逻辑树的每条叶子路径生成一个 function_unit**: + +1. **叶子路径 = 从根节点到叶子节点(end 类型)的完整决策链**,包含路径上所有中间节点和叶子节点的最终动作 +2. **每条叶子路径对应一个 function_unit**:不同决策分支导向不同叶子节点 → 不同的 function_unit +3. **"不受限"叶子节点也必须建模**:即使 action 是"不执行任何限制操作",也要创建对应的 function_unit +4. **禁止合并不同叶子节点**:不要将多个不同叶子节点的结果合并到一个 function_unit(除非它们触发完全相同的动作且属于同一父分支) +5. **文字描述中的功能单独列出**:对于无法对应到逻辑树节点的功能(如纯文字描述的功能行为),用 table/para 类型 source,path 用语义路径 +6. **非流程图的图片也可能包含功能行为**:rId18 等图片的描述文本中可能包含功能规则(如"使用语音打开受限应用"),同样需要提取为 function_unit + +**重要:不要创建纯开关/状态的空壳 unit**。"开关开启"本身不是一个功能行为(它没有 action),它是其他单元的 precondition。如果一个 function_unit 的 path 只有 `["国内", "开关开启"]` 且 sources 中只有 n1/n2/n3 这样的根/开关节点,说明它不是真正的功能单元,不应该输出。 + +{feedback} + +## 权威性规则 + +1. **逻辑树(流程图)是权威来源**:逻辑树定义了功能的确切行为。识别 function_unit 时必须优先按逻辑树路径建模。文字和表格用于补充描述、提供确切措辞(如 Toast 文案),但不应覆盖或曲解逻辑树路径。 + +2. **logic_tree_nodes 必须构成有效路径**:每个 function_unit 引用的 logic_tree_nodes 列表,必须对应逻辑树中的**一条连通路径**。禁止将互斥分支上的节点混入同一个 source(例如 n4 是"开关关闭"分支,n8 是"开关开启"分支的下游节点,它们不能出现在同一 function_unit 中)。 + +3. **resolved_conflicts 中的仲裁是最终决定**:如果文档有图文冲突且已仲裁,严格按仲裁结果处理。 + +4. **逻辑树路径应全部覆盖**:下面是程序从文档逻辑树中枚举的全部决策路径,请逐一确认每条路径都有对应的 function_unit: + +{logic_tree_paths} + +## 关键要求 + +1. **必须覆盖所有逻辑树路径**:上面列出的每条路径必须被至少一个 function_unit 的 sources 引用。 + +2. **必须覆盖表格中的所有规则**:表格中列出的每种"限制方法"、"限制规则"都要有对应的 function_unit。 + +3. **区分"限制"与"禁止"**:文档中"行车娱乐限制"(前台应用打断)和"行车娱乐禁止"(后台应用启动限制)是两个不同的子场景,必须分别建模。 + +4. **区分不同应用类型**:系统限制、SDK 限制、其他应用的行为路径不同,必须分别建模。 + +5. **包含开关状态**:开关"开启"和"关闭"两种状态下的行为都要覆盖。 + +6. **概念和路径必须有层级**:每个 concept 指定正确的 parent;每个 function_unit 输出 path 数组。 + +## 输出格式 + +**只输出 JSON,不要有 markdown 代码块标记或其他文字**: + +{ + "feature_name": "...", + "concepts": [ + {"name": "国内", "aliases": [], "defined_in": ["2.7", "3.1"], "parent": null}, + {"name": "行车娱乐限制", "aliases": [], "defined_in": ["3.1", "3.1.1"], "parent": "国内"}, + ... + ], + "function_units": [ + { + "unit_id": "FU-001", + "name": "国内-系统限制-前台-行车打断", + "description": "...", + "path": ["国内", "系统限制", "前台打断"], + "sources": [ + {"section": "3.1.1", "type": "table", "row": 2, "text_snippet": "打断:车速>=15km/h且持续5秒后..."}, + {"image_id": "rId16", "type": "logic_tree", "logic_tree_nodes": ["n2","n3","n8","n19","n21","n23","n25","n26"]} + ] + }, + ... + ] +} diff --git a/skills/ir_generation_skill/prompts/step2_ir_extraction.txt b/skills/ir_generation_skill/prompts/step2_ir_extraction.txt new file mode 100644 index 0000000..797ab1a --- /dev/null +++ b/skills/ir_generation_skill/prompts/step2_ir_extraction.txt @@ -0,0 +1,200 @@ +你是吉利汽车车机系统的需求分析专家。你的任务是基于给定的精准上下文包,为单个功能单元(Function Unit)提取详细的 **IR 规则(Intermediate Representation Rule)**。 + +## 上下文 + +下面是一个功能单元的精准上下文包,包含了从原始需求文档中提取的相关文字、表格和逻辑树: + +### 功能单元概要 +- **unit_id**: {unit_id} +- **unit_name**: {unit_name} +- **unit_description**: {unit_description} + +### 相关文字段落 +{texts} + +### 相关表格 +{tables} + +### 相关逻辑树 +{logic_trees} + +### 图文冲突仲裁(如有) +{resolved_conflicts} + +## IR Schema + +你需要为这个功能单元输出一个 **规则数组(rules)**。每条规则遵循以下 schema: + +```json +{{ + "rule_id": "{unit_id}-DOMESTIC-SYS-FG-INTERRUPT-01", + "path": ["国内", "系统限制", "前台打断"], + "description": "国内车型,开关开启,系统限制类应用在前台,车速>=15km/h且持续>5秒且非P档时,系统打断应用前台进程、将应用调入后台,显示Toast'在行车状态下无法使用该应用'", + "priority": "P0", + "sources": [ + {{"type": "table", "section": "3.1.1", "row": 2, "text_snippet": "打断:车速>=15km/h且持续5秒后..."}}, + {{"type": "logic_tree", "image_id": "rId16", "node_ids": ["n2","n3","n8","n19","n21","n23","n25","n26"], "priority": "primary_source"}} + ], + "precondition": {{ + "geographic_scope": "国内", + "screen_type": "any", + "switch": "开启", + "app_type": "系统限制", + "app_state": "前台" + }}, + "trigger": {{ + "operator": "AND", + "conditions": [ + {{"signal": "车速", "operator": ">=", "value": 15, "unit": "km/h"}}, + {{"signal": "车速_持续时间", "operator": ">", "value": 5, "unit": "秒"}}, + {{"signal": "档位", "operator": "!=", "value": "P"}} + ] + }}, + "actions": [ + {{"type": "system", "description": "打断应用前台进程"}}, + {{"type": "system", "description": "将应用调入后台"}}, + {{"type": "user_interaction", "description": "显示Toast", "content": "在行车状态下无法使用该应用"}} + ] +}} +``` + +## 字段说明(必读) + +1. **rule_id**: 格式为 `{unit_id}-SCOPE-METHOD-BEHAVIOR-NN`,其中: + - SCOPE: DOMESTIC(国内)| OVERSEAS(海外) + - METHOD: SYS(系统限制)| SDK(SDK限制)| OTHER(其他应用) + - BEHAVIOR: FG-INTERRUPT(前台打断)| BG-BLOCK(后台限制启动)| BG-PAUSE(后台暂停功能)| NO-RESTRICT(无限制)| SWITCH-OFF(开关关闭) + - NN: 序号从 01 开始 + +2. **path**: 层级路径数组(必填)。从 scope 到 behavior 逐级列出,如 `["国内", "系统限制", "前台打断"]`。此字段用于程序化遍历所有功能点。 + +3. **description**: 完整但简洁地描述整个规则,必须包含:地理范围 + 开关状态 + 应用类型 + 前后台状态 + 触发条件 + 所有动作。人读取此字段即可设计测试用例。 + +4. **priority**: P0(核心安全规则)、P1(重要规则)、P2(边界情况)。 + +5. **sources**: 每条规则必须列出所有数据来源。逻辑树类型的 source 必须标记 `"priority": "primary_source"`。文字/表格类型的 source 标记 `"priority": "supplementary"`。**node_ids 必须列举该规则在逻辑树中经历的所有 decision 和 action 节点。** + +6. **precondition**: 规则生效的前置状态条件。必须包含以下字段: + - `geographic_scope`(必填):"国内" | "海外" + - `screen_type`(必填):"CSD" | "PSD" | "RFD" | "any"(如文档未区分屏幕类型则填 "any") + - `switch`:开关状态("开启" | "关闭") + - `app_type`:应用类型 + - `app_state`:应用前后台状态("前台" | "后台") + 如某字段不适用,可省略。 + +7. **trigger**: 触发条件对象: + - `operator`: "AND" | "OR" + - `conditions`: 条件数组,每个条件必须有 `signal`、`operator`、`value`。有单位加 `unit`。 + - 如为瞬时事件(用户点击),用 `event` 字段。 + +8. **actions**: 每个动作必须有 `type`("system" | "user_interaction")和 `description`。 + - `"user_interaction"` 类型必须有 `content` 字段,填写**确切的提示文案**。 + - **禁止使用占位符**:content 不能是"文案由业务定义"、"待定"、"自定义"等。如果文档中给出了文案,必须原样填入。如果文档确实未给出文案,填写 `"(文档未指定)"` 并标注。 + +## Few-shot 示例 + +### 示例 1:行车娱乐限制(前台打断) + +**输入上下文**:国内车型,开关开启,系统限制类应用在前台,车速>=15km/h且持续>5秒且非P档时,打断应用并显示Toast"在行车状态下无法使用该应用"。 + +**期望输出**: + +```json +{{ + "rule_id": "FU-001-DOMESTIC-SYS-FG-INTERRUPT-01", + "path": ["国内", "系统限制", "前台打断"], + "description": "国内车型,开关开启,系统限制类应用在前台,当车速>=15km/h且持续超过5秒且非P档时,系统打断应用前台进程、将应用调入后台,并弹出Toast提示'在行车状态下无法使用该应用'", + "priority": "P0", + "sources": [ + {{"type": "table", "section": "3.1.1", "row": 2, "text_snippet": "行车娱乐限制:目标应用/功能处于前台时 ○ 打断:车速>=15km/h且持续5秒后...", "priority": "supplementary"}}, + {{"type": "logic_tree", "image_id": "rId16", "node_ids": ["n2","n3","n8","n19","n21","n23","n25","n26"], "priority": "primary_source"}} + ], + "precondition": {{ + "geographic_scope": "国内", + "screen_type": "any", + "switch": "开启", + "app_type": "系统限制", + "app_state": "前台" + }}, + "trigger": {{ + "operator": "AND", + "conditions": [ + {{"signal": "车速", "operator": ">=", "value": 15, "unit": "km/h"}}, + {{"signal": "车速_持续时间", "operator": ">", "value": 5, "unit": "秒"}}, + {{"signal": "档位", "operator": "!=", "value": "P"}} + ] + }}, + "actions": [ + {{"type": "system", "description": "打断应用前台进程"}}, + {{"type": "system", "description": "将应用调入后台"}}, + {{"type": "user_interaction", "description": "显示Toast", "content": "在行车状态下无法使用该应用"}} + ] +}} +``` + +### 示例 2:行车娱乐禁止(后台启动拦截) + +**输入上下文**:国内车型,开关开启,应用在后台,非P档时阻止应用启动,提示"请在P挡时使用该功能/应用"。 + +**期望输出**: + +```json +{{ + "rule_id": "FU-002-DOMESTIC-SYS-BG-BLOCK-01", + "path": ["国内", "系统限制", "后台限制启动"], + "description": "国内车型,开关开启,目标应用处于后台,当用户尝试启动应用且档位非P档时,系统限制应用/功能启用,并弹出Toast提示'请在P挡时使用该功能/应用'", + "priority": "P0", + "sources": [ + {{"type": "table", "section": "3.1.1", "row": 2, "text_snippet": "行车娱乐禁止:目标应用/功能处于后台时 ○ 限制:非P挡时,限制目标应用/功能启用...", "priority": "supplementary"}}, + {{"type": "logic_tree", "image_id": "rId17", "node_ids": ["n1","n2","n5","n7"], "priority": "primary_source"}} + ], + "precondition": {{ + "geographic_scope": "国内", + "screen_type": "any", + "switch": "开启", + "app_state": "后台" + }}, + "trigger": {{ + "operator": "AND", + "conditions": [ + {{"signal": "应用请求启动", "operator": "==", "value": true}}, + {{"signal": "档位", "operator": "!=", "value": "P"}} + ] + }}, + "actions": [ + {{"type": "system", "description": "限制应用/功能启用"}}, + {{"type": "user_interaction", "description": "显示Toast", "content": "请在P挡时使用该功能/应用"}} + ] +}} +``` + +## 关键要求 + +1. **逻辑树为唯一权威来源**:触发条件和动作序列必须严格按逻辑树路径建模。文字/表格描述仅用于补充确切措辞(如 Toast 文案),不得覆盖或曲解逻辑树路径。在 sources 中,逻辑树类型标记 `"priority": "primary_source"`,文字/表格标记 `"priority": "supplementary"`。 + +2. **信号和数值必须精确**:禁止写"车速超过阈值",必须写 `{{"signal": "车速", "operator": ">=", "value": 15, "unit": "km/h"}}`。 + +3. **条件必须完整**:逻辑树中的每个 decision 条件必须对应 trigger.conditions 中的一条。如果文档说"车速>=15km/h 且持续超过5秒 且非P档",这三个条件必须全部出现。 + +4. **每条规则必须自包含**:人仅凭一条 rule JSON 就能设计出对应的测试用例。必须包含:geographic_scope、screen_type、开关状态、应用类型、前后台状态、完整触发条件、所有动作及确切 Toast 文案、来源引用。 + +5. **禁止占位符**:`"user_interaction"` 类型的 `content` 不能是"文案由业务定义"、"待定"、"自定义"。如文档确实未给出文案,填 `"(文档未指定)"`。 + +6. **逻辑树节点必须追踪**:在 sources 中列出该规则在逻辑树中经历的所有 decision 节点和 action 节点。 + +7. **多条规则**:如果一个功能单元包含多个独立行为分支,输出多条规则分别描述。 + +8. **开关关闭状态**:开关关闭时所有限制失效,这也必须作为一条规则输出(path: ["...", "开关关闭", "无限制"])。 + +{format_feedback} + +## 输出格式 + +**只输出 JSON 数组,不要有任何其他文字或 markdown 标记**: + +[ + {{ ... }}, + {{ ... }} +] + +注意:即使只有一个规则,也必须用数组格式 `[...]`。 diff --git a/skills/ir_generation_skill/scripts/LLM.py b/skills/ir_generation_skill/scripts/LLM.py deleted file mode 100644 index e6f2099..0000000 --- a/skills/ir_generation_skill/scripts/LLM.py +++ /dev/null @@ -1,105 +0,0 @@ -import logging -import os -import time -from typing import Optional - -from openai import OpenAI - -logger = logging.getLogger(__name__) - - -class LLMClient: - """Low-level OpenAI-compatible LLM client with retry and token tracking. - - Usage:: - - llm = LLMClient() - content = llm.chat("qwen3.5-flash", [{"role": "user", "content": "Hello"}]) - print(llm.usage) - """ - - IMAGE_MODEL = "qwen3-vl-plus" - TEXT_MODEL = "qwen3.5-flash-2026-02-23" - TIMEOUT = 120 - MAX_RETRIES = 3 - - def __init__( - self, - *, - base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1", - timeout: int | None = None, - ): - key = os.environ.get("DASHSCOPE_API_KEY", "") - if not key: - raise ValueError("DASHSCOPE_API_KEY environment variable is not set.") - self._client = OpenAI(api_key=key, base_url=base_url) - self._timeout = timeout or self.TIMEOUT - self._prompt_tokens = 0 - self._completion_tokens = 0 - - @property - def usage(self) -> dict: - """Return accumulated token counts as ``{prompt, completion, total}``.""" - return { - "prompt_tokens": self._prompt_tokens, - "completion_tokens": self._completion_tokens, - "total_tokens": self._prompt_tokens + self._completion_tokens, - } - - @staticmethod - def estimate_tokens(text: str) -> int: - """Quick token estimate. CJK ≈1.7/token, others ≈3.0/token.""" - cjk = sum(1 for c in text if '一' <= c <= '鿿' or ' ' <= c <= '〿') - other = len(text) - cjk - return max(1, int(cjk / 1.7 + other / 3.0)) - - @staticmethod - def estimate_image_tokens() -> int: - """Fixed estimate for one vision-model image (~500 tokens).""" - return 500 - - def chat( - self, model: str, messages: list[dict], *, timeout: int | None = None, - response_format: dict | None = None, - ) -> str: - """Send a chat completion request and return the response content. - - Automatically retries on failure and accumulates token usage. - """ - label = f"chat({model})" - - def _call(): - t0 = time.time() - kwargs = dict(model=model, messages=messages, timeout=timeout or self._timeout) - if response_format is not None: - kwargs["response_format"] = response_format - kwargs["temperature"] = 0 - resp = self._client.chat.completions.create(**kwargs) - content = resp.choices[0].message.content - usg = resp.usage - if usg: - self._prompt_tokens += usg.prompt_tokens - self._completion_tokens += usg.completion_tokens - elapsed = time.time() - t0 - logger.info("%s: %d chars in %.1fs", label, len(content) if content else 0, elapsed) - if not content: - raise RuntimeError("Empty response from LLM") - return content - - return self._retry(_call, label) - - def _retry(self, fn, label: str) -> str: - """Call *fn()* with exponential-backoff retry.""" - last_error: Optional[Exception] = None - for attempt in range(self.MAX_RETRIES): - try: - return fn() - except Exception as e: - last_error = e - logger.warning( - "%s error (attempt %d/%d): %s", - label, attempt + 1, self.MAX_RETRIES, e, - ) - if attempt < self.MAX_RETRIES - 1: - time.sleep(2 ** attempt) - raise RuntimeError(f"{label}: all retries exhausted") from last_error diff --git a/skills/ir_generation_skill/scripts/ir_generator.py b/skills/ir_generation_skill/scripts/ir_generator.py deleted file mode 100644 index cac1307..0000000 --- a/skills/ir_generation_skill/scripts/ir_generator.py +++ /dev/null @@ -1,359 +0,0 @@ -#!/usr/bin/env python3 -"""Generate JSON intermediate representation from ``_parsed.json`` or ``_updated.json``. - -Sends the JSON document directly to the LLM for analysis. If the document exceeds -``MAX_ANALYSIS_TOKENS``, sections are batched greedily without splitting any -individual section. Conflict corrections from ``resolved_conflicts`` are included -so the output respects user arbitration decisions. - -Usage:: - - python scripts/ir_generator.py output/_updated.json [output_dir] [--dry-run] - -Output: ``_ir.json`` -""" - -import argparse -import json -import logging -import os -import sys -import time - -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - -from LLM import LLMClient - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Configuration -# --------------------------------------------------------------------------- - -RATE_LIMIT_DELAY = 0.5 -MAX_ANALYSIS_TOKENS = 6000 # max content size per LLM call - - -# --------------------------------------------------------------------------- -# Prompt -# --------------------------------------------------------------------------- - -PROMPT = """你是一个需求文档分析助手。请分析以下需求文档的JSON内容,输出结构化JSON。 - -## 已知修正(来自冲突检测) -以下内容已确认修正,生成JSON时请**使用修正后的值**,不要同时输出两个版本。 -{conflict_context} - -## 待分析内容(JSON格式) - -{content} - -## JSON字段说明 -- sections: 文档章节列表,每个章节含 source(章节标题)和 blocks(内容块数组) -- blocks: 类型含 para(段落,字段 text)和 table(表格,字段 rows,每行含 columns 数组) -- image_sources: 图片所在章节映射,key 为图片 rid -- image_analysis: 图片分析结果,每个含 rid、type(流程图/架构图/状态图等)、description -- resolved_conflicts: 已知修正列表,每个含 section、conflict_type、correction、source - -## 功能点定义 - -只有满足以下**全部条件**的才视为功能点: -1. 描述了一个**系统或软件要实现的具体行为**(有触发条件、执行动作、状态变化或逻辑规则) -2. 该行为直接由**系统或框架**执行(不是人的操作流程、管理流程) -3. 对用户或系统有**可观察的效果** - -**以下内容不是功能点,不要输出:** -- 术语/缩略词定义( -- 文档背景、范围说明(如 "本文档涵盖xxx") -- 变更日志、版本记录、编制人信息 -- 文档结构描述(如 "产品简介用户场景说明") -- 纯文本的概述、没有具体行为的介绍 - -## 决策树/流程图分解规则(重要) - -图片分析(image_analysis)中的流程图和决策树描述包含丰富的功能逻辑,**必须完全分解**: - -1. **每个叶子路径 = 一个独立 function**:从根节点到每个最终结果的完整路径,都拆成一个 function -2. **每个判断分支 = 一个独立 function**:菱形判断节点的每个分支方向和对应的结果,单独作为一个 function -3. **不同约束条件 = 不同 function**:例如"通过接入SDK限制"和"通过系统限制"是不同约束机制,必须分别列出 -4. **不要合并不同路径**:即使最终结果相同,只要到达路径不同,就是不同的 function - -## 输出格式 - -只输出功能点,每个功能点格式如下: - -{ - "function": "功能名称", - "source": { - "section": "章节名", - "location": "原文位置(如:正文第1段、表格1第2行、图片rId13)" - }, - "trigger": { - "type": "AND或者OR", - "conditions": [ - "触发条件1", - "触发条件2" - ] - }, - "actions": { - "场景/角色": [ - "动作1", - "动作2" - ] - } -} - -## 输出原则 - -1. **只输出功能点**,没有功能点就输出空数组 [] -2. 每个功能点**必须**包含 source.section 和 source.location -3. location 必须是具体的原文位置标签(如 "正文第1段"、"表格1"、"图片rId13") -4. **一个 function 只对应一种行为逻辑(一条完整路径)**。决策树中的每个分支路径(从根到叶子)必须拆成独立 function,conditions 中明确写出该路径上的所有判断条件和分支方向。 -5. **穷举所有分支**:流程图/决策树中的每一条分支路径都要输出对应的 function,不能遗漏任何子逻辑。 -6. 没有 trigger 或 actions 的字段直接**省略**,不要写 null 或空列表/空对象 -7. 所有功能点全部列出,**宁多勿漏** -8. **已知修正**中确认的信息,使用修正后的值 -9. 输出一个JSON数组,不要用 ```json 代码块包裹,直接输出纯JSON -""" - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _parse_llm_response(raw: str) -> list | dict | str | None: - """Parse JSON from LLM response, handling markdown code fences.""" - if raw is None: - return None - stripped = raw.strip() - if stripped.startswith("```"): - nl = stripped.find("\n") - stripped = stripped[nl + 1:] if nl != -1 else stripped[3:] - if stripped.endswith("```"): - stripped = stripped[:-3] - try: - return json.loads(stripped) - except json.JSONDecodeError: - logger.warning(" Failed to parse JSON, returning raw text") - return raw - - -def _build_conflict_context( - section_name: str | None, - resolved_conflicts: list[dict], -) -> str: - """Build conflict correction context for a section, or all if section_name is None.""" - if section_name is None: - relevant = resolved_conflicts - else: - relevant = [c for c in resolved_conflicts if c.get("section", "") == section_name] - if not relevant: - return "没有" - - lines: list[str] = [] - for c in relevant: - correction = c.get("correction", "") - conflict_type = c.get("conflict_type", "") - source = c.get("source", "") - lines.append(f"- 冲突类型:{conflict_type},依据:{source}") - lines.append(f" 修正后的值:{correction}") - - return "\n".join(lines) - - -# --------------------------------------------------------------------------- -# LLM analysis -# --------------------------------------------------------------------------- - -def _analyze_content( - content: str, - conflict_context: str, - llm: LLMClient, - *, - dry_run: bool = False, -) -> list[dict]: - """Send content to the LLM and return IR entries.""" - prompt = PROMPT.replace("{conflict_context}", conflict_context).replace("{content}", content) - - if dry_run: - est = llm.estimate_tokens(prompt) - logger.info(" [DRY RUN] prompt ~%d tokens", est) - return [] - - try: - raw = llm.chat( - model=LLMClient.TEXT_MODEL, - messages=[{"role": "user", "content": prompt}], - response_format={"type": "json_object"}, - ) - logger.info(" Response: %d chars", len(raw)) - except RuntimeError as e: - logger.error(" Analysis failed: %s", e) - return [] - - parsed = _parse_llm_response(raw) - if isinstance(parsed, list): - return parsed - elif isinstance(parsed, dict): - return [parsed] - else: - logger.warning(" Unparseable response, raw length: %d", len(raw)) - return [] - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - -def generate_ir( - parsed_path: str, - output_dir: str = "output", - *, - dry_run: bool = False, -) -> dict: - """Read parsed/updated JSON and generate JSON IR. - - Produces ``_ir.json`` in *output_dir*. - """ - with open(parsed_path, "r", encoding="utf-8") as f: - data = json.load(f) - - basename = os.path.splitext(os.path.basename(parsed_path))[0] - for suffix in ("_parsed", "_updated"): - if basename.endswith(suffix): - basename = basename[:-len(suffix)] - break - os.makedirs(output_dir, exist_ok=True) - - llm = LLMClient() - ir_output: list[dict] = [] - - sections = data.get("sections", []) - image_sources = data.get("image_sources", {}) - image_analysis = data.get("image_analysis", []) - resolved_conflicts = data.get("resolved_conflicts", []) - - # Build full document JSON to measure size - full_doc = { - "sections": sections, - "image_sources": image_sources, - "image_analysis": image_analysis, - } - full_json = json.dumps(full_doc, ensure_ascii=False) - total_chars = len(full_json) - logger.info("Total document JSON chars: %d", total_chars) - - if total_chars < MAX_ANALYSIS_TOKENS: - logger.info("Document fits in one request (< %d chars)", MAX_ANALYSIS_TOKENS) - conflict_ctx = _build_conflict_context(None, resolved_conflicts) - entries = _analyze_content(full_json, conflict_ctx, llm, dry_run=dry_run) - ir_output.extend(entries) - else: - logger.info("Document is large (>= %d chars), batching sections", MAX_ANALYSIS_TOKENS) - - # Filter to non-empty sections, measure effective size per section - # (section JSON + image_sources + image_analysis for images in that section) - sec_sizes = [] - for sec in sections: - if not sec.get("blocks"): - continue - sec_json = json.dumps(sec, ensure_ascii=False) - sec_chars = len(sec_json) - # Add image overhead for this section - sec_name = sec.get("source", "") - sec_rids = [rid for rid, src in image_sources.items() - if src.get("section", "") == sec_name] - if sec_rids: - overhead_doc = { - "image_sources": {rid: image_sources[rid] for rid in sec_rids}, - "image_analysis": [img for img in image_analysis - if img.get("rid", "") in sec_rids], - } - sec_chars += len(json.dumps(overhead_doc, ensure_ascii=False)) - sec_sizes.append((sec, sec_chars)) - - # Greedy batch: never split a section, keep adding until next exceeds limit - i = 0 - while i < len(sec_sizes): - batch = [] - batch_size = 0 - while i < len(sec_sizes) and batch_size + sec_sizes[i][1] <= MAX_ANALYSIS_TOKENS: - batch.append(sec_sizes[i][0]) - batch_size += sec_sizes[i][1] - i += 1 - - if not batch: - i += 1 - continue - - # Collect sections and their images for this batch - batch_names = [s.get("source", "") for s in batch] - batch_image_sources = { - rid: src for rid, src in image_sources.items() - if src.get("section", "") in batch_names - } - batch_images = [ - img for img in image_analysis - if image_sources.get(img.get("rid", ""), {}).get("section", "") in batch_names - ] - - batch_doc = { - "sections": batch, - "image_sources": batch_image_sources, - "image_analysis": batch_images, - } - batch_json = json.dumps(batch_doc, ensure_ascii=False) - - # Merge conflict contexts - ctx_parts = [] - for sn in batch_names: - ctx = _build_conflict_context(sn, resolved_conflicts) - if ctx != "没有": - ctx_parts.append(ctx) - conflict_ctx = "\n".join(ctx_parts) if ctx_parts else "没有" - - label = " + ".join(batch_names) - logger.info("Batch [%s]: %d sections, %d chars", label, len(batch), len(batch_json)) - entries = _analyze_content(batch_json, conflict_ctx, llm, dry_run=dry_run) - ir_output.extend(entries) - time.sleep(RATE_LIMIT_DELAY) - - # ---- save ---------------------------------------------------------------- - ir_path = os.path.join(output_dir, f"{basename}_ir.json") - os.makedirs(os.path.dirname(ir_path) or ".", exist_ok=True) - with open(ir_path, "w", encoding="utf-8") as f: - json.dump(ir_output, f, ensure_ascii=False, indent=2) - logger.info("Saved: %s (%d entries)", ir_path, len(ir_output)) - - # ---- summary ------------------------------------------------------------- - usg = llm.usage - logger.info("Tokens: %d prompt + %d completion = %d total", - usg["prompt_tokens"], usg["completion_tokens"], usg["total_tokens"]) - logger.info("Output: %s", ir_path) - - return {"ir": ir_output, "path": ir_path} - - -# --------------------------------------------------------------------------- -# CLI -# --------------------------------------------------------------------------- - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Generate JSON intermediate representation from parsed/updated JSON.", - ) - parser.add_argument("input", metavar="parsed.json", - help="Path to _parsed.json or _updated.json") - parser.add_argument("output_dir", nargs="?", default="output", metavar="output_dir", - help="Directory for output files (default: output/)") - parser.add_argument("--dry-run", action="store_true", - help="Print token estimates without calling the API.") - - args = parser.parse_args() - generate_ir(args.input, args.output_dir, dry_run=args.dry_run) diff --git a/skills/ir_generation_skill/step1_semantic_index.py b/skills/ir_generation_skill/step1_semantic_index.py new file mode 100644 index 0000000..621d75b --- /dev/null +++ b/skills/ir_generation_skill/step1_semantic_index.py @@ -0,0 +1,717 @@ +""" +Stage 1: Ensemble Semantic Index Generation. + +Generates N parallel LLM calls with different temperatures (e.g., 0.0, 0.3, 0.7), +then deterministically merges the results via ensemble_merge (pure Python, no LLM). +The merged output includes confidence scores for each concept and function_unit. + +Outputs: + - output/semantic_index_r1.json (T=0.0 raw) + - output/semantic_index_r2.json (T=0.3 raw) + - output/semantic_index_r3.json (T=0.7 raw) + - output/semantic_index.json (ensemble-merged final) +""" + +import concurrent.futures +import json +import re +import sys +import time +from pathlib import Path + +import config +from ensemble_merge import ensemble_merge + + +# ---- Path Enumeration (for prompt embedding) ---- + + +def _traverse_nested(node: dict, image_id: str, path_nodes: list, + branch_taken: str | None) -> list[dict]: + """DFS traversal of a logic_tree_nested node, returning leaf path records.""" + node_id = node.get("id", "?") + node_type = node.get("type", "?") + node_name = node.get("name", "") + + path_nodes = path_nodes + [{ + "id": node_id, + "type": node_type, + "label": node_name, + "branch_taken": branch_taken, + }] + + if node_type == "end": + return [_make_path_record(path_nodes, image_id)] + + children = node.get("children", []) + if not children: + return [_make_path_record(path_nodes, image_id)] + + all_paths = [] + for child in children: + # Decision nodes have {condition, node} wrappers; others are direct node dicts + if node_type == "decision": + condition = child.get("condition", "") + child_node = child.get("node", child) + else: + condition = "(implicit)" + child_node = child + + all_paths.extend( + _traverse_nested(child_node, image_id, path_nodes, condition) + ) + + return all_paths + + +def _make_path_record(path_nodes: list, image_id: str) -> dict: + """Build a path record from a completed node chain.""" + action_nodes = [n for n in path_nodes if n["type"] == "action"] + decision_nodes = [n for n in path_nodes if n["type"] == "decision"] + node_ids = [n["id"] for n in path_nodes] + + return { + "path_id": f"PATH-{image_id}-{'-'.join(node_ids)}", + "nodes": path_nodes, + "meaning": _describe_path(path_nodes), + "image_id": image_id, + "action_nodes": action_nodes, + "decision_nodes": decision_nodes, + "node_ids": node_ids, + } + + +def enumerate_logic_tree_paths(nested_tree: dict, image_id: str = "") -> list[dict]: + """Enumerate all root-to-leaf paths from a logic_tree_nested structure. + + Uses the nested tree directly (no flat-list adjacency). Decision nodes + fork by {condition, node} branches; other nodes have direct children. + """ + if not nested_tree: + return [] + return _traverse_nested(nested_tree, image_id, [], None) + + +def _describe_path(path_nodes: list[dict]) -> str: + """Generate a human-readable description of a logic tree path.""" + parts = [] + for n in path_nodes: + label = n["label"] + if n["branch_taken"] and n["branch_taken"] != "(implicit)": + label = f"{label} → {n['branch_taken']}" + parts.append(label) + return " → ".join(parts) + + +def enumerate_all_paths(doc: dict) -> dict[str, list[dict]]: + """Enumerate paths for all logic trees in the document. + + Uses logic_tree_nested when available (proper tree), falling back to + flat logic_tree. Returns {image_id: [path, ...]}. + """ + result = {} + for img in doc.get("image_analysis", []): + rid = img.get("rid", "") + if not rid: + continue + nested = img.get("logic_tree_nested") + if nested: + result[rid] = enumerate_logic_tree_paths(nested, image_id=rid) + else: + lt = img.get("logic_tree") + if lt and lt.get("nodes"): + lt["image_id"] = rid + result[rid] = _enumerate_flat_tree(lt) + elif lt: + result[rid] = [] + return result + + +def _enumerate_flat_tree(tree: dict) -> list[dict]: + """Fallback: enumerate paths from flat logic_tree using adjacency. + Handles start/process/action/state nodes as implicit chain links. + """ + nodes = tree.get("nodes", []) + if not nodes: + return [] + node_map = {n["id"]: n for n in nodes} + image_id = tree.get("image_id", "") + + # Find root: first start/state node, or first process node, or first node + root = None + for n in nodes: + if n["type"] in ("start", "state"): + root = n + break + if root is None: + for n in nodes: + if n["type"] == "process": + root = n + break + if root is None: + root = nodes[0] + + adj = _build_adjacency(nodes, node_map) + paths = [] + + def dfs(current_id, visited, path_nodes, branch_taken): + if current_id in visited: + return + new_visited = visited | {current_id} + node = node_map.get(current_id) + if node is None: + return + + path_nodes = path_nodes + [{ + "id": current_id, + "type": node["type"], + "label": node.get("description") or node.get("condition", ""), + "branch_taken": branch_taken, + }] + + outgoing = adj.get(current_id, []) + if not outgoing: + action_nodes = [n for n in path_nodes if n["type"] == "action"] + decision_nodes = [n for n in path_nodes if n["type"] == "decision"] + node_ids = [n["id"] for n in path_nodes] + paths.append({ + "path_id": f"PATH-{image_id}-{'-'.join(node_ids)}", + "nodes": path_nodes, + "meaning": _describe_path(path_nodes), + "image_id": image_id, + "action_nodes": action_nodes, + "decision_nodes": decision_nodes, + "node_ids": node_ids, + }) + else: + for branch_val, target_id in outgoing: + dfs(target_id, new_visited, path_nodes, branch_val) + + dfs(root["id"], set(), [], None) + return paths + + +def _build_adjacency(nodes, node_map): + """Build {node_id: [(branch_value, target_id)]} adjacency for flat trees. + + Handles: decision branches (explicit), non-branching nodes (implicit sequential). + """ + NON_BRANCHING = {"start", "process", "state", "action"} + + adj = {} + has_explicit_incoming = set() + for n in nodes: + for br in n.get("branches", []): + has_explicit_incoming.add(br["target"]) + + for i, node in enumerate(nodes): + nid = node["id"] + adj.setdefault(nid, []) + + # Explicit edges from decision nodes + for br in node.get("branches", []): + adj[nid].append((br["value"], br["target"])) + + # Implicit edges for non-branching nodes (start/process/state/action) + if node["type"] in NON_BRANCHING and not node.get("branches"): + j = i + 1 + targets = [] + while j < len(nodes): + next_node = nodes[j] + next_nid = next_node["id"] + if next_nid in has_explicit_incoming: + break + if next_node["type"] in NON_BRANCHING | {"end"}: + targets.append(next_nid) + has_explicit_incoming.add(next_nid) + j += 1 + continue + elif next_node["type"] == "decision": + if not targets: + targets.append(next_nid) + break + j += 1 + for t in targets: + adj[nid].append(("(implicit)", t)) + + return adj + + +def format_paths_for_prompt(all_paths: dict[str, list[dict]]) -> str: + """Format enumerated paths as a readable list for the LLM prompt.""" + if not all_paths: + return "(无逻辑树路径)" + + lines = [] + for image_id, paths in all_paths.items(): + lines.append(f"\n### {image_id} 的全部决策路径(共 {len(paths)} 条):") + for i, path in enumerate(paths, 1): + lines.append(f"\n**路径 {i}** (ID: {path['path_id']})") + lines.append(f" 含义: {path['meaning']}") + lines.append(f" 节点: {path['node_ids']}") + lines.append(f" 决策节点: {[n['id'] for n in path['decision_nodes']]}") + lines.append(f" 动作节点: {[n['id'] for n in path['action_nodes']]}") + return "\n".join(lines) + + +# ---- Document Formatting ---- + + +def format_document_for_prompt(doc: dict) -> str: + """Render the full parsed document as a readable string for the LLM prompt.""" + lines = [] + + lines.append("=== SECTIONS ===") + for i, section in enumerate(doc.get("sections", [])): + source = section.get("source", f"(无标题-章节{i})") + lines.append(f"\n--- Section: {source} ---") + + for block in section.get("blocks", []): + if block["type"] == "para": + lines.append(f"[段落 {block['index']}] {block['text']}") + elif block["type"] == "table": + lines.append(f"[表格 {block.get('table', '?')}]") + headers = block.get("headers", []) + lines.append(f" 表头: {' | '.join(headers)}") + for row in block.get("rows", []): + cols = row.get("columns", []) + cell_texts = [] + for c in cols: + cell_texts.append( + f"[行{c.get('row','?')}]{c.get('name','')}: {c.get('text','')}" + ) + lines.append(f" {'; '.join(cell_texts)}") + + images = section.get("images", []) + if images: + lines.append(f" 图片引用: {', '.join(images)}") + + lines.append("\n\n=== IMAGE_ANALYSIS (流程图逻辑树) ===") + for img in doc.get("image_analysis", []): + rid = img.get("rid", "?") + img_type = img.get("type", "?") + lines.append(f"\n--- Image: {rid} (type={img_type}) ---") + lines.append(f" 描述: {img.get('description', '')[:300]}") + + lt = img.get("logic_tree") + if lt: + lines.append(f" 逻辑树根节点: {lt.get('root', '?')}") + lines.append(" 节点详情:") + for node in lt.get("nodes", []): + nid = node.get("id", "?") + ntype = node.get("type", "?") + desc = node.get("description", "") or node.get("condition", "") + lines.append(f" [{ntype}] {nid}: {desc}") + branches = node.get("branches", []) + if branches: + for br in branches: + lines.append(f" → {br['value']} → {br['target']}") + + conflicts = doc.get("resolved_conflicts", []) + if conflicts: + lines.append("\n\n=== RESOLVED_CONFLICTS (图文冲突仲裁) ===") + for c in conflicts: + lines.append( + f" [{c.get('conflict_type','?')}] {c.get('section','?')}: " + f"以{c.get('source','?')}为准 — {c.get('correction','')}" + ) + + return "\n".join(lines) + + +# ---- Prompt Building ---- + + +def build_prompt(doc: dict, feedback: str = "", all_paths: dict | None = None) -> str: + """Load the prompt template and inject the formatted document + paths + feedback.""" + template_path = Path(config.PROMPTS_DIR) / "step1_semantic_index.txt" + template = template_path.read_text(encoding="utf-8") + + formatted_doc = format_document_for_prompt(doc) + prompt = template.replace("{document_json}", formatted_doc) + + if all_paths is None: + all_paths = enumerate_all_paths(doc) + path_text = format_paths_for_prompt(all_paths) + prompt = prompt.replace("{logic_tree_paths}", path_text) + + if feedback: + prompt = prompt.replace("{feedback}", feedback) + else: + prompt = prompt.replace("{feedback}", "") + + return prompt + + +# ---- Validation ---- + + +def _quick_validate( + semantic_index: dict, doc: dict, all_paths: dict | None = None +) -> tuple[bool, dict]: + """Validate semantic index and return (passed, gaps). + + Uses a single COVERAGE_TARGET threshold (default 0.95). + """ + gaps = { + "missing_paths": [], + "missing_concepts": [], + "format_issues": [], + "parent_issues": [], + } + + units = semantic_index.get("function_units", []) + concepts = semantic_index.get("concepts", []) + + # --- Check function_units non-empty --- + if not units: + gaps["format_issues"].append("function_units 为空") + return False, gaps + + # --- Check each function_unit has path --- + for fu in units: + uid = fu.get("unit_id", "?") + if not fu.get("path"): + gaps["format_issues"].append(f"{uid}: 缺少 path 字段") + if not fu.get("sources"): + gaps["format_issues"].append(f"{uid}: 缺少 sources") + + # --- Logic tree node coverage --- + all_nodes = _collect_logic_tree_nodes(doc) + referenced = _collect_referenced_nodes(units) + + threshold = config.COVERAGE_TARGET + + for image_id, node_set in all_nodes.items(): + ref_set = referenced.get(image_id, set()) + checkable = { + nid for nid, ntype in node_set.items() + if ntype in ("decision", "action") + } + if not checkable: + continue + covered = checkable & ref_set + coverage = len(covered) / len(checkable) if checkable else 1.0 + + if coverage < threshold: + missing = checkable - ref_set + gaps["missing_paths"].append( + f"{image_id}: 覆盖率 {coverage:.0%} < {threshold:.0%}, " + f"未覆盖节点: {sorted(missing)}" + ) + + # --- Check logic tree path consistency --- + # A unit's logic_tree_nodes must form a valid (connected) path in the tree. + if all_paths is not None: + for fu in units: + uid = fu.get("unit_id", "?") + for src in fu.get("sources", []): + if src.get("type") != "logic_tree": + continue + image_id = src.get("image_id", "") + unit_nodes = set(src.get("logic_tree_nodes", [])) + if not unit_nodes: + continue + # Check if there exists a path containing all these nodes + valid = False + for path in all_paths.get(image_id, []): + path_nodes = set(path.get("node_ids", [])) + if unit_nodes.issubset(path_nodes): + valid = True + break + if not valid: + gaps["format_issues"].append( + f"{uid}: logic_tree_nodes 不构成有效路径 " + f"(image={image_id}, nodes={sorted(unit_nodes)})" + ) + + # --- Check for trivial units (only state/switch nodes, no actions) --- + if all_paths is not None: + for fu in units: + uid = fu.get("unit_id", "?") + has_logic_ref = False + has_action = False + has_non_trivial_decision = False + for src in fu.get("sources", []): + if src.get("type") != "logic_tree": + continue + has_logic_ref = True + node_ids = src.get("logic_tree_nodes", []) + node_types = {} + for image_id, nset in all_nodes.items(): + for nid in node_ids: + if nid in nset: + node_types[nid] = nset[nid] + for nid in node_ids: + ntype = node_types.get(nid, "") + if ntype == "action": + has_action = True + # Count decisions beyond first level (e.g., n1/n2 are just root+switch) + decisions = [nid for nid in node_ids + if node_types.get(nid, "") == "decision"] + if len(decisions) > 1: + has_non_trivial_decision = True + if has_logic_ref and not has_action and not has_non_trivial_decision: + gaps["format_issues"].append( + f"{uid}: 可能为空壳单元(仅有state/开关节点,无action或深层decision)" + ) + + # --- Concept parent validity --- + concept_names = {c["name"] for c in concepts} + for c in concepts: + name = c.get("name", "?") + parent = c.get("parent") # can be None for scope-level + if parent is not None and parent not in concept_names: + gaps["parent_issues"].append( + f"concept '{name}' 的 parent '{parent}' 不存在" + ) + # Warn about scope-level concepts without parent=null + for c in concepts: + if c.get("parent") is not None: + continue + name = c.get("name", "") + # Scope-level concepts (国内/海外) should have parent=null + if name not in ("国内", "海外", ""): + gaps["parent_issues"].append( + f"concept '{name}' 的 parent 为 null,但它可能不是 scope 概念" + ) + + # --- Check for missing scope concepts --- + if "国内" not in concept_names: + gaps["missing_concepts"].append("缺少 scope 概念: 国内") + if "海外" not in concept_names and any( + "海外" in s.get("source", "") for s in doc.get("sections", []) + ): + gaps["missing_concepts"].append("缺少 scope 概念: 海外") + + passed = ( + not gaps["missing_paths"] + and not gaps["format_issues"] + and not gaps["parent_issues"] + ) + return passed, gaps + + +def _collect_logic_tree_nodes(doc: dict) -> dict[str, dict[str, str]]: + """Return {image_id: {node_id: node_type}} for all logic trees.""" + result = {} + for img in doc.get("image_analysis", []): + lt = img.get("logic_tree") + rid = img.get("rid", "") + if lt and rid: + result[rid] = {n["id"]: n["type"] for n in lt.get("nodes", [])} + return result + + +def _collect_referenced_nodes(units: list[dict]) -> dict[str, set[str]]: + """Return {image_id: {referenced node_ids}} across all function_units.""" + refs = {} + for fu in units: + for src in fu.get("sources", []): + if src.get("type") == "logic_tree": + image_id = src.get("image_id", "") + if image_id not in refs: + refs[image_id] = set() + refs[image_id].update(src.get("logic_tree_nodes", [])) + return refs + + +# ---- LLM Calls ---- + + +def extract_json_from_response(text: str) -> str: + """Robustly extract JSON from LLM response.""" + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text) + if m: + return m.group(1).strip() + + start = text.find("{") + if start == -1: + raise ValueError("No JSON object found in LLM response") + + depth = 0 + for i in range(start, len(text)): + if text[i] == "{": + depth += 1 + elif text[i] == "}": + depth -= 1 + if depth == 0: + return text[start : i + 1] + + raise ValueError("Unclosed JSON object in LLM response") + + +def call_llm(prompt: str, max_retries: int = 2, + temperature: float | None = None) -> dict: + """Send prompt to LLM, return parsed JSON dict. + + Args: + temperature: Override config.TEMPERATURE. If None, uses config default. + """ + client = config.llm_client() + temp = temperature if temperature is not None else config.TEMPERATURE + + for attempt in range(max_retries + 1): + print(f" LLM 调用 T={temp} (尝试 {attempt + 1}/{max_retries + 1})...", flush=True) + try: + resp = client.chat.completions.create( + model=config.MODEL_NAME, + messages=[ + { + "role": "system", + "content": "你是一个精确的 JSON 输出引擎。只输出合法的 JSON。", + }, + {"role": "user", "content": prompt}, + ], + temperature=temp, + max_tokens=config.MAX_TOKENS, + ) + content = resp.choices[0].message.content + if content is None: + raise RuntimeError("LLM returned empty response") + + json_str = extract_json_from_response(content) + return json.loads(json_str) + + except (json.JSONDecodeError, ValueError) as e: + print(f" JSON 解析失败: {e}") + if attempt < max_retries: + time.sleep(2) + + raise RuntimeError("无法从 LLM 响应中解析 JSON") + + +# ---- Ensemble Orchestration ---- + + +def run_ensemble_semantic_index(doc: dict) -> dict: + """Run N parallel LLM calls at different temperatures, then ensemble-merge. + + 1. Enumerate all logic tree paths (once). + 2. Build the prompt (once — no iterative feedback needed). + 3. Launch len(ENSEMBLE_TEMPERATURES) parallel LLM calls via ThreadPoolExecutor. + 4. Collect all results. + 5. Call ensemble_merge() for deterministic merge. + 6. Validate final output with _quick_validate(). + 7. Save individual version outputs + merged output. + """ + all_paths = enumerate_all_paths(doc) + print(f" 已枚举逻辑树路径: {sum(len(v) for v in all_paths.values())} 条") + + prompt = build_prompt(doc, "", all_paths) + print(f" Prompt 长度: {len(prompt)} 字符") + + temperatures = config.ENSEMBLE_TEMPERATURES + print(f" 集成温度: {temperatures}") + + # Parallel LLM calls + raw_results: list[tuple[int, float, dict]] = [] + + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(temperatures) + ) as executor: + future_to_meta = {} + for i, temp in enumerate(temperatures): + future = executor.submit(call_llm, prompt, 2, temp) + future_to_meta[future] = (i, temp) + + for future in concurrent.futures.as_completed(future_to_meta): + idx, temp = future_to_meta[future] + try: + si = future.result() + n_units = len(si.get("function_units", [])) + n_concepts = len(si.get("concepts", [])) + print(f" T={temp}: {n_concepts} 概念, {n_units} 功能单元") + raw_results.append((idx, temp, si)) + except Exception as e: + print(f" T={temp}: FAIL — {e}") + raw_results.append((idx, temp, { + "feature_name": "", "concepts": [], "function_units": [] + })) + + if not raw_results: + raise RuntimeError("所有集成的 LLM 调用均失败") + + # Sort by temperature for determinism + raw_results.sort(key=lambda x: x[1]) + semantic_indices = [r[2] for r in raw_results] + + # Save individual version outputs + version_paths = { + 0: config.SEMANTIC_INDEX_R1_JSON, + 1: config.SEMANTIC_INDEX_R2_JSON, + 2: config.SEMANTIC_INDEX_R3_JSON, + } + for i, si in enumerate(semantic_indices): + out_path = version_paths.get(i) + if out_path: + config.save_json(si, out_path) + print(f" 保存版本 {i} (T={temperatures[i]}): {out_path}") + + # Ensemble merge + print(f"\n 集成合并 {len(semantic_indices)} 个版本...") + merged = ensemble_merge(semantic_indices) + merged["ensemble_temperatures"] = list(temperatures) + + # Validate + passed, gaps = _quick_validate(merged, doc, all_paths) + merged["validation_passed"] = passed + merged["validation_gaps"] = { + k: v for k, v in gaps.items() if v + } + + # Print summary + cs = merged.get("confidence_summary", {}) + print(f" 合并后: {cs.get('total_concepts', 0)} 概念, " + f"{cs.get('total_units', 0)} 功能单元") + print(f" 置信度: high={cs.get('high', 0)}, medium={cs.get('medium', 0)}, " + f"low={cs.get('low', 0)}") + print(f" 验证: {'PASS' if passed else 'GAPS FOUND'}") + if not passed: + for k, v in gaps.items(): + if v: + print(f" {k}: {len(v)} 个问题") + + return merged + + +# ---- Main ---- + + +def main(): + print("=" * 60) + print("阶段一:集成语义索引 (Ensemble Semantic Index)") + print("=" * 60) + + # 1. Load input + print(f"\n[1/3] 加载输入文档: {config.INPUT_JSON}") + doc = config.load_input_document() + print(f" 已加载 {len(doc.get('sections', []))} 个 section, " + f"{len(doc.get('image_analysis', []))} 张图片分析") + + # 2. Run ensemble generation + merge + print(f"\n[2/3] 运行集成语义索引 ({len(config.ENSEMBLE_TEMPERATURES)} 个温度版本)...") + merged_index = run_ensemble_semantic_index(doc) + + # 3. Save outputs + print(f"\n[3/3] 保存最终语义索引: {config.SEMANTIC_INDEX_JSON}") + config.save_json(merged_index, config.SEMANTIC_INDEX_JSON) + + # Also save path enumeration for downstream use + all_paths = enumerate_all_paths(doc) + config.save_json( + {"logic_tree_paths": {k: v for k, v in all_paths.items()}}, + config.PATH_ENUM_JSON, + ) + print(f" 路径枚举: {config.PATH_ENUM_JSON}") + + cs = merged_index.get("confidence_summary", {}) + n_concepts = cs.get("total_concepts", len(merged_index.get("concepts", []))) + n_units = cs.get("total_units", len(merged_index.get("function_units", []))) + n_versions = merged_index.get("ensemble_versions", len(config.ENSEMBLE_TEMPERATURES)) + print(f"\n完成! {n_versions} 版本集成, {n_concepts} 个概念, {n_units} 个功能单元.") + print(f"输出: {config.SEMANTIC_INDEX_JSON}") + + +if __name__ == "__main__": + main() diff --git a/skills/ir_generation_skill/step2_5_branch_coverage.py b/skills/ir_generation_skill/step2_5_branch_coverage.py new file mode 100644 index 0000000..14f329c --- /dev/null +++ b/skills/ir_generation_skill/step2_5_branch_coverage.py @@ -0,0 +1,399 @@ +""" +Stage 2.5: Branch Coverage Auto-Completion. + +1. Enumerates all root-to-leaf paths in every logic tree +2. Compares paths against existing IR rules to find uncovered paths +3. Generates synthetic function_units for uncovered paths +4. Calls LLM (same extract_rules_for_unit) to produce rules for synthetic units +5. Iterates up to MAX_RETRIES_PER_STAGE rounds to reach COVERAGE_TARGET + +Outputs: + - output/path_enumeration.json + - output/ir_autocomplete_fragments.json +""" + +import concurrent.futures +import json +import time +from pathlib import Path + +import config + + +# ---- Path Enumeration (shared with step1, duplicated for module independence) ---- + + +def enumerate_all_paths(doc: dict) -> dict[str, list[dict]]: + """Enumerate all root-to-leaf paths for every logic tree.""" + from step1_semantic_index import enumerate_all_paths as _enum + return _enum(doc) + + +# ---- Coverage Analysis ---- + + +def find_referenced_path_ids(rules: list[dict]) -> dict[str, set[str]]: + """Map each rule to the set of logic tree nodes it references. + + Returns {rule_id: set of "image_id:node_id" pairs} + """ + result = {} + for rule in rules: + rid = rule.get("rule_id", "?") + refs = set() + for src in rule.get("sources", []): + if src.get("type") == "logic_tree": + image_id = src.get("image_id", "") + for nid in src.get("node_ids", []): + refs.add(f"{image_id}:{nid}") + result[rid] = refs + return result + + +def compute_path_coverage( + all_paths: dict[str, list[dict]], rules: list[dict] +) -> tuple[list[dict], list[dict], dict]: + """Compute coverage of enumerated paths by existing rules. + + Returns (covered_paths, uncovered_paths, stats). + A path is "covered" if at least one rule's node_ids form a superset + of the path's decision+action nodes for that image. + """ + # Build per-rule node sets keyed by image_id + rule_node_sets = {} # {rule_id: {image_id: set(node_ids)}} + for rule in rules: + rid = rule.get("rule_id", "?") + rule_node_sets[rid] = {} + for src in rule.get("sources", []): + if src.get("type") == "logic_tree": + image_id = src.get("image_id", "") + rule_node_sets[rid].setdefault(image_id, set()).update( + src.get("node_ids", []) + ) + + covered = [] + uncovered = [] + + for image_id, paths in all_paths.items(): + for path in paths: + # Get checkable nodes for this path (decision + action) + checkable = set( + n["id"] for n in path["nodes"] + if n["type"] in ("decision", "action") + ) + if not checkable: + # Path with no decision/action nodes — trivially covered + covered.append(path) + continue + + path_covered = False + for rid, img_sets in rule_node_sets.items(): + rule_nodes = img_sets.get(image_id, set()) + if checkable.issubset(rule_nodes): + path_covered = True + break + + if path_covered: + covered.append(path) + else: + uncovered.append(path) + + total = len(covered) + len(uncovered) + stats = { + "total_paths": total, + "covered_paths": len(covered), + "uncovered_paths": len(uncovered), + "coverage_pct": round(len(covered) / total * 100, 1) if total > 0 else 100.0, + } + return covered, uncovered, stats + + +# ---- Synthetic Function Unit Generation ---- + + +def generate_synthetic_unit(path: dict, unit_seq: int) -> dict: + """Create a synthetic function_unit from an uncovered logic tree path. + + Infers preconditions and trigger from the decision nodes along the path. + """ + node_map = {n["id"]: n for n in path["nodes"]} + + # Infer switch state from path + switch = _infer_switch_state(path) + + # Infer app_type from path + app_type = _infer_app_type(path) + + # Infer app_state from path + app_state = _infer_app_state(path) + + # Infer geographic_scope from section context + scope = _infer_scope(path) + + # Build description from path meaning + description = f"自动补全: {path.get('meaning', '')}" + if switch: + description = f"开关{switch}, {description}" + + # Build path list + path_labels = [] + if scope: + path_labels.append(scope) + if switch: + path_labels.append(f"开关{switch}") + if app_type: + path_labels.append(app_type) + if app_state: + path_labels.append(app_state) + # Add behavior from terminal action + action_nodes = path.get("action_nodes", []) + if action_nodes: + last_action = action_nodes[-1].get("label", "") + path_labels.append(last_action[:20]) + + unit_id = f"FU-AUTO-{path['image_id']}-{unit_seq:03d}" + seq = f"{unit_seq:03d}" + + return { + "unit_id": unit_id, + "name": f"自动补全-{path.get('meaning', '')[:60]}", + "description": description, + "path": path_labels, + "auto_generated": True, + "sources": [ + { + "section": "", + "type": "logic_tree", + "image_id": path["image_id"], + "logic_tree_nodes": path.get("node_ids", []), + } + ], + } + + +def _infer_switch_state(path: dict) -> str: + """Infer switch state from decision nodes in path.""" + for n in path["nodes"]: + label = n.get("label", "") + branch = n.get("branch_taken", "") + if "开关" in label and n["type"] == "decision": + if branch == "开启": + return "开启" + elif branch == "关闭": + return "关闭" + return "" + + +def _infer_app_type(path: dict) -> str: + """Infer app type from state nodes in path.""" + type_map = { + "其他应用": "其他应用", + "SDK限制": "SDK限制", + "通过接入SDK限制的应用": "SDK限制", + "系统限制": "系统限制", + "通过系统限制应用": "系统限制", + } + for n in path["nodes"]: + if n["type"] == "state": + for key, val in type_map.items(): + if key in n.get("label", ""): + return val + return "" + + +def _infer_app_state(path: dict) -> str: + """Infer app state (前台/后台) from decision nodes.""" + for n in path["nodes"]: + label = n.get("label", "") + branch = n.get("branch_taken", "") + if "前台" in label: + if branch == "是": + return "前台" + elif branch == "否": + return "后台" + return "" + + +def _infer_scope(path: dict) -> str: + """Infer geographic scope. Defaults to 国内.""" + return "国内" + + +# ---- LLM Extraction for Synthetic Units ---- + + +def extract_rules_for_synthetic_units( + synthetic_units: list[dict], doc: dict, max_retries: int | None = None +) -> list[dict]: + """Extract IR rules for synthetic function_units using step2's LLM logic.""" + from step2_ir_extraction import ( + build_document_lookup, + extract_context_package, + extract_rules_for_unit, + ) + + if max_retries is None: + max_retries = config.MAX_RETRIES_PER_STAGE + + sections_by_source, image_by_rid, conflicts_by_section = build_document_lookup(doc) + + fragments = [] + for unit in synthetic_units: + pkg = extract_context_package( + unit, doc, sections_by_source, image_by_rid, conflicts_by_section + ) + # Enrich pkg with unit's own path and description + pkg["unit_path"] = unit.get("path", []) + pkg["unit_description"] = unit.get("description", pkg["unit_description"]) + + try: + rules = extract_rules_for_unit(pkg, max_retries) + except Exception as e: + rules = [] + + fragments.append({ + "unit_id": unit["unit_id"], + "unit_name": unit.get("name", ""), + "rules": rules, + "auto_generated": True, + }) + print(f" {unit['unit_id']}: {len(rules)} 条规则") + + return fragments + + +# ---- Iterative Auto-Completion ---- + + +def run_autocomplete( + all_paths: dict[str, list[dict]], + existing_rules: list[dict], + doc: dict, +) -> tuple[list[dict], dict]: + """Run iterative auto-completion. Returns (autocomplete_fragments, final_stats).""" + print(f"\n 初始路径覆盖率分析...") + covered, uncovered, stats = compute_path_coverage(all_paths, existing_rules) + print(f" 覆盖: {stats['covered_paths']}/{stats['total_paths']} " + f"({stats['coverage_pct']}%)") + + if not uncovered: + print(f" 所有路径已覆盖,无需自动补全") + return [], stats + + print(f" 未覆盖路径: {len(uncovered)} 条") + + all_fragments = [] + best_stats = stats + + for round_n in range(1, config.MAX_RETRIES_PER_STAGE + 1): + if not uncovered: + break + + print(f"\n--- 自动补全 第 {round_n} 轮 ---") + print(f" 为 {len(uncovered)} 条未覆盖路径生成合成单元...") + + # Generate synthetic units + start_seq = (round_n - 1) * len(uncovered) + 1 + synthetic_units = [ + generate_synthetic_unit(path, start_seq + i) + for i, path in enumerate(uncovered) + ] + + # Extract rules via LLM + max_llm_workers = min(2, len(synthetic_units)) + if len(synthetic_units) <= 1: + fragments = extract_rules_for_synthetic_units(synthetic_units, doc) + else: + # Sequential to avoid flooding the API + fragments = extract_rules_for_synthetic_units(synthetic_units, doc) + + all_fragments.extend(fragments) + + # Re-compute coverage + all_rules = existing_rules + [ + rule for f in fragments for rule in f.get("rules", []) + ] + covered, uncovered, stats = compute_path_coverage(all_paths, all_rules) + print(f" 第 {round_n} 轮后覆盖: {stats['covered_paths']}/{stats['total_paths']} " + f"({stats['coverage_pct']}%)") + + if stats["coverage_pct"] > best_stats["coverage_pct"]: + best_stats = stats + + if stats["coverage_pct"] >= config.COVERAGE_TARGET * 100: + print(f" 达到目标覆盖率 {config.COVERAGE_TARGET:.0%},停止") + break + + # If coverage didn't improve, try a different approach next round + uncovered_decision_nodes = set() + for p in uncovered: + for n in p.get("decision_nodes", []): + uncovered_decision_nodes.add(n.get("label", "")) + if not uncovered_decision_nodes: + print(f" 无更多可补全路径,停止") + break + + return all_fragments, best_stats + + +# ---- Main ---- + + +def main(): + print("=" * 60) + print("阶段 2.5:分支覆盖自动补全") + print("=" * 60) + + # 1. Load inputs + print(f"\n[1/5] 加载输入...") + doc = config.load_input_document() + fragments = config.load_json(config.IR_FRAGMENTS_JSON) + + all_rules = [] + for f in fragments: + all_rules.extend(f.get("rules", [])) + + print(f" 已有规则: {len(all_rules)} 条") + + # 2. Enumerate paths + print(f"\n[2/5] 枚举逻辑树路径...") + all_paths = enumerate_all_paths(doc) + total_paths = sum(len(v) for v in all_paths.values()) + print(f" 共 {total_paths} 条路径") + + # Save path enumeration for downstream audit + path_enum_data = { + "logic_tree_paths": { + k: [{kk: vv for kk, vv in p.items() if kk != "nodes"} for p in v] + for k, v in all_paths.items() + }, + "total_paths": total_paths, + } + config.save_json(path_enum_data, config.PATH_ENUM_JSON) + + # 3. Run auto-completion + print(f"\n[3/5] 运行自动补全...") + autocomplete_fragments, final_stats = run_autocomplete( + all_paths, all_rules, doc + ) + + # 4. Save + print(f"\n[4/5] 保存自动补全片段...") + config.save_json( + autocomplete_fragments, config.IR_AUTOCOMPLETE_FRAGMENTS_JSON + ) + print(f" 输出: {config.IR_AUTOCOMPLETE_FRAGMENTS_JSON}") + print(f" 生成 {len(autocomplete_fragments)} 个补全片段") + + # 5. Summary + print(f"\n[5/5] 完成!") + print(f" 最终路径覆盖: {final_stats['covered_paths']}/{final_stats['total_paths']} " + f"({final_stats['coverage_pct']}%)") + + if final_stats["coverage_pct"] < config.COVERAGE_TARGET * 100: + remaining = final_stats["total_paths"] - final_stats["covered_paths"] + print(f" WARN: {remaining} 条路径仍未覆盖,将在审计报告中列出") + + +if __name__ == "__main__": + main() diff --git a/skills/ir_generation_skill/step2_ir_extraction.py b/skills/ir_generation_skill/step2_ir_extraction.py new file mode 100644 index 0000000..bcbde92 --- /dev/null +++ b/skills/ir_generation_skill/step2_ir_extraction.py @@ -0,0 +1,508 @@ +""" +Stage 2: Per Function Unit IR Extraction. + +For each function unit from the semantic index, constructs a precision context +package and calls the LLM to extract detailed IR rules. + +Runs multiple LLM calls in parallel (up to MAX_CONCURRENCY). + +Output: output/ir_fragments.json +""" + +import concurrent.futures +import json +import re +import sys +import time +from pathlib import Path + +import config + + +MAX_CONCURRENCY = 3 # Max parallel LLM calls + + +def load_semantic_index() -> dict: + """Load the semantic index from Stage 1.""" + return config.load_json(config.SEMANTIC_INDEX_JSON) + + +def build_document_lookup(doc: dict): + """Build lookup structures for fast context extraction from the document.""" + + # sections_by_source: "3.1.1" -> section dict + sections_by_source = {} + for section in doc.get("sections", []): + source = section.get("source", "") + # Normalize: extract leading number like "3.1.1" + parts = source.split() + if parts: + key = parts[0].strip() + sections_by_source[key] = section + + # image_by_rid: "rId16" -> image_analysis entry + image_by_rid = {} + for img in doc.get("image_analysis", []): + rid = img.get("rid", "") + if rid: + image_by_rid[rid] = img + + # Conflicts indexed by section + conflicts_by_section = {} + for c in doc.get("resolved_conflicts", []): + section = c.get("section", "") + key = section.split()[0] if section else "" + conflicts_by_section.setdefault(key, []).append(c) + + return sections_by_source, image_by_rid, conflicts_by_section + + +def extract_context_package( + fu: dict, doc: dict, sections_by_source: dict, image_by_rid: dict, + conflicts_by_section: dict +) -> dict: + """Build a precision context package for a single function unit.""" + texts = [] + tables = [] + logic_trees = [] + seen_sections = set() + seen_images = set() + + for src in fu.get("sources", []): + src_type = src.get("type", "") + section_key = src.get("section", "").split()[0] if src.get("section") else "" + + # --- Text source --- + if src_type in ("table", "para") and section_key: + if section_key in seen_sections: + continue + seen_sections.add(section_key) + + section = sections_by_source.get(section_key) + if section is None: + # Fuzzy match by prefix + for key in sections_by_source: + if key.startswith(section_key): + section = sections_by_source[key] + break + + if section: + for block in section.get("blocks", []): + if block["type"] == "para": + texts.append({ + "section": section_key, + "text": block["text"] + }) + elif block["type"] == "table": + row_num = src.get("row") if src_type == "table" else None + if row_num is not None: + # Extract only the specific row + matching_rows = [] + for r in block.get("rows", []): + for c in r.get("columns", []): + if c.get("row") == row_num: + matching_rows.append({ + "headers": block.get("headers", []), + "cells": { + col["name"]: col["text"] + for col in r["columns"] + }, + "row": row_num + }) + break + tables.append({ + "section": section_key, + "headers": block.get("headers", []), + "rows": matching_rows, + "all_rows": [ + { + "row": col.get("row"), + "name": col.get("name"), + "text": col.get("text") + } + for row in block.get("rows", []) + for col in row.get("columns", []) + ] + }) + else: + # Include full table + tables.append({ + "section": section_key, + "headers": block.get("headers", []), + "all_rows": [ + { + "row": col.get("row"), + "name": col.get("name"), + "text": col.get("text") + } + for row in block.get("rows", []) + for col in row.get("columns", []) + ] + }) + + # --- Logic tree source --- + if src_type == "logic_tree": + image_id = src.get("image_id", "") + if not image_id or image_id in seen_images: + continue + seen_images.add(image_id) + + img = image_by_rid.get(image_id) + if img: + lt = img.get("logic_tree") + if lt: + logic_trees.append({ + "image_id": image_id, + "description": img.get("description", ""), + "tree": lt + }) + + # Include relevant resolved conflicts + relevant_conflicts = [] + for section_key in seen_sections: + for c in conflicts_by_section.get(section_key, []): + relevant_conflicts.append(c) + + return { + "unit_id": fu["unit_id"], + "unit_name": fu.get("name", ""), + "unit_description": fu.get("description", ""), + "unit_path": fu.get("path", []), + "texts": texts, + "tables": tables, + "logic_trees": logic_trees, + "resolved_conflicts": relevant_conflicts + } + + +def format_context_package(pkg: dict) -> str: + """Format a context package as a readable string for the prompt.""" + parts = [] + + # Texts + parts.append("【文字段落】") + for i, t in enumerate(pkg.get("texts", [])): + parts.append(f"[{t.get('section', '?')}] {t.get('text', '')}") + if not pkg.get("texts"): + parts.append("(无)") + + # Tables + parts.append("\n【表格数据】") + for i, tbl in enumerate(pkg.get("tables", [])): + parts.append(f"表格 {i+1} (section={tbl.get('section', '?')})") + headers = tbl.get("headers", []) + parts.append(f" 表头: {headers}") + parts.append(" 全部行数据:") + for row in tbl.get("all_rows", []): + parts.append( + f" 行{row.get('row','?')}[{row.get('name','?')}]: {row.get('text','')}" + ) + # Highlight matched rows if any + matched = tbl.get("rows", []) + if matched: + parts.append(" <重点关注行>:") + for mr in matched: + parts.append(f" 行{mr.get('row','?')}: {mr.get('cells', {})}") + if not pkg.get("tables"): + parts.append("(无)") + + # Logic trees + parts.append("\n【逻辑树】") + for i, lt in enumerate(pkg.get("logic_trees", [])): + parts.append(f"逻辑树 {i+1} (image_id={lt.get('image_id', '?')})") + parts.append(f" 描述: {lt.get('description', '')[:200]}") + tree = lt.get("tree", {}) + parts.append(f" 根: {tree.get('root', '?')}") + parts.append(" 节点:") + for node in tree.get("nodes", []): + nid = node.get("id", "?") + ntype = node.get("type", "?") + desc = node.get("description", "") or node.get("condition", "") + parts.append(f" [{ntype}] {nid}: {desc}") + for br in node.get("branches", []): + parts.append(f" → {br['value']} → {br['target']}") + if not pkg.get("logic_trees"): + parts.append("(无)") + + # Conflicts + conflicts = pkg.get("resolved_conflicts", []) + if conflicts: + parts.append("\n【图文冲突仲裁】") + for c in conflicts: + parts.append( + f" [{c.get('conflict_type', '?')}] 以{c.get('source', '?')}为准: " + f"{c.get('correction', '')}" + ) + + return "\n".join(parts) + + +def _escape_json_for_format(s: str) -> str: + """Escape curly braces in a JSON string for use with str.format().""" + return s.replace("{", "{{").replace("}", "}}") + + +def build_prompt(pkg: dict, format_feedback: str = "") -> str: + """Build the LLM prompt for a single function unit.""" + template_path = Path(config.PROMPTS_DIR) / "step2_ir_extraction.txt" + template = template_path.read_text(encoding="utf-8") + + prompt = template.format( + unit_id=pkg["unit_id"], + unit_name=_escape_json_for_format(pkg["unit_name"]), + unit_description=_escape_json_for_format(pkg["unit_description"]), + texts=_escape_json_for_format( + json.dumps(pkg.get("texts", []), ensure_ascii=False, indent=2) + ), + tables=_escape_json_for_format( + json.dumps(pkg.get("tables", []), ensure_ascii=False, indent=2) + ), + logic_trees=_escape_json_for_format( + json.dumps(pkg.get("logic_trees", []), ensure_ascii=False, indent=2) + ), + resolved_conflicts=_escape_json_for_format( + json.dumps(pkg.get("resolved_conflicts", []), ensure_ascii=False, indent=2) + ), + format_feedback=_escape_json_for_format(format_feedback), + ) + return prompt + + +def extract_json_from_response(text: str) -> str: + """Extract JSON array from LLM response.""" + m = re.search(r"```(?:json)?\s*(\[[\s\S]*?\])\s*```", text) + if m: + return m.group(1).strip() + + # Find outermost [ ... ] + start = text.find("[") + if start == -1: + raise ValueError("No JSON array found in LLM response") + + depth = 0 + for i in range(start, len(text)): + if text[i] == "[": + depth += 1 + elif text[i] == "]": + depth -= 1 + if depth == 0: + return text[start : i + 1] + + raise ValueError("Unclosed JSON array in LLM response") + + +def _check_rule_fields(rules: list[dict]) -> tuple[bool, list[dict]]: + """Validate each rule has required fields. Returns (passed, failures). + + Each failure: {rule_id, field, issue} + """ + failures = [] + for j, rule in enumerate(rules): + if not isinstance(rule, dict): + failures.append({"rule_id": f"rule[{j}]", "field": "-", "issue": "规则不是 dict"}) + continue + rid = rule.get("rule_id") or f"rule[{j}]" + + if not rule.get("path"): + failures.append({"rule_id": rid, "field": "path", "issue": "缺少 path 字段(必填)"}) + + precond = rule.get("precondition") or {} + if not precond.get("geographic_scope"): + failures.append({"rule_id": rid, "field": "precondition.geographic_scope", "issue": "缺少 geographic_scope(必填)"}) + + for k, action in enumerate(rule.get("actions") or []): + if not isinstance(action, dict): + continue + if action.get("type") == "user_interaction": + content = action.get("content") or "" + if not content: + failures.append({ + "rule_id": rid, "field": f"actions[{k}].content", + "issue": "user_interaction 的 content 为空" + }) + elif any(ph in content for ph in ["文案由业务定义", "待定", "自定义"]): + failures.append({ + "rule_id": rid, "field": f"actions[{k}].content", + "issue": f"content 包含占位符: '{content}'" + }) + + trigger = rule.get("trigger") or {} + for k, cond in enumerate(trigger.get("conditions") or []): + if isinstance(cond, dict): + if not cond.get("signal"): + failures.append({ + "rule_id": rid, "field": f"trigger.conditions[{k}].signal", + "issue": "缺少 signal" + }) + if not cond.get("operator"): + failures.append({ + "rule_id": rid, "field": f"trigger.conditions[{k}].operator", + "issue": "缺少 operator" + }) + if "value" not in cond: + failures.append({ + "rule_id": rid, "field": f"trigger.conditions[{k}].value", + "issue": "缺少 value" + }) + + return len(failures) == 0, failures + + +def _build_fix_prompt(failures: list[dict]) -> str: + """Build a format-fix instruction block for the prompt.""" + if not failures: + return "" + + lines = [ + "\n## 上一轮格式问题修正\n", + "上一轮输出的规则存在以下格式问题,请修正后重新输出:\n", + ] + for f in failures: + lines.append(f"- **{f['rule_id']}.{f['field']}**: {f['issue']}") + + lines.append("\n请修正以上所有问题,重新输出完整的规则数组。") + return "\n".join(lines) + + +def extract_rules_for_unit(pkg: dict, max_retries: int | None = None) -> list[dict]: + """Call LLM for one function unit, return its IR rules. + + Includes format validation with auto-fix retries. + """ + if max_retries is None: + max_retries = config.MAX_RETRIES_PER_STAGE + client = config.llm_client() + prompt = build_prompt(pkg) + last_failures = [] + + for attempt in range(max_retries + 1): + # Append format feedback on retry + if attempt > 0 and last_failures: + fix_text = _build_fix_prompt(last_failures) + prompt = build_prompt(pkg, format_feedback=fix_text) + + try: + resp = client.chat.completions.create( + model=config.MODEL_NAME, + messages=[ + { + "role": "system", + "content": "你是一个精确的 JSON 输出引擎。只输出合法的 JSON 数组。", + }, + {"role": "user", "content": prompt}, + ], + temperature=config.TEMPERATURE, + max_tokens=config.MAX_TOKENS, + ) + content = resp.choices[0].message.content + if content is None: + raise RuntimeError("LLM returned empty response") + + json_str = extract_json_from_response(content) + rules = json.loads(json_str) + if not isinstance(rules, list): + raise ValueError(f"Expected JSON array, got {type(rules).__name__}") + + # Format validation + passed, failures = _check_rule_fields(rules) + if passed: + return rules + + # Format issues found — retry with fix instructions + print(f" 格式问题 ({len(failures)} 个): {[f['field'] for f in failures[:5]]}") + last_failures = failures + if attempt < max_retries: + time.sleep(1) + + except (json.JSONDecodeError, ValueError) as e: + print(f" JSON 解析失败 (尝试 {attempt + 1}): {e}") + last_failures = [{"rule_id": "?", "field": "json", "issue": str(e)}] + if attempt < max_retries: + time.sleep(2) + + # Exhausted retries — return what we have (even if imperfect) + print(f" WARN: {pkg['unit_id']} 格式修复耗尽了 {max_retries} 次重试") + return [] + + +def extract_all_rules( + semantic_index: dict, doc: dict +) -> list[dict]: + """Extract IR rules for all function units. Runs in parallel up to MAX_CONCURRENCY.""" + sections_by_source, image_by_rid, conflicts_by_section = build_document_lookup(doc) + function_units = semantic_index.get("function_units", []) + + print(f" 共 {len(function_units)} 个功能单元待处理") + print(f" 最大并发: {MAX_CONCURRENCY}") + + # Build context packages (serial — fast) + packages = [] + for fu in function_units: + pkg = extract_context_package( + fu, doc, sections_by_source, image_by_rid, conflicts_by_section + ) + packages.append(pkg) + + # Run LLM calls in parallel + fragments = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_CONCURRENCY) as executor: + futures = {} + for i, pkg in enumerate(packages): + future = executor.submit(extract_rules_for_unit, pkg) + futures[future] = (i, pkg["unit_id"], pkg["unit_name"]) + + for future in concurrent.futures.as_completed(futures): + i, uid, uname = futures[future] + try: + rules = future.result() + fragments.append({ + "unit_id": uid, + "unit_name": uname, + "rules": rules + }) + print(f" [OK] {uid} ({uname}): {len(rules)} 条规则") + except Exception as e: + print(f" [FAIL] {uid} ({uname}): 失败 — {e}") + fragments.append({ + "unit_id": uid, + "unit_name": uname, + "rules": [], + "error": str(e) + }) + + # Sort by unit_id to maintain stable ordering + fragments.sort(key=lambda f: f["unit_id"]) + return fragments + + +def main(): + print("=" * 60) + print("阶段二:逐功能单元 IR 提取") + print("=" * 60) + + # 1. Load inputs + print(f"\n[1/3] 加载输入...") + semantic_index = load_semantic_index() + doc = config.load_input_document() + n_units = len(semantic_index.get("function_units", [])) + print(f" 语义索引: {n_units} 个功能单元") + + # 2. Extract rules + print(f"\n[2/3] 逐单元提取 IR 规则...") + fragments = extract_all_rules(semantic_index, doc) + + # 3. Save + print(f"\n[3/3] 保存 IR 片段...") + config.save_json(fragments, config.IR_FRAGMENTS_JSON) + + total_rules = sum(len(f["rules"]) for f in fragments) + failed_units = [f for f in fragments if f.get("error")] + print(f"\n完成! {len(fragments)} 个功能单元, 共 {total_rules} 条规则") + if failed_units: + print(f" [WARN] {len(failed_units)} 个单元提取失败: " + f"{[f['unit_id'] for f in failed_units]}") + print(f"输出: {config.IR_FRAGMENTS_JSON}") + + +if __name__ == "__main__": + main() diff --git a/skills/ir_generation_skill/step3_merge_and_audit.py b/skills/ir_generation_skill/step3_merge_and_audit.py new file mode 100644 index 0000000..ddce256 --- /dev/null +++ b/skills/ir_generation_skill/step3_merge_and_audit.py @@ -0,0 +1,1094 @@ +""" +Stage 3: Deterministic Merge, Consistency Check & Completeness Audit. + +- Merges IR rule fragments (including autocomplete), deduplicating by trigger+actions. +- Reassigns stable hierarchical rule_ids. +- Runs consistency checks: naming uniformity, rule contradictions. +- Generates an audit report covering: + 1. Path coverage (vs enumerated logic tree paths) + 2. Table enumeration coverage + 3. Global switch state coverage + 4. Consistency scan report + 5. Auto-complete summary + 6. Final rule manifest + +Outputs: + - ir_final.json + - ir_audit_report.md +""" + +import hashlib +import json +import sys +from collections import defaultdict +from datetime import datetime +from pathlib import Path + +import config +from step2_5_branch_coverage import compute_path_coverage, enumerate_all_paths + +PASS = "[PASS]" +WARN = "[WARN]" +FAIL = "[FAIL]" + +# ---- Rule ID Generation ---- + +LABEL_MAP = { + "国内": "DOMESTIC", + "海外": "OVERSEAS", + "系统限制": "SYS", + "SDK限制": "SDK", + "SDK自定义限制": "SDK", + "其他应用": "OTHER", + "行车娱乐限制": "SYS", + "行车娱乐禁止": "SYS", + "前台打断": "FG-INTERRUPT", + "后台限制启动": "BG-BLOCK", + "后台禁止": "BG-BLOCK", + "后台暂停功能": "BG-PAUSE", + "后台允许": "BG-ALLOW", + "无限制": "NO-RESTRICT", + "开关关闭": "SWITCH-OFF", + "开关置灰": "SWITCH-GRAY", + "确认弹窗": "CONFIRM-DLG", + "风险确认弹窗": "CONFIRM-DLG", +} + + +def _path_to_rule_id_components(path: list[str]) -> tuple[str, str, str]: + """Extract (scope, method, behavior) from a path array. + + Falls back to "UNKNOWN" for unrecognized components. + """ + scope = "UNKNOWN" + method = "UNKNOWN" + behavior = "UNKNOWN" + + for segment in path: + mapped = LABEL_MAP.get(segment) + if mapped in ("DOMESTIC", "OVERSEAS"): + scope = mapped + elif mapped in ("SYS", "SDK", "OTHER"): + method = mapped + elif mapped in ("FG-INTERRUPT", "BG-BLOCK", "BG-PAUSE", + "NO-RESTRICT", "SWITCH-OFF"): + behavior = mapped + + return scope, method, behavior + + +# ---- Loading ---- + +def load_fragments() -> list[dict]: + """Load IR fragments from Stage 2.""" + return config.load_json(config.IR_FRAGMENTS_JSON) + + +def load_autocomplete_fragments() -> list[dict]: + """Load auto-complete fragments from Stage 2.5, or return [] if absent.""" + path = config.IR_AUTOCOMPLETE_FRAGMENTS_JSON + if not Path(path).exists(): + return [] + return config.load_json(path) + + +def load_semantic_index() -> dict: + """Load merged semantic index from Stage 1.""" + return config.load_json(config.SEMANTIC_INDEX_JSON) + + +def load_path_enumeration() -> dict: + """Load logic tree path enumeration, or return {} if absent.""" + path = config.PATH_ENUM_JSON + if not Path(path).exists(): + return {} + data = config.load_json(path) + return data.get("logic_tree_paths", {}) + + +# ---- Rule Merge ---- + +def rule_signature(rule: dict) -> str: + """Generate a dedup signature from path + trigger + actions.""" + path = rule.get("path", []) + trigger = rule.get("trigger", {}) + actions = rule.get("actions", []) + + conditions = sorted( + trigger.get("conditions", []), key=lambda c: c.get("signal", "") + ) + sorted_actions = sorted(actions, key=lambda a: a.get("description", "")) + + sig_data = { + "path": path, + "conditions": conditions, + "actions": sorted_actions, + } + sig_json = json.dumps(sig_data, ensure_ascii=False, sort_keys=True) + return hashlib.sha256(sig_json.encode()).hexdigest()[:16] + + +def merge_rules(fragments: list[dict], + autocomplete_fragments: list[dict] | None = None) -> list[dict]: + """Merge rules across all fragments, deduplicating by trigger+actions. + + Includes autocomplete fragments if provided. + """ + all_fragments = list(fragments) + if autocomplete_fragments: + all_fragments.extend(autocomplete_fragments) + + signature_map: dict[str, dict] = {} + order = [] + + for fragment in all_fragments: + for rule in fragment.get("rules", []): + sig = rule_signature(rule) + + if sig in signature_map: + existing = signature_map[sig] + existing_sources = existing.setdefault("sources", []) + for src in rule.get("sources", []): + if src not in existing_sources: + existing_sources.append(src) + if len(rule.get("description", "")) > len( + existing.get("description", "") + ): + existing["description"] = rule["description"] + # Merge path if present and different + rpath = rule.get("path", []) + epath = existing.get("path", []) + if rpath and not epath: + existing["path"] = rpath + else: + signature_map[sig] = dict(rule) + order.append(sig) + + merged = [signature_map[sig] for sig in order] + + total_before = sum(len(f.get("rules", [])) for f in all_fragments) + auto_before = sum( + len(f.get("rules", [])) for f in (autocomplete_fragments or []) + ) + print(f" 主片段规则: {total_before - auto_before} 条") + if auto_before: + print(f" 自动补全规则: {auto_before} 条") + print(f" 合并后: {len(merged)} 条 (去重 {total_before - len(merged)} 条)") + return merged + + +# ---- Rule ID Assignment ---- + +def assign_rule_ids(rules: list[dict], feature_id: str = "DRL-001") -> list[dict]: + """Reassign stable hierarchical rule_ids. + + Format: {feature_id}-SCOPE-METHOD-BEHAVIOR-NN + Example: DRL-001-DOMESTIC-SYS-FG-INTERRUPT-01 + """ + # Counter per (scope, method, behavior) key + counters: dict[tuple[str, str, str], int] = defaultdict(int) + + for rule in rules: + path = rule.get("path", []) + scope, method, behavior = _path_to_rule_id_components(path) + + # If path is missing, try to infer from precondition + if scope == "UNKNOWN": + precond = rule.get("precondition", {}) + geo = precond.get("geographic_scope", "") + scope = LABEL_MAP.get(geo, "DOMESTIC") + if method == "UNKNOWN": + precond = rule.get("precondition", {}) + at = precond.get("app_type", "") + method = LABEL_MAP.get(at, "SYS") + if behavior == "UNKNOWN": + precond = rule.get("precondition", {}) + sw = precond.get("switch", "") + if sw == "关闭": + behavior = "SWITCH-OFF" + else: + # Infer from actions + actions = rule.get("actions", []) + action_descs = " ".join( + a.get("description", "") for a in actions + ) + if "打断" in action_descs or "前台" in action_descs: + behavior = "FG-INTERRUPT" + elif "限制" in action_descs and "启动" in action_descs: + behavior = "BG-BLOCK" + elif "暂停" in action_descs: + behavior = "BG-PAUSE" + else: + behavior = "NO-RESTRICT" + + key = (scope, method, behavior) + counters[key] += 1 + seq = counters[key] + rule["rule_id"] = f"{feature_id}-{scope}-{method}-{behavior}-{seq:02d}" + + return rules + + +# ---- Consistency Checks ---- + +def _check_naming_consistency(rules: list[dict]) -> list[dict]: + """Check that app_type, app_state, switch values use unified terminology. + + Returns a list of inconsistency items. + """ + results = [] + + # Collect all values for each field + app_types = set() + app_states = set() + switches = set() + geo_scopes = set() + screen_types = set() + + for rule in rules: + precond = rule.get("precondition", {}) + if precond.get("app_type"): + app_types.add(precond["app_type"]) + if precond.get("app_state"): + app_states.add(precond["app_state"]) + if precond.get("switch"): + switches.add(precond["switch"]) + if precond.get("geographic_scope"): + geo_scopes.add(precond["geographic_scope"]) + if precond.get("screen_type"): + screen_types.add(precond["screen_type"]) + + # Known canonical values + canonical_app_types = {"系统限制", "SDK限制", "其他应用"} + canonical_app_states = {"前台", "后台"} + canonical_switches = {"开启", "关闭"} + canonical_geo = {"国内", "海外"} + canonical_screens = {"CSD", "PSD", "RFD", "any"} + + unknown_app_types = app_types - canonical_app_types + unknown_app_states = app_states - canonical_app_states + unknown_switches = switches - canonical_switches + unknown_geo = geo_scopes - canonical_geo + unknown_screens = screen_types - canonical_screens + + if unknown_app_types: + results.append({ + "field": "app_type", + "issue": f"非标准值: {sorted(unknown_app_types)}", + "expected": sorted(canonical_app_types), + "status": WARN, + }) + + if unknown_app_states: + results.append({ + "field": "app_state", + "issue": f"非标准值: {sorted(unknown_app_states)}", + "expected": sorted(canonical_app_states), + "status": WARN, + }) + + if unknown_switches: + results.append({ + "field": "switch", + "issue": f"非标准值: {sorted(unknown_switches)}", + "expected": sorted(canonical_switches), + "status": WARN, + }) + + if unknown_geo: + results.append({ + "field": "geographic_scope", + "issue": f"非标准值: {sorted(unknown_geo)}", + "expected": sorted(canonical_geo), + "status": WARN, + }) + + if unknown_screens: + results.append({ + "field": "screen_type", + "issue": f"非标准值: {sorted(unknown_screens)}", + "expected": sorted(canonical_screens), + "status": WARN, + }) + + # Also check for near-duplicates (e.g., "系统限制类" vs "系统限制") + similar_pairs = [] + for v1 in app_types: + for v2 in app_types: + if v1 < v2 and (v1 in v2 or v2 in v1): + similar_pairs.append(f"'{v1}' vs '{v2}'") + if similar_pairs: + results.append({ + "field": "app_type", + "issue": f"疑似同义异名: {', '.join(similar_pairs)}", + "status": WARN, + }) + + if not results: + results.append({ + "field": "all", + "issue": "所有字段术语统一", + "status": PASS, + }) + + return results + + +def _trigger_overlaps(t1: dict, t2: dict) -> bool: + """Check if two triggers have any overlapping signal conditions.""" + conds1 = t1.get("conditions", []) + conds2 = t2.get("conditions", []) + signals1 = {c.get("signal") for c in conds1 if isinstance(c, dict)} + signals2 = {c.get("signal") for c in conds2 if isinstance(c, dict)} + return bool(signals1 & signals2) + + +def _actions_conflict(a1: list[dict], a2: list[dict]) -> bool: + """Check if two action lists appear contradictory. + + "Contradictory" means: both have user_interaction with different content, + or one does system interrupt while the other does nothing. + """ + descs1 = {a.get("description", "") for a in a1} + descs2 = {a.get("description", "") for a in a2} + + # If one set is a subset of the other, no conflict — just less detail + if descs1.issubset(descs2) or descs2.issubset(descs1): + return False + + # Check for contradictory user_interaction content + contents1 = { + a.get("content", "") for a in a1 + if a.get("type") == "user_interaction" + } + contents2 = { + a.get("content", "") for a in a2 + if a.get("type") == "user_interaction" + } + if contents1 and contents2 and contents1 != contents2: + return True + + # Check if one has system actions and the other doesn't + has_sys1 = any(a.get("type") == "system" for a in a1) + has_sys2 = any(a.get("type") == "system" for a in a2) + if has_sys1 != has_sys2 and (contents1 or contents2): + return True + + return False + + +def _precondition_overlaps(p1: dict, p2: dict) -> bool: + """Check if two preconditions overlap significantly. + + Two preconditions overlap if they share the same scope, switch state, + and either share app_type or one is unspecified. + """ + if p1.get("geographic_scope") != p2.get("geographic_scope"): + return False + if p1.get("switch") != p2.get("switch"): + return False + # App type overlap (empty = any) + at1 = p1.get("app_type", "") + at2 = p2.get("app_type", "") + if at1 and at2 and at1 != at2: + return False + # App state overlap + as1 = p1.get("app_state", "") + as2 = p2.get("app_state", "") + if as1 and as2 and as1 != as2: + return False + return True + + +def _detect_contradictions(rules: list[dict]) -> list[dict]: + """Find pairs of rules with overlapping preconditions but contradictory actions. + + Returns a list of contradiction items with: + {rule_a, rule_b, conflict_point, resolvable, recommendation} + """ + contradictions = [] + + for i in range(len(rules)): + for j in range(i + 1, len(rules)): + r1 = rules[i] + r2 = rules[j] + rid1 = r1.get("rule_id", f"rule[{i}]") + rid2 = r2.get("rule_id", f"rule[{j}]") + + p1 = r1.get("precondition", {}) + p2 = r2.get("precondition", {}) + + if not _precondition_overlaps(p1, p2): + continue + + t1 = r1.get("trigger", {}) + t2 = r2.get("trigger", {}) + + if not _trigger_overlaps(t1, t2): + continue + + a1 = r1.get("actions", []) + a2 = r2.get("actions", []) + + if not _actions_conflict(a1, a2): + continue + + # Determine the conflict point + path1 = r1.get("path", []) + path2 = r2.get("path", []) + conflict_point = ( + f"相同前置状态 (scope={p1.get('geographic_scope')}, " + f"app={p1.get('app_type', 'any')}, " + f"state={p1.get('app_state', 'any')}) " + f"但行为路径不同: {path1} vs {path2}" + ) + + # Check if resolvable: if paths differ only at behavior level + # and the behaviors are non-overlapping (e.g., FG vs BG) + resolvable = False + if path1 and path2: + shared_prefix = [] + for a, b in zip(path1, path2): + if a == b: + shared_prefix.append(a) + else: + break + # If they share scope/method but differ on app_state → not a conflict + # they are different scenarios + if len(shared_prefix) >= 3: # scope + method + (app_state) + # Actually, same app_state with different behaviors could be a real conflict + pass + elif len(shared_prefix) >= 2: + # Same scope+method, different app_state → not a real conflict + resolvable = True + + contradictions.append({ + "rule_a": rid1, + "rule_b": rid2, + "conflict_point": conflict_point, + "rule_a_path": path1, + "rule_b_path": path2, + "resolvable": resolvable, + "recommendation": ( + "路径前缀不同,可能为不同场景的正常分支" + if resolvable + else "请人工确认是否为真正的矛盾,或合并规则" + ), + }) + + return contradictions + + +def _auto_resolve_contradictions( + contradictions: list[dict], doc: dict +) -> tuple[list[dict], list[dict]]: + """Attempt to auto-resolve contradictions using resolved_conflicts. + + Returns (resolved, unresolved). + """ + if not contradictions: + return [], [] + + resolved_conflicts = doc.get("resolved_conflicts", []) + resolved = [] + unresolved = [] + + for c in contradictions: + # Check if any resolved_conflict covers this + auto_fixed = False + for rc in resolved_conflicts: + rc_text = rc.get("correction", "") + rc.get("conflict_type", "") + # Simple heuristic: check if the conflict involves the sections mentioned + if any( + seg in rc_text + for seg in c.get("rule_a_path", []) + c.get("rule_b_path", []) + ): + c["auto_resolved_by"] = rc.get("correction", "") + c["resolvable"] = True + resolved.append(c) + auto_fixed = True + break + + if not auto_fixed: + unresolved.append(c) + + return resolved, unresolved + + +# ---- Path Coverage Audit ---- + +def audit_path_coverage( + doc: dict, rules: list[dict] +) -> tuple[list[dict], dict]: + """Audit logic tree path coverage (vs node coverage). + + Uses the same path enumeration and coverage computation as step 2.5. + Returns (results_list, stats_dict). + """ + all_paths = enumerate_all_paths(doc) + if not all_paths: + return [], {"total_paths": 0, "covered_paths": 0, + "uncovered_paths": 0, "coverage_pct": 100.0} + + covered, uncovered, stats = compute_path_coverage(all_paths, rules) + + results = [] + for image_id, paths in all_paths.items(): + # Compute per-image coverage + img_covered = [p for p in covered if p.get("image_id") == image_id] + img_uncovered = [p for p in uncovered if p.get("image_id") == image_id] + img_total = len(img_covered) + len(img_uncovered) + img_cov = ( + round(len(img_covered) / img_total * 100, 1) if img_total > 0 else 100.0 + ) + + status = PASS if img_cov >= 95 else (WARN if img_cov >= 70 else FAIL) + detail = f"{len(img_covered)}/{img_total} 路径被覆盖 ({img_cov}%)" + + if img_uncovered: + uncovered_meanings = [p.get("meaning", "?") for p in img_uncovered[:5]] + detail += f"; 未覆盖路径示例: {uncovered_meanings}" + if len(img_uncovered) > 5: + detail += f" ... 还有 {len(img_uncovered) - 5} 条" + + results.append({ + "check": f"逻辑树 {image_id} 路径覆盖率", + "status": status, + "coverage_pct": img_cov, + "detail": detail, + "image_id": image_id, + "uncovered_paths": img_uncovered, + }) + + return results, stats + + +# ---- Table Enumeration Audit ---- + +def find_table_enums(doc: dict) -> list[dict]: + """Find enumerated values in tables.""" + enums = [] + for section in doc.get("sections", []): + for block in section.get("blocks", []): + if block["type"] != "table": + continue + headers = block.get("headers", []) + if not headers: + continue + if "功能" in headers and "功能详细说明" in headers: + for row in block.get("rows", []): + cols = row.get("columns", []) + key_col = next( + (c for c in cols if c.get("name") == "功能"), None + ) + val_col = next( + (c for c in cols if c.get("name") == "功能详细说明"), None + ) + if key_col and val_col: + enums.append({ + "section": section.get("source", ""), + "row": key_col.get("row"), + "key": key_col.get("text", ""), + "value": val_col.get("text", ""), + }) + else: + first_col_name = headers[0] if headers else "" + values = [] + for row in block.get("rows", []): + for col in row.get("columns", []): + if col.get("name") == first_col_name: + values.append(col.get("text", "")) + if values: + enums.append({ + "section": section.get("source", ""), + "column": first_col_name, + "values": values, + }) + return enums + + +def audit_table_enums(rules: list[dict], doc: dict) -> list[dict]: + """Check if key enumerated values appear in rule preconditions.""" + results = [] + rule_preconditions = [rule.get("precondition", {}) for rule in rules] + + # App type coverage + app_types = {"系统限制", "SDK限制", "其他应用"} + found_app_types = set() + for precond in rule_preconditions: + at = precond.get("app_type", "") + if at: + found_app_types.add(at) + missing_types = app_types - found_app_types + results.append({ + "check": "应用类型枚举覆盖", + "status": PASS if not missing_types else WARN, + "detail": f"已覆盖: {found_app_types or '无'}" + + (f"; 未覆盖: {missing_types}" if missing_types else ""), + }) + + # App state coverage + app_states = {"前台", "后台"} + found_states = set() + for precond in rule_preconditions: + st = precond.get("app_state", "") + if st: + found_states.add(st) + missing_states = app_states - found_states + results.append({ + "check": "应用前后台状态覆盖", + "status": PASS if not missing_states else WARN, + "detail": f"已覆盖: {found_states or '无'}" + + (f"; 未覆盖: {missing_states}" if missing_states else ""), + }) + + # Trigger signal coverage + trigger_signals = set() + for rule in rules: + for cond in rule.get("trigger", {}).get("conditions", []): + signal = cond.get("signal", "") + if signal: + trigger_signals.add(signal) + + key_signals = {"车速", "档位", "车速_持续时间", "应用请求启动"} + missing_signals = key_signals - trigger_signals + results.append({ + "check": "触发信号覆盖(车速/档位/持续时间/启动请求)", + "status": PASS if not missing_signals else WARN, + "detail": f"已覆盖信号: {sorted(trigger_signals)}" + + (f"; 未覆盖: {missing_signals}" if missing_signals else ""), + }) + + return results + + +# ---- Switch Coverage Audit ---- + +def audit_switch_coverage(rules: list[dict]) -> list[dict]: + """Check that rules cover both switch ON and OFF states.""" + switch_on = False + switch_off = False + + for rule in rules: + sw = rule.get("precondition", {}).get("switch", "") + if sw == "开启": + switch_on = True + elif sw == "关闭": + switch_off = True + + status = PASS + detail_parts = [] + if switch_on: + detail_parts.append("开关=开启: 有规则覆盖") + else: + detail_parts.append("开关=开启: 未找到规则") + status = FAIL + if switch_off: + detail_parts.append("开关=关闭: 有规则覆盖") + else: + detail_parts.append("开关=关闭: 未找到规则") + status = FAIL + + return [{ + "check": "开关状态完整性(开启/关闭)", + "status": status, + "detail": "; ".join(detail_parts), + }] + + +# ---- Audit Report Generation ---- + +def generate_audit_report( + rules: list[dict], + doc: dict, + feature_name: str, + path_results: list[dict], + path_stats: dict, + enum_results: list[dict], + switch_results: list[dict], + consistency_results: list[dict], + contradictions: list[dict], + unresolved_contradictions: list[dict], + autocomplete_count: int, + path_conflicts: list[dict] | None = None, +) -> str: + """Generate ir_audit_report.md with all audit sections.""" + lines = [] + lines.append("# IR 完整性审计报告") + lines.append("") + lines.append(f"**功能**: {feature_name}") + lines.append(f"**规则总数**: {len(rules)}") + lines.append(f"**生成时间**: {datetime.now().isoformat()}") + lines.append("") + + # Human review notice + issue_count = sum( + 1 for r in path_results + enum_results + switch_results + consistency_results + if r["status"] in (WARN, FAIL) + ) + len(unresolved_contradictions) + len(path_conflicts or []) + + lines.append( + f"> **重要**: 请人工审查以下标记项。" + f"共 {issue_count} 项需要关注。" + ) + lines.append( + f'> 如无需修改,在对应项后标注 **"已确认"**。' + ) + lines.append("") + + # ---- Section 1: Path Coverage ---- + lines.append("## 1. 逻辑树路径覆盖率") + lines.append("") + lines.append( + f"**总体**: {path_stats.get('covered_paths', 0)}/" + f"{path_stats.get('total_paths', 0)} 路径已覆盖 " + f"({path_stats.get('coverage_pct', 0)}%)" + ) + lines.append("") + lines.append("| 图片 ID | 覆盖率 | 状态 | 详情 |") + lines.append("|---------|--------|------|------|") + for r in path_results: + lines.append( + f"| {r['image_id']} | {r['coverage_pct']}% " + f"| {r['status']} | {r['detail']} |" + ) + lines.append("") + + # Uncovered path details + for r in path_results: + uncovered = r.get("uncovered_paths", []) + if uncovered: + lines.append(f"### {r['image_id']} 未覆盖路径详情") + lines.append("") + for p in uncovered[:10]: + meaning = p.get("meaning", "?") + node_ids = p.get("node_ids", []) + lines.append( + f"- **路径**: {meaning} " + f"(节点: {' → '.join(node_ids)})" + ) + if len(uncovered) > 10: + lines.append(f"- ... 还有 {len(uncovered) - 10} 条未覆盖路径") + lines.append("") + + # ---- Section 2: Table Enumeration ---- + lines.append("## 2. 表格枚举覆盖") + lines.append("") + lines.append("| 检查项 | 状态 | 详情 |") + lines.append("|--------|------|------|") + for r in enum_results: + lines.append(f"| {r['check']} | {r['status']} | {r['detail']} |") + lines.append("") + + # ---- Section 3: Switch Coverage ---- + lines.append("## 3. 全局开关状态覆盖") + lines.append("") + lines.append("| 检查项 | 状态 | 详情 |") + lines.append("|--------|------|------|") + for r in switch_results: + lines.append(f"| {r['check']} | {r['status']} | {r['detail']} |") + lines.append("") + + # ---- Section 4: Consistency Scan ---- + lines.append("## 4. 一致性扫描报告") + lines.append("") + + lines.append("### 4.1 术语统一性") + lines.append("") + lines.append("| 字段 | 状态 | 详情 |") + lines.append("|------|------|------|") + for r in consistency_results: + expected = r.get("expected", []) + expected_str = f" (期望: {expected})" if expected else "" + lines.append( + f"| {r['field']} | {r['status']} | {r['issue']}{expected_str} |" + ) + lines.append("") + + lines.append("### 4.2 规则矛盾检测") + lines.append("") + if contradictions: + auto_resolved = [c for c in contradictions if c.get("auto_resolved_by")] + remaining = [c for c in contradictions if not c.get("auto_resolved_by")] + + if auto_resolved: + lines.append(f"**自动解决**: {len(auto_resolved)} 项 (通过图文冲突仲裁)") + lines.append("") + for c in auto_resolved: + lines.append( + f"- {c['rule_a']} vs {c['rule_b']}: " + f"已按仲裁 '**{c['auto_resolved_by']}**' 处理" + ) + lines.append("") + + if remaining: + lines.append(f"**需人工确认**: {len(remaining)} 项") + lines.append("") + for c in remaining: + lines.append(f"### 矛盾: {c['rule_a']} vs {c['rule_b']}") + lines.append(f"- **冲突点**: {c['conflict_point']}") + lines.append(f"- **路径A**: {c.get('rule_a_path', [])}") + lines.append(f"- **路径B**: {c.get('rule_b_path', [])}") + lines.append(f"- **建议**: {c.get('recommendation', '请人工判断')}") + lines.append(f'- [ ] 已确认 (标注 **"已确认"**)') + lines.append("") + else: + lines.append("未检测到规则矛盾。") + lines.append("") + + # ---- Section 4.3: Path Conflicts ---- + lines.append("### 4.3 路径冲突(同path不同行为)") + lines.append("") + if path_conflicts: + for pc in path_conflicts: + lines.append(f"- **Path**: {' > '.join(pc['path'])}") + lines.append(f" - 规则: {', '.join(pc['rule_ids'])}") + lines.append(f" - 不同行为数: {pc['distinct_behaviors']}") + lines.append(f" - 建议: {pc['suggestion']}") + lines.append("") + else: + lines.append("未检测到路径冲突。") + lines.append("") + + # ---- Section 5: Auto-Complete Summary ---- + lines.append("## 5. 自动补全摘要") + lines.append("") + if autocomplete_count > 0: + lines.append(f"- 自动补全片段数: {autocomplete_count}") + lines.append( + f"- 补全后路径覆盖率: " + f"{path_stats.get('coverage_pct', 0)}%" + ) + lines.append(f"- 自动生成的规则已合并到最终规则集中") + else: + lines.append("- 未执行自动补全(所有路径已被手动覆盖,或未运行 step2.5)") + lines.append("") + + # ---- Section 6: Rule Manifest ---- + lines.append("## 6. 规则清单") + lines.append("") + lines.append("| rule_id | Priority | Path | 简述 |") + lines.append("|---------|----------|------|------|") + for rule in rules: + rid = rule.get("rule_id", "?") + pri = rule.get("priority", "?") + path_str = " > ".join(rule.get("path", [])) + desc = rule.get("description", "")[:60] + lines.append(f"| {rid} | {pri} | {path_str} | {desc} |") + lines.append("") + + return "\n".join(lines) + + +def _extract_config_defaults(doc: dict, semantic_index: dict) -> dict: + """Extract configuration defaults (e.g. switch default states) from document. + + Scans table text for patterns like "默认开启"/"默认关闭" and checks + semantic_index concepts for "默认" keywords. + """ + defaults = {} + + # Scan document tables for default config values + for section in doc.get("sections", []): + for block in section.get("blocks", []): + if block["type"] != "table": + continue + for row in block.get("rows", []): + row_texts = [] + for col in row.get("columns", []): + row_texts.append(f"{col.get('name','')}: {col.get('text','')}") + combined = " ".join(row_texts) + + if "行车娱乐限制开关" in combined or "开关" in combined: + if "默认开启" in combined or "默认状态:开启" in combined: + defaults["行车娱乐限制开关"] = { + "default": "开启", + "section": section.get("source", "").split()[0] + if section.get("source") else "", + } + elif "默认关闭" in combined or "默认状态:关闭" in combined: + defaults["行车娱乐限制开关"] = { + "default": "关闭", + "section": section.get("source", "").split()[0] + if section.get("source") else "", + } + + # Supplement from semantic_index concepts + for concept in semantic_index.get("concepts", []): + name = concept.get("name", "") + if "默认" in name: + if "开启" in name: + defaults.setdefault("行车娱乐限制开关", {})["default"] = "开启" + elif "关闭" in name: + defaults.setdefault("行车娱乐限制开关", {})["default"] = "关闭" + + return defaults + + +def _detect_path_conflicts(rules: list[dict]) -> list[dict]: + """Detect rules that share the same path triplet but have different behaviors. + + Returns list of conflict items for the audit report. + """ + from collections import defaultdict + + path_groups = defaultdict(list) + for rule in rules: + path_key = tuple(rule.get("path", [])) + path_groups[path_key].append(rule) + + conflicts = [] + for path_key, group in path_groups.items(): + if len(group) <= 1: + continue + # Check if rules in the same path have different trigger/action signatures + signatures = set() + for r in group: + trigger = r.get("trigger", {}) + actions = tuple( + a.get("description", "") for a in r.get("actions", []) + ) + sig = ( + tuple(sorted( + (c.get("signal",""), c.get("operator",""), str(c.get("value",""))) + for c in trigger.get("conditions", []) + )), + actions, + ) + signatures.add(sig) + + if len(signatures) > 1: + # Same path, different behaviors → potential organization issue + conflicts.append({ + "status": "WARN", + "type": "path_collision", + "path": list(path_key), + "rule_ids": [r["rule_id"] for r in group], + "count": len(group), + "distinct_behaviors": len(signatures), + "suggestion": "多条规则共享相同path但行为不同,考虑拆分path或使用更细粒度的叶子路径", + }) + + return conflicts + + +# ---- Main ---- + +def main(): + print("=" * 60) + print("阶段三:确定性合并、一致性校验与完整性审计") + print("=" * 60) + + # 1. Load inputs + print(f"\n[1/7] 加载输入...") + fragments = load_fragments() + autocomplete_fragments = load_autocomplete_fragments() + doc = config.load_input_document() + semantic_index = load_semantic_index() + path_enum = load_path_enumeration() + + feature_name = semantic_index.get("feature_name", "行车娱乐限制") + feature_id = "DRL-001" + print(f" 功能: {feature_name} ({feature_id})") + print(f" 主片段: {len(fragments)}") + if autocomplete_fragments: + print(f" 自动补全片段: {len(autocomplete_fragments)}") + + # 2. Merge rules + print(f"\n[2/7] 合并去重...") + merged_rules = merge_rules(fragments, autocomplete_fragments) + + # 3. Reassign rule IDs + print(f"\n[3/7] 重分配 rule_id (层次化格式)...") + final_rules = assign_rule_ids(merged_rules, feature_id) + print(f" 已分配 {len(final_rules)} 个稳定 ID") + + # Show ID examples + if final_rules: + sample_ids = [r["rule_id"] for r in final_rules[:3]] + print(f" 示例: {sample_ids}") + + # 4. Consistency checks + print(f"\n[4/7] 一致性扫描...") + consistency_results = _check_naming_consistency(final_rules) + + n_warns = sum(1 for r in consistency_results if r["status"] == WARN) + if n_warns: + print(f" {WARN} {n_warns} 个术语不一致问题") + else: + print(f" {PASS} 术语统一") + + contradictions = _detect_contradictions(final_rules) + resolved, unresolved = _auto_resolve_contradictions(contradictions, doc) + + if resolved: + print(f" {PASS} 自动解决 {len(resolved)} 个矛盾") + if unresolved: + print(f" {WARN} {len(unresolved)} 个矛盾需要人工确认") + for c in unresolved: + print(f" - {c['rule_a']} vs {c['rule_b']}: {c['conflict_point'][:80]}") + if not contradictions: + print(f" {PASS} 未检测到规则矛盾") + + path_conflicts = _detect_path_conflicts(final_rules) + if path_conflicts: + print(f" {WARN} {len(path_conflicts)} 个 path 冲突(同path不同行为)") + else: + print(f" {PASS} 无 path 冲突") + + # 5. Generate audit report + print(f"\n[5/7] 生成审计报告...") + path_results, path_stats = audit_path_coverage(doc, final_rules) + enum_results = audit_table_enums(final_rules, doc) + switch_results = audit_switch_coverage(final_rules) + + report = generate_audit_report( + final_rules, doc, feature_name, + path_results, path_stats, + enum_results, switch_results, + consistency_results, + contradictions, unresolved, + len(autocomplete_fragments), + path_conflicts, + ) + + # 6. Extract config defaults from document + config_defaults = _extract_config_defaults(doc, semantic_index) + if config_defaults: + print(f" 配置默认值: {list(config_defaults.keys())}") + + # 7. Save outputs + print(f"\n[7/7] 保存输出...") + ir_final = { + "feature": feature_name, + "feature_id": feature_id, + "rules": final_rules, + } + if config_defaults: + ir_final["config_defaults"] = config_defaults + config.save_json(ir_final, config.IR_FINAL_JSON) + print(f" IR: {config.IR_FINAL_JSON}") + + with open(config.IR_AUDIT_REPORT_MD, "w", encoding="utf-8") as f: + f.write(report) + print(f" 审计报告: {config.IR_AUDIT_REPORT_MD}") + + # Summary + print(f"\n完成!") + issue_count = ( + sum(1 for r in path_results + enum_results + switch_results + if r["status"] in (WARN, FAIL)) + + n_warns + + len(unresolved) + + len(path_conflicts) + ) + print(f" 规则: {len(final_rules)} 条") + print(f" 路径覆盖: {path_stats.get('coverage_pct', 0)}%") + print(f" 审计问题: {issue_count} 个需要关注") + + if issue_count > 0: + print(f"\n 请查看 {config.IR_AUDIT_REPORT_MD} 并审查标记项。") + + +if __name__ == "__main__": + main() diff --git a/skills/ir_generation_skill/tests/test_ensemble_merge.py b/skills/ir_generation_skill/tests/test_ensemble_merge.py new file mode 100644 index 0000000..5b4da5a --- /dev/null +++ b/skills/ir_generation_skill/tests/test_ensemble_merge.py @@ -0,0 +1,472 @@ +""" +Tests for ensemble_merge.py — all pure Python, no LLM calls, no file I/O. + +Each test uses hardcoded mock data to verify one piece of the merge logic. +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from ensemble_merge import ( + concept_name_similarity, + cluster_concepts, + merge_concept_cluster, + unit_node_jaccard, + path_similarity, + unit_similarity, + cluster_function_units, + pick_best_representative, + compute_confidence_versions, + ensemble_merge_concepts, + ensemble_merge_function_units, + ensemble_merge, + _collect_logic_tree_nodes, +) + +PASS = "[PASS]" +FAIL = "[FAIL]" + +# ---- Mock helpers ---- + +def _mk_unit(unit_id, name, path, logic_tree_nodes, description="", sources=None): + """Create a minimal function_unit dict for testing.""" + if sources is None: + srcs = [] + if logic_tree_nodes: + srcs.append({ + "image_id": "rId16", + "type": "logic_tree", + "logic_tree_nodes": logic_tree_nodes, + }) + if not srcs: + srcs.append({ + "section": "3.1", + "type": "table", + "text_snippet": "test", + }) + else: + srcs = sources + return { + "unit_id": unit_id, + "name": name, + "description": description or f"desc for {name}", + "path": path, + "sources": srcs, + } + + +def _mk_concept(name, parent=None, aliases=None, defined_in=None): + """Create a minimal concept dict for testing.""" + return { + "name": name, + "aliases": aliases or [], + "defined_in": defined_in or ["3.1"], + "parent": parent, + } + + +# ============================================================================= +# Test 1: concept_name_similarity +# ============================================================================= + +def test_concept_name_similarity_exact(): + assert concept_name_similarity("国内", "国内") == 1.0 + assert concept_name_similarity("行车娱乐限制", "行车娱乐限制") == 1.0 + +def test_concept_name_similarity_substring(): + sim = concept_name_similarity("国内行车娱乐限制", "行车娱乐限制") + assert sim >= 0.85, f"expected >= 0.85, got {sim}" + +def test_concept_name_similarity_different(): + sim = concept_name_similarity("国内", "海外") + assert sim < 0.7, f"expected < 0.7, got {sim}" + +def test_concept_name_similarity_seq_matcher(): + sim = concept_name_similarity("前台打断", "前台应用打断") + assert 0.6 < sim < 0.95, f"expected 0.6-0.95, got {sim}" + + +# ============================================================================= +# Test 2: _collect_logic_tree_nodes +# ============================================================================= + +def test_collect_logic_tree_nodes(): + unit = _mk_unit("U1", "test", ["A"], ["n1", "n2", "n3"]) + nodes = _collect_logic_tree_nodes(unit) + assert nodes == {"n1", "n2", "n3"} + +def test_collect_logic_tree_nodes_empty(): + unit = _mk_unit("U2", "test", ["A"], [], sources=[{"section": "3.1", "type": "table"}]) + nodes = _collect_logic_tree_nodes(unit) + assert nodes == set() + + +# ============================================================================= +# Test 3: unit_node_jaccard +# ============================================================================= + +def test_unit_node_jaccard_identical(): + u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2", "n3"]) + u2 = _mk_unit("U2", "b", ["A"], ["n1", "n2", "n3"]) + assert unit_node_jaccard(u1, u2) == 1.0 + +def test_unit_node_jaccard_partial(): + u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2", "n3", "n4"]) + u2 = _mk_unit("U2", "b", ["A"], ["n1", "n2", "n3"]) + # intersection=3, union=4 + assert abs(unit_node_jaccard(u1, u2) - 0.75) < 0.01 + +def test_unit_node_jaccard_disjoint(): + u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2"]) + u2 = _mk_unit("U2", "b", ["B"], ["n3", "n4"]) + assert unit_node_jaccard(u1, u2) == 0.0 + +def test_unit_node_jaccard_both_empty(): + u1 = _mk_unit("U1", "a", ["A"], [], sources=[{"section": "3.1", "type": "table"}]) + u2 = _mk_unit("U2", "b", ["B"], [], sources=[{"section": "3.1", "type": "table"}]) + assert unit_node_jaccard(u1, u2) == 0.0 + + +# ============================================================================= +# Test 4: path_similarity +# ============================================================================= + +def test_path_similarity_identical(): + assert path_similarity( + ["国内", "系统限制", "前台打断"], + ["国内", "系统限制", "前台打断"], + ) == 1.0 + +def test_path_similarity_partial(): + sim = path_similarity( + ["国内", "系统限制", "前台打断"], + ["国内", "系统限制", "后台限制启动"], + ) + # 2/3 set overlap, sequential 3/5 ≈ 0.6 + assert 0.4 < sim < 0.9, f"expected 0.4-0.9, got {sim}" + +def test_path_similarity_different(): + sim = path_similarity(["国内"], ["海外"]) + assert sim < 0.7, f"expected < 0.7, got {sim}" + + +# ============================================================================= +# Test 5: unit_similarity +# ============================================================================= + +def test_unit_similarity_identical(): + u = _mk_unit("U1", "国内-系统限制-前台打断", + ["国内", "系统限制", "前台打断"], + ["n1", "n2", "n3", "n8", "n19"]) + assert unit_similarity(u, u) > 0.99 + +def test_unit_similarity_different(): + u1 = _mk_unit("U1", "a", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3"]) + u2 = _mk_unit("U2", "b", ["海外", "SDK限制"], ["n10", "n11", "n12"]) + assert unit_similarity(u1, u2) < 0.3 + + +# ============================================================================= +# Test 6: cluster_concepts +# ============================================================================= + +def test_cluster_concepts_identical(): + v0 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")] + v1 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")] + v2 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")] + clusters = cluster_concepts([v0, v1, v2]) + # Should have exactly 3 clusters (国内, 海外, 系统限制) + assert len(clusters) == 3, f"expected 3 clusters, got {len(clusters)}" + for c in clusters: + assert len(c) == 3, f"expected each cluster to have 3 members, got {len(c)}" + +def test_cluster_concepts_name_variation(): + v0 = [_mk_concept("国内行车娱乐限制", parent="国内")] + v1 = [_mk_concept("行车娱乐限制", parent="国内")] + v2 = [_mk_concept("国内行车娱乐限制", parent="国内")] + clusters = cluster_concepts([v0, v1, v2]) + assert len(clusters) == 1, f"expected 1 cluster, got {len(clusters)}" + assert len(clusters[0]) == 3, f"expected 3 members, got {len(clusters[0])}" + + +# ============================================================================= +# Test 7: merge_concept_cluster +# ============================================================================= + +def test_merge_concept_cluster(): + cluster = [ + (0, _mk_concept("国内行车娱乐限制", parent="国内", aliases=["限制"])), + (1, _mk_concept("行车娱乐限制", parent="国内", aliases=["行车限制"])), + (2, _mk_concept("行车娱乐限制", parent="国内", aliases=["限制"])), + ] + merged, conf = merge_concept_cluster(cluster, 3) + assert "行车娱乐限制" in merged["name"] + assert merged["parent"] == "国内" + assert set(merged["aliases"]) == {"限制", "行车限制"} + assert conf in ("high", "medium") + + +# ============================================================================= +# Test 8: cluster_function_units +# ============================================================================= + +def test_cluster_function_units_all_agree(): + u0 = _mk_unit("U-001", "国内-系统限制-前台打断", + ["国内", "系统限制", "前台打断"], + ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"], + "switch ON, system app, foreground, speed>=15, non-P, interrupt + toast") + u1 = _mk_unit("U-001", "国内-系统限制-前台打断", + ["国内", "系统限制", "前台打断"], + ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"], + "switch ON, system app, foreground, speed>=15, non-P, interrupt + toast") + u2 = _mk_unit("U-001", "国内-系统限制-前台打断", + ["国内", "系统限制", "前台打断"], + ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"], + "switch ON, system app, foreground, interrupt") + clusters = cluster_function_units([[u0], [u1], [u2]]) + assert len(clusters) == 1, f"expected 1 cluster, got {len(clusters)}" + assert len(clusters[0]) == 3 + +def test_cluster_function_units_partial_agree(): + u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], + ["n1", "n2", "n3", "n8", "n19"]) + u1 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], + ["n1", "n2", "n3", "n8", "n19"]) + u2 = _mk_unit("U-002", "禁止", ["国内", "系统限制", "后台限制启动"], + ["n5", "n6"]) + clusters = cluster_function_units([[u0], [u1], [u2]]) + # u0+u1 in one cluster, u2 in another + assert len(clusters) == 2, f"expected 2 clusters, got {len(clusters)}" + cluster_sizes = sorted(len(c) for c in clusters) + assert cluster_sizes == [1, 2], f"expected cluster sizes [1,2], got {cluster_sizes}" + +def test_cluster_function_units_all_disagree(): + u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3"]) + u1 = _mk_unit("U-002", "禁止", ["国内", "系统限制", "后台限制启动"], ["n5", "n6"]) + u2 = _mk_unit("U-003", "SDK", ["国内", "SDK限制"], ["n10", "n11"]) + clusters = cluster_function_units([[u0], [u1], [u2]]) + assert len(clusters) == 3, f"expected 3 clusters, got {len(clusters)}" + + +# ============================================================================= +# Test 9: pick_best_representative +# ============================================================================= + +def test_pick_best_representative_prefers_rich(): + u0 = _mk_unit("U-001", "short", ["国内", "系统限制"], + ["n1", "n2", "n3"], + description="short desc") + u1 = _mk_unit("U-001", "detailed", ["国内", "系统限制", "前台打断"], + ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"], + description="very detailed description of the full rule behavior " * 5) + cluster = [(0, u0), (1, u1)] + best = pick_best_representative(cluster) + # u1 should win: more nodes, longer description, though u0 has lower temp + assert best["name"] == "detailed" + + +# ============================================================================= +# Test 10: compute_confidence_versions +# ============================================================================= + +def test_confidence_high_unanimous(): + assert compute_confidence_versions(3, 3, True) == "high" + +def test_confidence_high_two_of_three_with_t0(): + assert compute_confidence_versions(2, 3, True) == "high" + +def test_confidence_medium_two_of_three_without_t0(): + assert compute_confidence_versions(2, 3, False) == "medium" + +def test_confidence_low_one_of_three(): + assert compute_confidence_versions(1, 3, False) == "low" + +def test_confidence_high_all_two_versions(): + assert compute_confidence_versions(2, 2, True) == "high" + + +# ============================================================================= +# Test 11: ensemble_merge_concepts +# ============================================================================= + +def test_ensemble_merge_concepts(): + v0 = [_mk_concept("国内"), _mk_concept("海外"), + _mk_concept("国内行车娱乐限制", parent="国内")] + v1 = [_mk_concept("国内"), _mk_concept("海外"), + _mk_concept("行车娱乐限制", parent="国内", + aliases=["限制"], defined_in=["3.1", "3.1.1"])] + v2 = [_mk_concept("国内"), _mk_concept("海外"), + _mk_concept("行车娱乐限制", parent="国内")] + + merged = ensemble_merge_concepts([v0, v1, v2]) + # Should merge the 3 concepts across 3 versions into 3 clusters + assert len(merged) == 3, f"expected 3 merged concepts, got {len(merged)}" + for c in merged: + assert "confidence" in c + assert "ensemble_support" in c + assert c["ensemble_support"] == "3/3" + + +# ============================================================================= +# Test 12: ensemble_merge_function_units +# ============================================================================= + +def test_ensemble_merge_function_units(): + u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], + ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"], + description="full description A") + u1 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], + ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"], + description="full description B (more detail)") + u2 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], + ["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25"], + description="partial description") + + merged = ensemble_merge_function_units([[u0], [u1], [u2]]) + assert len(merged) == 1, f"expected 1 unit, got {len(merged)}" + unit = merged[0] + assert unit["confidence"] == "high" + assert unit["ensemble_support"] == "3/3" + assert unit["source_versions"] == 3 + assert unit["unit_id"].startswith("FU-ENS-") + # Should have picked u1 (more detail) + assert "more detail" in unit["description"] + + +# ============================================================================= +# Test 13: ensemble_merge full integration +# ============================================================================= + +def test_ensemble_merge_full(): + v0 = { + "feature_name": "行车娱乐限制", + "concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")], + "function_units": [ + _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], + ["n1", "n2", "n3", "n8", "n19", "n25", "n26"]), + _mk_unit("U-002", "后台禁止", ["国内", "系统限制", "后台限制启动"], + ["n5", "n6"]), + ], + } + v1 = { + "feature_name": "行车娱乐限制", + "concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")], + "function_units": [ + _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], + ["n1", "n2", "n3", "n8", "n19", "n25", "n26"]), + _mk_unit("U-003", "SDK自定义", ["国内", "SDK限制", "自定义限制"], + ["n10", "n11"]), + ], + } + v2 = { + "feature_name": "行车娱乐限制", + "concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")], + "function_units": [ + _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], + ["n1", "n2", "n3", "n8", "n19", "n25", "n26"]), + ], + } + + result = ensemble_merge([v0, v1, v2]) + + assert result["feature_name"] == "行车娱乐限制" + assert result["ensemble_versions"] == 3 + + units = result["function_units"] + concepts = result["concepts"] + + # Concepts: 国内 + 系统限制 + assert len(concepts) == 2 + + # Units: 打断 (3 versions → high), 后台禁止 (1 version → low), SDK (1 version → low) + assert len(units) == 3 + + high_units = [u for u in units if u["confidence"] == "high"] + low_units = [u for u in units if u["confidence"] == "low"] + assert len(high_units) == 1 + assert len(low_units) == 2 + + # All units should have ensemble fields + for u in units: + assert "confidence" in u + assert "ensemble_support" in u + assert "source_versions" in u + + # Confidence summary + cs = result["confidence_summary"] + assert cs["total_units"] == 3 + assert cs["high"] == 1 + assert cs["low"] == 2 + + +# ============================================================================= +# Runner +# ============================================================================= + +def run_all_tests(): + print("=" * 60) + print("Ensemble Merge 测试 (纯 Python, 无 LLM)") + print("=" * 60) + + tests = [ + ("concept_name_similarity exact", test_concept_name_similarity_exact), + ("concept_name_similarity substring", test_concept_name_similarity_substring), + ("concept_name_similarity different", test_concept_name_similarity_different), + ("concept_name_similarity seq_matcher", test_concept_name_similarity_seq_matcher), + ("collect_logic_tree_nodes", test_collect_logic_tree_nodes), + ("collect_logic_tree_nodes empty", test_collect_logic_tree_nodes_empty), + ("unit_node_jaccard identical", test_unit_node_jaccard_identical), + ("unit_node_jaccard partial", test_unit_node_jaccard_partial), + ("unit_node_jaccard disjoint", test_unit_node_jaccard_disjoint), + ("unit_node_jaccard both_empty", test_unit_node_jaccard_both_empty), + ("path_similarity identical", test_path_similarity_identical), + ("path_similarity partial", test_path_similarity_partial), + ("path_similarity different", test_path_similarity_different), + ("unit_similarity identical", test_unit_similarity_identical), + ("unit_similarity different", test_unit_similarity_different), + ("cluster_concepts identical", test_cluster_concepts_identical), + ("cluster_concepts name variation", test_cluster_concepts_name_variation), + ("merge_concept_cluster", test_merge_concept_cluster), + ("cluster_function_units all_agree", test_cluster_function_units_all_agree), + ("cluster_function_units partial_agree", test_cluster_function_units_partial_agree), + ("cluster_function_units all_disagree", test_cluster_function_units_all_disagree), + ("pick_best_representative", test_pick_best_representative_prefers_rich), + ("confidence high unanimous", test_confidence_high_unanimous), + ("confidence high 2/3 with t0", test_confidence_high_two_of_three_with_t0), + ("confidence medium 2/3 no t0", test_confidence_medium_two_of_three_without_t0), + ("confidence low 1/3", test_confidence_low_one_of_three), + ("confidence high 2/2", test_confidence_high_all_two_versions), + ("ensemble_merge_concepts", test_ensemble_merge_concepts), + ("ensemble_merge_function_units", test_ensemble_merge_function_units), + ("ensemble_merge full", test_ensemble_merge_full), + ] + + passed = 0 + failed = 0 + for name, test_fn in tests: + try: + test_fn() + print(f" {PASS} {name}") + passed += 1 + except AssertionError as e: + print(f" {FAIL} {name}: {e}") + failed += 1 + except Exception as e: + print(f" {FAIL} {name}: unexpected {type(e).__name__}: {e}") + failed += 1 + + print(f"\n{'='*60}") + if failed == 0: + print(f"{PASS} 所有 {passed} 个测试通过!") + else: + print(f"{FAIL} {failed}/{passed + failed} 个测试失败") + print(f"{'='*60}") + + return failed == 0 + + +if __name__ == "__main__": + success = run_all_tests() + sys.exit(0 if success else 1) diff --git a/skills/ir_generation_skill/tests/test_step1.py b/skills/ir_generation_skill/tests/test_step1.py new file mode 100644 index 0000000..7047234 --- /dev/null +++ b/skills/ir_generation_skill/tests/test_step1.py @@ -0,0 +1,370 @@ +""" +Tests for Stage 1 (Semantic Index). + +Validates that the generated semantic_index.json meets all completeness +and structural requirements, including the new iterative features: +- function_units have path fields +- concepts have parent references +- logic tree node coverage meets thresholds +""" + +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) +import config + + +PASS = "[PASS]" +FAIL = "[FAIL]" +WARN = "[WARN]" + + +def load_inputs(): + """Load semantic_index.json and the original parsed document.""" + try: + si = config.load_json(config.SEMANTIC_INDEX_JSON) + except FileNotFoundError: + print(f"{FAIL} semantic_index.json 未找到: {config.SEMANTIC_INDEX_JSON}") + print(" 请先运行 step1_semantic_index.py") + sys.exit(1) + doc = config.load_input_document() + return si, doc + + +def build_image_index(doc: dict) -> dict[str, dict]: + """Build lookup: image rId -> image_analysis entry.""" + idx = {} + for img in doc.get("image_analysis", []): + rid = img.get("rid", "") + if rid: + idx[rid] = img + return idx + + +def build_logic_tree_node_index(doc: dict) -> dict[str, set[str]]: + """Build lookup: image rId -> set of all node IDs in that logic_tree.""" + idx = {} + for img in doc.get("image_analysis", []): + rid = img.get("rid", "") + lt = img.get("logic_tree") + if lt and rid: + node_ids = {n["id"] for n in lt.get("nodes", [])} + idx[rid] = node_ids + return idx + + +def check_unit_ids(units: list[dict]) -> list[str]: + """Check that every function_unit has a non-empty unit_id and name.""" + errors = [] + seen_ids = set() + for i, fu in enumerate(units): + uid = fu.get("unit_id", "") + name = fu.get("name", "") + if not uid: + errors.append(f"function_unit[{i}]: unit_id 为空") + elif uid in seen_ids: + errors.append(f"function_unit[{i}]: unit_id '{uid}' 重复") + seen_ids.add(uid) + if not name: + errors.append(f"function_unit[{i}] ({uid}): name 为空") + return errors + + +def check_unit_paths(units: list[dict]) -> list[str]: + """Check that every function_unit has a non-empty path array.""" + errors = [] + for fu in units: + uid = fu.get("unit_id", "?") + path = fu.get("path", []) + if not path: + errors.append(f"{uid}: path 字段为空或缺失") + elif not isinstance(path, list): + errors.append(f"{uid}: path 必须是数组") + return errors + + +def check_concept_parents(concepts: list[dict]) -> list[str]: + """Check that non-scope concepts have valid parent references.""" + errors = [] + concept_names = {c.get("name", "") for c in concepts} + scope_concepts = {"国内", "海外"} + + for c in concepts: + name = c.get("name", "?") + parent = c.get("parent", "") + + if name in scope_concepts: + # Scope concepts should have no parent + if parent: + errors.append(f"scope 概念 '{name}' 不应有 parent (当前: '{parent}')") + else: + # Non-scope concepts must have a parent + if not parent: + errors.append(f"概念 '{name}' 缺少 parent 字段") + elif parent not in concept_names: + errors.append(f"概念 '{name}' 的 parent '{parent}' 不存在于 concepts 中") + + return errors + + +def check_sources_exist( + units: list[dict], image_index: dict[str, dict], node_index: dict[str, set[str]] +) -> list[str]: + """Check that all source references point to real content.""" + errors = [] + for fu in units: + uid = fu.get("unit_id", "?") + sources = fu.get("sources", []) + if not sources: + errors.append(f"{uid}: sources 为空,必须至少引用一张图片或一段文字") + continue + + has_text = False + has_image = False + + for j, src in enumerate(sources): + src_type = src.get("type", "") + if src_type in ("table", "para"): + has_text = True + section = src.get("section", "") + if not section: + errors.append(f"{uid}.sources[{j}]: 缺少 section") + elif src_type == "logic_tree": + has_image = True + image_id = src.get("image_id", "") + if not image_id: + errors.append(f"{uid}.sources[{j}]: logic_tree 缺少 image_id") + continue + if image_id not in image_index: + errors.append( + f"{uid}.sources[{j}]: image_id '{image_id}' " + f"在 image_analysis 中不存在" + ) + continue + node_ids = src.get("logic_tree_nodes", []) + if node_ids and image_id in node_index: + valid_nodes = node_index[image_id] + for nid in node_ids: + if nid not in valid_nodes: + errors.append( + f"{uid}.sources[{j}]: 节点 '{nid}' 在 " + f"{image_id} 的逻辑树中不存在" + ) + elif not node_ids: + errors.append( + f"{uid}.sources[{j}]: logic_tree 类型但未提供 logic_tree_nodes" + ) + + if not has_text and not has_image: + errors.append(f"{uid}: 必须至少引用一个文本或图片来源") + + return errors + + +def check_logic_tree_coverage( + units: list[dict], node_index: dict[str, set[str]] +) -> list[str]: + """Check that decision and action nodes in logic trees are covered.""" + warnings = [] + for image_id, all_nodes in node_index.items(): + referenced = set() + for fu in units: + for src in fu.get("sources", []): + if src.get("image_id") == image_id: + for nid in src.get("logic_tree_nodes", []): + referenced.add(nid) + + uncovered = all_nodes - referenced + if uncovered: + doc = config.load_input_document() + node_types = {} + for img in doc.get("image_analysis", []): + if img.get("rid") == image_id: + lt = img.get("logic_tree", {}) + for n in lt.get("nodes", []): + node_types[n["id"]] = n.get("type", "?") + break + + decision_action_uncovered = [ + n for n in uncovered if node_types.get(n) in ("decision", "action") + ] + if decision_action_uncovered: + warnings.append( + f"{image_id}: {len(decision_action_uncovered)} 个 " + f"decision/action 节点未被引用: {decision_action_uncovered}" + ) + + return warnings + + +def check_ensemble_confidence(units: list[dict]) -> list[str]: + """Check that every function_unit has confidence, ensemble_support, source_versions.""" + errors = [] + valid_conf = {"high", "medium", "low"} + for fu in units: + uid = fu.get("unit_id", "?") + conf = fu.get("confidence", "") + if not conf: + errors.append(f"{uid}: 缺少 confidence 字段") + elif conf not in valid_conf: + errors.append(f"{uid}: confidence='{conf}' 无效 (期望 high/medium/low)") + support = fu.get("ensemble_support", "") + if not support: + errors.append(f"{uid}: 缺少 ensemble_support 字段") + if "source_versions" not in fu: + errors.append(f"{uid}: 缺少 source_versions 字段") + return errors + + +def check_confidence_summary(si: dict) -> list[str]: + """Check that confidence_summary counts match actual unit/concept confidence.""" + errors = [] + cs = si.get("confidence_summary", {}) + if not cs: + errors.append("缺少 confidence_summary 字段") + return errors + + units = si.get("function_units", []) + concepts = si.get("concepts", []) + + # Count actual confidence levels + unit_high = sum(1 for u in units if u.get("confidence") == "high") + unit_medium = sum(1 for u in units if u.get("confidence") == "medium") + unit_low = sum(1 for u in units if u.get("confidence") == "low") + concept_high = sum(1 for c in concepts if c.get("confidence") == "high") + concept_medium = sum(1 for c in concepts if c.get("confidence") == "medium") + concept_low = sum(1 for c in concepts if c.get("confidence") == "low") + + if cs.get("total_units", 0) != len(units): + errors.append(f"confidence_summary.total_units={cs.get('total_units')} != 实际 {len(units)}") + if cs.get("high", 0) != unit_high: + errors.append(f"confidence_summary.high={cs.get('high')} != 实际 {unit_high}") + if cs.get("medium", 0) != unit_medium: + errors.append(f"confidence_summary.medium={cs.get('medium')} != 实际 {unit_medium}") + if cs.get("low", 0) != unit_low: + errors.append(f"confidence_summary.low={cs.get('low')} != 实际 {unit_low}") + if cs.get("total_concepts", 0) != len(concepts): + errors.append(f"confidence_summary.total_concepts={cs.get('total_concepts')} != 实际 {len(concepts)}") + if cs.get("concept_high", 0) != concept_high: + errors.append(f"confidence_summary.concept_high={cs.get('concept_high')} != 实际 {concept_high}") + if cs.get("concept_medium", 0) != concept_medium: + errors.append(f"confidence_summary.concept_medium={cs.get('concept_medium')} != 实际 {concept_medium}") + if cs.get("concept_low", 0) != concept_low: + errors.append(f"confidence_summary.concept_low={cs.get('concept_low')} != 实际 {concept_low}") + + return errors + + +def run_all_tests(): + print("=" * 60) + print("Step 1 自检测试") + print("=" * 60) + + si, doc = load_inputs() + units = si.get("function_units", []) + concepts = si.get("concepts", []) + image_index = build_image_index(doc) + node_index = build_logic_tree_node_index(doc) + + all_errors = [] + all_warnings = [] + + # Test 1: unit_id and name validity + errors = check_unit_ids(units) + if errors: + print(f"\n{FAIL} unit_id/name 检查: {len(errors)} 个错误") + for e in errors: + print(f" - {e}") + all_errors.extend(errors) + else: + print(f"\n{PASS} unit_id/name 检查: 全部通过 ({len(units)} 个功能单元)") + + # Test 2: path fields + errors = check_unit_paths(units) + if errors: + print(f"\n{FAIL} path 字段检查: {len(errors)} 个错误") + for e in errors: + print(f" - {e}") + all_errors.extend(errors) + else: + print(f"\n{PASS} path 字段检查: 全部通过") + + # Test 3: concept parent references + errors = check_concept_parents(concepts) + if errors: + print(f"\n{FAIL} concept parent 检查: {len(errors)} 个错误") + for e in errors: + print(f" - {e}") + all_errors.extend(errors) + else: + print(f"\n{PASS} concept parent 检查: 全部通过 ({len(concepts)} 个概念)") + + # Test 4: source references exist + errors = check_sources_exist(units, image_index, node_index) + if errors: + print(f"\n{FAIL} 来源引用检查: {len(errors)} 个错误") + for e in errors: + print(f" - {e}") + all_errors.extend(errors) + else: + print(f"\n{PASS} 来源引用检查: 全部通过") + + # Test 5: Logic tree coverage + warnings = check_logic_tree_coverage(units, node_index) + if warnings: + print(f"\n{WARN} 逻辑树节点覆盖率: {len(warnings)} 个警告") + for w in warnings: + print(f" - {w}") + all_warnings.extend(warnings) + else: + print(f"\n{PASS} 逻辑树节点覆盖率: 全部通过") + + # Test 6: Ensemble confidence fields on function_units + errors = check_ensemble_confidence(units) + if errors: + print(f"\n{FAIL} 集成置信度字段: {len(errors)} 个错误") + for e in errors: + print(f" - {e}") + all_errors.extend(errors) + else: + print(f"\n{PASS} 集成置信度字段: 全部通过") + + # Test 7: Confidence summary consistency + errors = check_confidence_summary(si) + if errors: + print(f"\n{FAIL} confidence_summary 一致性: {len(errors)} 个错误") + for e in errors: + print(f" - {e}") + all_errors.extend(errors) + else: + cs = si.get("confidence_summary", {}) + print(f"\n{PASS} confidence_summary 一致性: " + f"high={cs.get('high',0)}, medium={cs.get('medium',0)}, " + f"low={cs.get('low',0)}") + + # Summary + print(f"\n{'='*60}") + total_failures = len(all_errors) + total_warnings = len(all_warnings) + + if total_failures == 0 and total_warnings == 0: + print(f"{PASS} 所有测试通过!") + elif total_failures == 0: + print(f"{WARN} 全部通过但有 {total_warnings} 个警告") + else: + print(f"{FAIL} 测试失败: {total_failures} 个错误, {total_warnings} 个警告") + print("\n请检查 LLM 输出质量,可能需要调整 Prompt 并重新运行 step1_semantic_index.py") + + print(f"\n统计:") + print(f" 功能单元数: {len(units)}") + print(f" 概念数: {len(concepts)}") + print(f" 逻辑树图片数: {len(node_index)}") + + return total_failures == 0 + + +if __name__ == "__main__": + success = run_all_tests() + sys.exit(0 if success else 1) diff --git a/skills/ir_generation_skill/tests/test_step2.py b/skills/ir_generation_skill/tests/test_step2.py new file mode 100644 index 0000000..2e8cef2 --- /dev/null +++ b/skills/ir_generation_skill/tests/test_step2.py @@ -0,0 +1,322 @@ +""" +Tests for Stage 2 (IR Extraction). + +Validates that ir_fragments.json meets quality and structural requirements: +- All fragments have non-empty rules +- All rules have path arrays +- All rules have precondition.geographic_scope and precondition.screen_type +- All trigger conditions have signal/operator/value +- user_interaction content is non-empty and not a placeholder +- No duplicate rule_ids (across all fragments) +""" + +import json +import sys +from pathlib import Path +from collections import Counter + +sys.path.insert(0, str(Path(__file__).parent.parent)) +import config + + +PASS = "[PASS]" +FAIL = "[FAIL]" +WARN = "[WARN]" + +# Forbidden placeholder phrases in user_interaction content +FORBIDDEN_PLACEHOLDERS = [ + "文案由业务定义", "待定", "自定义", "TBD", "todo", "TODO" +] + + +def load_fragments(): + """Load ir_fragments.json.""" + try: + return config.load_json(config.IR_FRAGMENTS_JSON) + except FileNotFoundError: + print(f"{FAIL} ir_fragments.json 未找到: {config.IR_FRAGMENTS_JSON}") + print(" 请先运行 step2_ir_extraction.py") + sys.exit(1) + + +def check_non_empty_rules(fragments: list[dict]) -> list[str]: + """Every fragment must have at least one rule.""" + errors = [] + for f in fragments: + uid = f.get("unit_id", "?") + rules = f.get("rules", []) + if not rules: + if f.get("error"): + errors.append(f"{uid}: 提取失败 — {f['error']}") + else: + errors.append(f"{uid}: rules 为空") + return errors + + +def check_rule_paths(fragments: list[dict]) -> list[str]: + """Every rule must have a non-empty path array.""" + errors = [] + for f in fragments: + uid = f.get("unit_id", "?") + for j, rule in enumerate(f.get("rules", [])): + rid = rule.get("rule_id", f"rule[{j}]") + path = rule.get("path", []) + if not path: + errors.append(f"{rid}: path 字段为空或缺失") + elif not isinstance(path, list): + errors.append(f"{rid}: path 必须是数组") + return errors + + +def check_precondition_fields(fragments: list[dict]) -> list[str]: + """Every rule must have precondition with geographic_scope and screen_type.""" + errors = [] + for f in fragments: + uid = f.get("unit_id", "?") + for j, rule in enumerate(f.get("rules", [])): + rid = rule.get("rule_id", f"rule[{j}]") + precond = rule.get("precondition", {}) + if not precond: + errors.append(f"{rid}: precondition 缺失") + continue + if not precond.get("geographic_scope"): + errors.append(f"{rid}: precondition.geographic_scope 缺失") + if "screen_type" not in precond: + errors.append(f"{rid}: precondition.screen_type 缺失") + return errors + + +def check_user_interaction_content(fragments: list[dict]) -> list[str]: + """user_interaction actions must have non-empty, non-placeholder content.""" + errors = [] + for f in fragments: + uid = f.get("unit_id", "?") + for j, rule in enumerate(f.get("rules", [])): + rid = rule.get("rule_id", f"rule[{j}]") + for k, action in enumerate(rule.get("actions", [])): + if action.get("type") != "user_interaction": + continue + content = action.get("content", "") + if not content: + errors.append( + f"{rid}.actions[{k}]: user_interaction 的 content 为空" + ) + elif any(ph in content for ph in FORBIDDEN_PLACEHOLDERS): + errors.append( + f"{rid}.actions[{k}]: content 包含占位符: '{content}'" + ) + return errors + + +def check_sources_have_logic_tree_nodes(fragments: list[dict]) -> list[str]: + """Every rule should reference at least one logic tree node in its sources.""" + errors = [] + for f in fragments: + uid = f.get("unit_id", "?") + for j, rule in enumerate(f.get("rules", [])): + rid = rule.get("rule_id", f"rule[{j}]") + sources = rule.get("sources", []) + has_logic_tree = any( + src.get("type") == "logic_tree" and src.get("node_ids") + for src in sources + ) + if not has_logic_tree: + has_text = any( + src.get("type") in ("table", "para") for src in sources + ) + if not has_text: + errors.append(f"{rid}: sources 中既无逻辑树引用也无文字引用") + return errors + + +def check_trigger_conditions(fragments: list[dict]) -> list[str]: + """Every trigger condition must have signal, operator, value.""" + errors = [] + for f in fragments: + uid = f.get("unit_id", "?") + for j, rule in enumerate(f.get("rules", [])): + rid = rule.get("rule_id", f"rule[{j}]") + trigger = rule.get("trigger", {}) + conditions = trigger.get("conditions", []) + + if trigger.get("event") is not None: + continue + + for k, cond in enumerate(conditions): + signal = cond.get("signal", "") + operator = cond.get("operator", "") + has_value = "value" in cond + + if not signal: + errors.append(f"{rid}.condition[{k}]: 缺少 signal") + if not operator: + errors.append(f"{rid}.condition[{k}]: 缺少 operator") + if not has_value: + errors.append(f"{rid}.condition[{k}]: 缺少 value") + + return errors + + +def check_duplicate_rule_ids(fragments: list[dict]) -> list[str]: + """Check for duplicate rule_ids across all fragments.""" + all_rule_ids = [] + for f in fragments: + for rule in f.get("rules", []): + rid = rule.get("rule_id", "") + if rid: + all_rule_ids.append(rid) + + duplicates = [rid for rid, count in Counter(all_rule_ids).items() if count > 1] + errors = [] + if duplicates: + errors.append(f"重复 rule_id: {duplicates}") + return errors + + +def check_action_types(fragments: list[dict]) -> list[str]: + """Verify that actions have valid types.""" + valid_types = {"system", "user_interaction"} + errors = [] + for f in fragments: + for j, rule in enumerate(f.get("rules", [])): + rid = rule.get("rule_id", f"rule[{j}]") + for k, action in enumerate(rule.get("actions", [])): + atype = action.get("type", "") + if atype not in valid_types: + errors.append( + f"{rid}.action[{k}]: type='{atype}' 无效, " + f"应为 {valid_types}" + ) + if atype == "user_interaction" and "content" not in action: + errors.append( + f"{rid}.action[{k}]: user_interaction 类型缺少 content 字段" + ) + return errors + + +def run_all_tests(): + print("=" * 60) + print("Step 2 自检测试") + print("=" * 60) + + fragments = load_fragments() + all_errors = [] + total_units = len(fragments) + total_rules = sum(len(f.get("rules", [])) for f in fragments) + + # Test 1: Non-empty rules + errors = check_non_empty_rules(fragments) + if errors: + print(f"\n{FAIL} 非空规则检查: {len(errors)} 个错误") + for e in errors: + print(f" - {e}") + all_errors.extend(errors) + else: + print(f"\n{PASS} 非空规则检查: 全部通过 ({total_units} 个片段)") + + # Test 2: Rule path arrays + errors = check_rule_paths(fragments) + if errors: + print(f"\n{FAIL} 规则 path 字段: {len(errors)} 个错误") + for e in errors[:10]: + print(f" - {e}") + if len(errors) > 10: + print(f" ... 还有 {len(errors) - 10} 个") + all_errors.extend(errors) + else: + print(f"\n{PASS} 规则 path 字段: 全部通过") + + # Test 3: Precondition fields + errors = check_precondition_fields(fragments) + if errors: + print(f"\n{FAIL} precondition 字段: {len(errors)} 个错误") + for e in errors[:10]: + print(f" - {e}") + if len(errors) > 10: + print(f" ... 还有 {len(errors) - 10} 个") + all_errors.extend(errors) + else: + print(f"\n{PASS} precondition 字段: 全部通过") + + # Test 4: user_interaction content + errors = check_user_interaction_content(fragments) + if errors: + print(f"\n{FAIL} user_interaction content: {len(errors)} 个错误") + for e in errors[:10]: + print(f" - {e}") + if len(errors) > 10: + print(f" ... 还有 {len(errors) - 10} 个") + all_errors.extend(errors) + else: + print(f"\n{PASS} user_interaction content: 全部通过") + + # Test 5: Sources have logic tree references + errors = check_sources_have_logic_tree_nodes(fragments) + if errors: + print(f"\n{FAIL} 来源节点引用: {len(errors)} 个规则缺少来源引用") + for e in errors[:10]: + print(f" - {e}") + if len(errors) > 10: + print(f" ... 还有 {len(errors) - 10} 个") + all_errors.extend(errors) + else: + print(f"\n{PASS} 来源节点引用: 全部通过") + + # Test 6: Trigger conditions completeness + errors = check_trigger_conditions(fragments) + if errors: + print(f"\n{FAIL} 触发条件完整性: {len(errors)} 个条件不完整") + for e in errors[:10]: + print(f" - {e}") + if len(errors) > 10: + print(f" ... 还有 {len(errors) - 10} 个") + all_errors.extend(errors) + else: + print(f"\n{PASS} 触发条件完整性: 全部通过") + + # Test 7: No duplicate rule_ids + errors = check_duplicate_rule_ids(fragments) + if errors: + print(f"\n{FAIL} rule_id 唯一性: 发现重复") + for e in errors: + print(f" - {e}") + all_errors.extend(errors) + else: + print(f"\n{PASS} rule_id 唯一性: 全部通过") + + # Test 8: Valid action types + errors = check_action_types(fragments) + if errors: + print(f"\n{FAIL} 动作类型检查: {len(errors)} 个问题") + for e in errors[:10]: + print(f" - {e}") + all_errors.extend(errors) + else: + print(f"\n{PASS} 动作类型检查: 全部通过") + + # Summary + print(f"\n{'='*60}") + total_failures = len(all_errors) + + if total_failures == 0: + print(f"{PASS} 所有测试通过!") + else: + print(f"{FAIL} 测试失败: {total_failures} 个错误") + print("\n建议:") + print(" 1. 检查 ir_fragments.json 中出错的规则") + print(" 2. 如果某些功能单元的规则为空,检查上下文包是否丢失了关键信息") + print(" 3. 调整 Prompt (prompts/step2_ir_extraction.txt) 后重新运行") + + print(f"\n统计:") + print(f" 功能单元数: {total_units}") + print(f" 规则总数: {total_rules}") + error_units = sum(1 for f in fragments if f.get("error")) + if error_units: + print(f" 提取失败的单元: {error_units}") + + return total_failures == 0 + + +if __name__ == "__main__": + success = run_all_tests() + sys.exit(0 if success else 1) diff --git a/skills/ir_generation_skill/tests/test_step2_5.py b/skills/ir_generation_skill/tests/test_step2_5.py new file mode 100644 index 0000000..e24d210 --- /dev/null +++ b/skills/ir_generation_skill/tests/test_step2_5.py @@ -0,0 +1,152 @@ +""" +Tests for Stage 2.5 (Branch Coverage Auto-Completion). + +Validates: +- Path enumeration exists and is non-empty +- Auto-complete fragments have valid structure +- No duplicate unit_ids in autocomplete fragments +- Path coverage improved after autocomplete (if applicable) +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) +import config + + +PASS = "[PASS]" +FAIL = "[FAIL]" +WARN = "[WARN]" + + +def load_path_enumeration(): + """Load path_enumeration.json.""" + try: + return config.load_json(config.PATH_ENUM_JSON) + except FileNotFoundError: + print(f"{FAIL} path_enumeration.json 未找到: {config.PATH_ENUM_JSON}") + print(" 请先运行 step2_5_branch_coverage.py") + sys.exit(1) + + +def load_autocomplete_fragments(): + """Load ir_autocomplete_fragments.json, or return [] if absent.""" + path = config.IR_AUTOCOMPLETE_FRAGMENTS_JSON + if not Path(path).exists(): + return None + return config.load_json(path) + + +def check_path_enumeration(data: dict) -> list[str]: + """Check path enumeration has valid structure.""" + errors = [] + paths = data.get("logic_tree_paths", {}) + if not paths: + errors.append("logic_tree_paths 为空") + total = data.get("total_paths", 0) + if total <= 0: + errors.append(f"total_paths = {total}, 期望 > 0") + + for image_id, image_paths in paths.items(): + if not image_paths: + errors.append(f"{image_id}: 路径列表为空") + continue + for i, p in enumerate(image_paths): + if not p.get("path_id"): + errors.append(f"{image_id}[{i}]: 缺少 path_id") + if not p.get("image_id"): + errors.append(f"{image_id}[{i}]: 缺少 image_id") + if not p.get("node_ids"): + errors.append(f"{image_id}[{i}]: 缺少 node_ids") + + return errors + + +def check_autocomplete_fragments(fragments: list[dict] | None) -> list[str]: + """Check auto-complete fragments have valid structure.""" + if fragments is None: + return ["ir_autocomplete_fragments.json 未生成 (可能无需补全)"] + + errors = [] + seen_unit_ids = set() + + for frag in fragments: + uid = frag.get("unit_id", "") + if not uid: + errors.append("fragment 缺少 unit_id") + continue + if uid in seen_unit_ids: + errors.append(f"unit_id '{uid}' 重复") + seen_unit_ids.add(uid) + + if not frag.get("auto_generated"): + errors.append(f"{uid}: auto_generated 应为 true") + + rules = frag.get("rules", []) + for j, rule in enumerate(rules): + rid = rule.get("rule_id", f"rule[{j}]") + if not rule.get("path"): + errors.append(f"{rid}: path 字段缺失") + precond = rule.get("precondition", {}) + if not precond.get("geographic_scope"): + errors.append(f"{rid}: precondition.geographic_scope 缺失") + + return errors + + +def run_all_tests(): + print("=" * 60) + print("Step 2.5 自检测试") + print("=" * 60) + + all_errors = [] + + # Test 1: Path enumeration exists + try: + path_data = load_path_enumeration() + except SystemExit: + return False + + errors = check_path_enumeration(path_data) + if errors: + print(f"\n{FAIL} 路径枚举检查: {len(errors)} 个错误") + for e in errors: + print(f" - {e}") + all_errors.extend(errors) + else: + total = path_data.get("total_paths", 0) + n_images = len(path_data.get("logic_tree_paths", {})) + print(f"\n{PASS} 路径枚举检查: {total} 条路径, {n_images} 个逻辑树") + + # Test 2: Auto-complete fragments + fragments = load_autocomplete_fragments() + errors = check_autocomplete_fragments(fragments) + + if fragments is None: + print(f"\n{WARN} 自动补全片段: 未生成 (可能所有路径已覆盖)") + elif errors: + print(f"\n{FAIL} 自动补全片段检查: {len(errors)} 个错误") + for e in errors[:10]: + print(f" - {e}") + all_errors.extend(errors) + else: + auto_rules = sum(len(f.get("rules", [])) for f in fragments) + print(f"\n{PASS} 自动补全片段检查: " + f"{len(fragments)} 个片段, {auto_rules} 条规则") + + # Summary + print(f"\n{'='*60}") + total_failures = len(all_errors) + + if total_failures == 0: + print(f"{PASS} 所有测试通过!") + else: + print(f"{FAIL} 测试失败: {total_failures} 个错误") + + return total_failures == 0 + + +if __name__ == "__main__": + success = run_all_tests() + sys.exit(0 if success else 1) diff --git a/skills/ir_generation_skill/tests/test_step3.py b/skills/ir_generation_skill/tests/test_step3.py new file mode 100644 index 0000000..47676a1 --- /dev/null +++ b/skills/ir_generation_skill/tests/test_step3.py @@ -0,0 +1,232 @@ +""" +Tests for Stage 3 (Merge & Audit). + +Validates: +- ir_final.json exists and is well-formed +- No duplicate rule_ids +- All rule_ids follow new hierarchical naming convention +- All rules have path arrays +- ir_audit_report.md exists and contains all required sections +""" + +import re +import sys +from pathlib import Path +from collections import Counter + +sys.path.insert(0, str(Path(__file__).parent.parent)) +import config + + +PASS = "[PASS]" +FAIL = "[FAIL]" +WARN = "[WARN]" + + +def load_ir_final(): + """Load ir_final.json.""" + try: + return config.load_json(config.IR_FINAL_JSON) + except FileNotFoundError: + print(f"{FAIL} ir_final.json 未找到: {config.IR_FINAL_JSON}") + print(" 请先运行 step3_merge_and_audit.py") + sys.exit(1) + + +def load_audit_report(): + """Load ir_audit_report.md if it exists.""" + try: + with open(config.IR_AUDIT_REPORT_MD, "r", encoding="utf-8") as f: + return f.read() + except FileNotFoundError: + print(f"{FAIL} ir_audit_report.md 未找到: {config.IR_AUDIT_REPORT_MD}") + print(" 请先运行 step3_merge_and_audit.py") + sys.exit(1) + + +def check_rule_ids(ir: dict) -> list[str]: + """Check for duplicate rule_ids and hierarchical naming convention. + + Format: DRL-001-DOMESTIC-SYS-FG-INTERRUPT-01 + """ + errors = [] + rules = ir.get("rules", []) + rule_ids = [r.get("rule_id", "") for r in rules] + + # No duplicates + duplicates = [rid for rid, count in Counter(rule_ids).items() if count > 1] + if duplicates: + errors.append(f"重复 rule_id: {duplicates}") + + # New hierarchical naming convention + pattern = re.compile( + r"^[A-Z]+-\d{3}-(DOMESTIC|OVERSEAS)-" + r"(SYS|SDK|OTHER)-" + r"(FG-INTERRUPT|BG-BLOCK|BG-PAUSE|NO-RESTRICT|SWITCH-OFF)-\d{2}$" + ) + for rid in rule_ids: + if rid and not pattern.match(rid): + errors.append( + f"rule_id 命名不规范: '{rid}' " + f"(期望: FEATURE-SCOPE-METHOD-BEHAVIOR-NN)" + ) + + return errors + + +def check_top_level_structure(ir: dict) -> list[str]: + """Check that ir_final has the required top-level fields.""" + errors = [] + for field in ["feature", "feature_id", "rules"]: + if field not in ir: + errors.append(f"ir_final 缺少顶层字段: {field}") + + if not isinstance(ir.get("rules"), list): + errors.append("ir_final.rules 必须是数组") + elif len(ir["rules"]) == 0: + errors.append("ir_final.rules 为空") + + return errors + + +def check_rule_paths(rules: list[dict]) -> list[str]: + """Every rule must have a non-empty path array.""" + errors = [] + for rule in rules: + rid = rule.get("rule_id", "?") + path = rule.get("path", []) + if not path: + errors.append(f"{rid}: path 字段为空或缺失") + return errors + + +def check_rule_completeness(rules: list[dict]) -> list[str]: + """Check each rule has all required fields.""" + errors = [] + required_fields = [ + "rule_id", "description", "priority", "sources", + "precondition", "trigger", "actions" + ] + for i, rule in enumerate(rules): + rid = rule.get("rule_id", f"rule[{i}]") + for field in required_fields: + if field not in rule: + errors.append(f"{rid}: 缺少字段 '{field}'") + if not rule.get("sources"): + errors.append(f"{rid}: sources 为空") + if not rule.get("actions"): + errors.append(f"{rid}: actions 为空") + # Check precondition fields + precond = rule.get("precondition", {}) + if not precond.get("geographic_scope"): + errors.append(f"{rid}: precondition.geographic_scope 缺失") + if "screen_type" not in precond: + errors.append(f"{rid}: precondition.screen_type 缺失") + return errors + + +def check_audit_report(report: str) -> list[str]: + """Check audit report has all required sections.""" + errors = [] + + required_sections = [ + "逻辑树路径覆盖率", + "表格枚举覆盖", + "开关状态", + "一致性扫描报告", + "自动补全摘要", + "规则清单", + ] + for section in required_sections: + if section not in report: + errors.append(f"审计报告缺少章节: {section}") + + # Should have the human review notice + if "人工审查" not in report: + errors.append("审计报告缺少人工审查提示") + + return errors + + +def run_all_tests(): + print("=" * 60) + print("Step 3 自检测试") + print("=" * 60) + + ir = load_ir_final() + report = load_audit_report() + rules = ir.get("rules", []) + all_errors = [] + + # Test 1: Top-level structure + errors = check_top_level_structure(ir) + if errors: + print(f"\n{FAIL} 顶层结构检查: {len(errors)} 个错误") + for e in errors: + print(f" - {e}") + all_errors.extend(errors) + else: + print(f"\n{PASS} 顶层结构检查: 通过 " + f"(feature={ir.get('feature')}, feature_id={ir.get('feature_id')})") + + # Test 2: rule_id uniqueness and naming + errors = check_rule_ids(ir) + if errors: + print(f"\n{FAIL} rule_id 检查: {len(errors)} 个错误") + for e in errors: + print(f" - {e}") + all_errors.extend(errors) + else: + print(f"\n{PASS} rule_id 检查: 全部通过 ({len(rules)} 个唯一 ID, 层次化格式)") + + # Test 3: Rule path fields + errors = check_rule_paths(rules) + if errors: + print(f"\n{FAIL} 规则 path 字段: {len(errors)} 个错误") + for e in errors[:10]: + print(f" - {e}") + all_errors.extend(errors) + else: + print(f"\n{PASS} 规则 path 字段: 全部通过") + + # Test 4: Rule field completeness + errors = check_rule_completeness(rules) + if errors: + print(f"\n{FAIL} 规则字段完整性: {len(errors)} 个错误") + for e in errors[:10]: + print(f" - {e}") + if len(errors) > 10: + print(f" ... 还有 {len(errors) - 10} 个") + all_errors.extend(errors) + else: + print(f"\n{PASS} 规则字段完整性: 全部通过") + + # Test 5: Audit report content + errors = check_audit_report(report) + if errors: + print(f"\n{FAIL} 审计报告检查: {len(errors)} 个错误") + for e in errors: + print(f" - {e}") + all_errors.extend(errors) + else: + print(f"\n{PASS} 审计报告检查: 全部通过 (6 个章节)") + + # Summary + print(f"\n{'='*60}") + total_failures = len(all_errors) + + if total_failures == 0: + print(f"{PASS} 所有测试通过!") + print(f"\n最终交付物:") + print(f" - {config.IR_FINAL_JSON} ({len(rules)} 条规则)") + print(f" - {config.IR_AUDIT_REPORT_MD}") + else: + print(f"{FAIL} 测试失败: {total_failures} 个错误") + print("\n建议: 检查 ir_fragments.json 和合并逻辑,修复问题后重新运行 step3_merge_and_audit.py") + + return total_failures == 0 + + +if __name__ == "__main__": + success = run_all_tests() + sys.exit(0 if success else 1) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..8af4f71 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests package for document_analyzer diff --git a/tests/acceptance/__init__.py b/tests/acceptance/__init__.py new file mode 100644 index 0000000..4f4d8c2 --- /dev/null +++ b/tests/acceptance/__init__.py @@ -0,0 +1 @@ +# QE Acceptance Tests for document_analyzer diff --git a/tests/acceptance/conftest.py b/tests/acceptance/conftest.py new file mode 100644 index 0000000..ce30708 --- /dev/null +++ b/tests/acceptance/conftest.py @@ -0,0 +1,186 @@ +"""Pytest configuration and shared fixtures for QE acceptance tests. + +Usage:: + + pytest tests/acceptance/ -v --run-acceptance [--acceptance-runs=3] + +Environment variables: + DASHSCOPE_API_KEY — LLM API key (required for Layers B/C) + TEST_IR_PATH — path to IR JSON to validate (default: ir_final.json sample) + TEST_PARSED_PATH — path to _parsed.json or _updated.json for coverage analysis +""" + +from __future__ import annotations + +import json +import os +import sys +import tempfile +from pathlib import Path +from typing import Any + +import pytest + +# ── Path setup ────────────────────────────────────────────────────────────── + +_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(_PROJECT_ROOT)) + + +def _skill_path(skill_name: str) -> str: + return str(_PROJECT_ROOT / "skills" / skill_name / "scripts") + + +# ── pytest configuration ──────────────────────────────────────────────────── + + +def pytest_addoption(parser): + parser.addoption( + "--run-acceptance", + action="store_true", + default=False, + help="Run QE acceptance tests (requires DASHSCOPE_API_KEY)", + ) + parser.addoption( + "--acceptance-runs", + type=int, + default=1, + help="Number of IR generation runs for Layer B stability testing (default: 1 = skip)", + ) + parser.addoption( + "--ir-path", + type=str, + default=None, + help="Path to IR JSON file to validate", + ) + parser.addoption( + "--parsed-path", + type=str, + default=None, + help="Path to _parsed.json or _updated.json for coverage analysis", + ) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", + "acceptance: QE acceptance test (requires --run-acceptance flag and DASHSCOPE_API_KEY)", + ) + + +def pytest_collection_modifyitems(config, items): + acceptance_dir = str(_PROJECT_ROOT / "tests" / "acceptance") + acceptance_items = [i for i in items if str(i.fspath).startswith(acceptance_dir)] + non_acceptance_items = [i for i in items if not str(i.fspath).startswith(acceptance_dir)] + + if not config.getoption("--run-acceptance"): + skip_msg = pytest.mark.skip(reason="Need --run-acceptance flag to run") + for item in acceptance_items: + item.add_marker(skip_msg) + # Don't skip non-acceptance tests + return + + if not os.environ.get("DASHSCOPE_API_KEY"): + skip_msg = pytest.mark.skip(reason="DASHSCOPE_API_KEY not set") + for item in acceptance_items: + item.add_marker(skip_msg) + + +# ── Shared fixtures ───────────────────────────────────────────────────────── + + +@pytest.fixture(scope="session") +def project_root() -> Path: + return _PROJECT_ROOT + + +@pytest.fixture(scope="session") +def ir_path(request) -> str: + """Path to the IR JSON file under test.""" + path = ( + request.config.getoption("--ir-path") + or os.environ.get("TEST_IR_PATH") + or str( + Path.home() + / ".openclaw/workspace/skills/doc_parser_skill/output/ir_final.json" + ) + ) + if not os.path.exists(path): + pytest.skip(f"IR file not found: {path}") + return path + + +@pytest.fixture(scope="session") +def ir_data(ir_path: str) -> dict: + """Load the IR JSON data.""" + with open(ir_path, "r", encoding="utf-8") as f: + return json.load(f) + + +@pytest.fixture(scope="session") +def parsed_path(request) -> str | None: + """Path to the corresponding _parsed.json or _updated.json.""" + path = ( + request.config.getoption("--parsed-path") + or os.environ.get("TEST_PARSED_PATH") + or str( + _PROJECT_ROOT + / "skills/ir_generation_skill/车机娱乐系统禁止功能文档_精简_updated.json" + ) + ) + if os.path.exists(path): + return path + return None + + +@pytest.fixture(scope="session") +def parsed_data(parsed_path: str | None) -> dict | None: + """Load the parsed document JSON for coverage analysis.""" + if parsed_path is None: + return None + with open(parsed_path, "r", encoding="utf-8") as f: + return json.load(f) + + +@pytest.fixture(scope="session") +def llm_client(): + """Create an LLMClient instance for acceptance tests. + + Uses the DashScope-compatible LLMClient from the project. + """ + sys.path.insert(0, _skill_path("doc_parser_skill")) + from LLM import LLMClient + + return LLMClient() + + +@pytest.fixture(scope="session") +def acceptance_runs(request) -> int: + return request.config.getoption("--acceptance-runs", default=1) + + +# ── Pipeline runner ───────────────────────────────────────────────────────── + + +@pytest.fixture(scope="session") +def run_ir_pipeline(): + """Return a callable that runs the IR generation pipeline on a parsed JSON. + + Usage:: + + ir_data, ir_path = run_ir_pipeline(parsed_json_path, output_dir) + """ + sys.path.insert(0, _skill_path("ir_generation_skill")) + from ir_generator import generate_ir + + def _run(parsed_path: str, output_dir: str | None = None) -> tuple[dict, str]: + """Run IR generation and return (ir_data, ir_path).""" + out = output_dir or tempfile.mkdtemp(prefix="qe_acceptance_") + result = generate_ir(parsed_path, out, dry_run=False) + ir_list = result.get("ir", []) + ir_path = result.get("path", "") + # ir_generator produces a list; wrap to match rich format expectations + # for schema validation we accept both formats + return ir_list, ir_path + + return _run diff --git a/tests/acceptance/ir_schema.py b/tests/acceptance/ir_schema.py new file mode 100644 index 0000000..05bc3f5 --- /dev/null +++ b/tests/acceptance/ir_schema.py @@ -0,0 +1,325 @@ +"""Rich IR schema definition and validators for the document_analyzer QE framework. + +Target format is the production IR (``ir_final.json``): + {feature, feature_id, config_defaults?, rules: [{rule_id, path, description, + priority, sources, precondition, trigger, actions}]} +""" + +from __future__ import annotations + +import re +from typing import Any + +# ── Constants ──────────────────────────────────────────────────────────────── + +VALID_SOURCE_TYPES = {"table", "logic_tree", "text"} +VALID_ACTION_TYPES = {"system", "user_interaction"} +VALID_PRIORITIES = {"P0", "P1", "P2"} +VALID_TRIGGER_OPERATORS = {"AND", "OR"} + +# rule_id pattern: FEAT-NNN-SCOPE-TYPE-...-PATH-NN (variable middle segments) +RULE_ID_RE = re.compile( + r"^[A-Z]+-\d+(-[A-Z]+)+-\d+$" +) + + +# ── Validation helpers ────────────────────────────────────────────────────── + +def _check(condition: bool, message: str) -> list[str]: + """Return a list with an error message if *condition* is False, else empty list.""" + return [] if condition else [message] + + +def validate_rule(rule: dict, index: int = 0) -> list[str]: + """Validate a single rule dict. Returns a (possibly empty) list of error strings.""" + errors: list[str] = [] + label = f"rules[{index}]" + + if not isinstance(rule, dict): + return [f"{label}: not a dict"] + + # ── required top-level fields ── + for field in ("rule_id", "description"): + errors.extend(_check( + isinstance(rule.get(field), str) and bool(rule[field].strip()), + f'{label}.{field}: required non-empty string', + )) + + # sources is a list, not a string — validated separately below + + # ── rule_id naming ── + rid = rule.get("rule_id", "") + if rid and isinstance(rid, str): + errors.extend(_check( + bool(RULE_ID_RE.match(rid)), + f'{label}.rule_id: "{rid}" does not match pattern FEAT-NNN-SCOPE-TYPE-PATH-NN', + )) + + # ── priority ── + priority = rule.get("priority") + if priority is not None: + errors.extend(_check( + priority in VALID_PRIORITIES, + f'{label}.priority: "{priority}" not in {VALID_PRIORITIES}', + )) + + # ── path ── + path = rule.get("path") + if path is not None: + if not isinstance(path, list): + errors.append(f"{label}.path: must be a list") + elif len(path) == 0: + errors.append(f"{label}.path: must not be empty") + elif not all(isinstance(p, str) and p.strip() for p in path): + errors.append(f"{label}.path: all segments must be non-empty strings") + + # ── sources[] ── + sources = rule.get("sources", []) + if not isinstance(sources, list): + errors.append(f"{label}.sources: must be a list") + elif len(sources) == 0: + errors.append(f"{label}.sources: must have at least one source") + else: + for si, src in enumerate(sources): + errors.extend(_validate_source(src, f"{label}.sources[{si}]")) + + # ── precondition ── + precondition = rule.get("precondition") + if precondition is not None: + if not isinstance(precondition, dict): + errors.append(f"{label}.precondition: must be a dict") + elif len(precondition) == 0: + errors.append(f"{label}.precondition: must not be empty") + + # ── trigger ── + trigger = rule.get("trigger") + if trigger is not None: + if not isinstance(trigger, dict): + errors.append(f"{label}.trigger: must be a dict") + else: + errors.extend(_validate_trigger(trigger, f"{label}.trigger")) + + # ── actions ── + actions = rule.get("actions") + if actions is not None: + if not isinstance(actions, list): + errors.append(f"{label}.actions: must be a list") + else: + for ai, act in enumerate(actions): + errors.extend(_validate_action(act, f"{label}.actions[{ai}]")) + + # ── no null values at any depth ── + errors.extend(_find_nulls(rule, label)) + + return errors + + +def _validate_source(src: dict, label: str) -> list[str]: + errors: list[str] = [] + if not isinstance(src, dict): + return [f"{label}: not a dict"] + + stype = src.get("type", "") + errors.extend(_check( + stype in VALID_SOURCE_TYPES, + f'{label}.type: "{stype}" not in {VALID_SOURCE_TYPES}', + )) + + priority = src.get("priority", "") + if priority: + errors.extend(_check( + priority in ("primary_source", "supplementary"), + f'{label}.priority: "{priority}" must be primary_source or supplementary', + )) + + # type-specific fields + if stype == "table": + errors.extend(_check( + isinstance(src.get("section"), str) and bool(src["section"].strip()), + f"{label}.section: required non-empty string for table source", + )) + errors.extend(_check( + isinstance(src.get("row"), int), + f"{label}.row: required int for table source", + )) + elif stype == "logic_tree": + errors.extend(_check( + isinstance(src.get("image_id"), str) and bool(src["image_id"].strip()), + f"{label}.image_id: required non-empty string for logic_tree source", + )) + node_ids = src.get("node_ids", []) + errors.extend(_check( + isinstance(node_ids, list) and len(node_ids) > 0, + f"{label}.node_ids: required non-empty list for logic_tree source", + )) + elif stype == "text": + errors.extend(_check( + isinstance(src.get("section"), str) and bool(src["section"].strip()), + f"{label}.section: required non-empty string for text source", + )) + + return errors + + +def _validate_trigger(trigger: dict, label: str) -> list[str]: + errors: list[str] = [] + operator = trigger.get("operator", "") + errors.extend(_check( + operator in VALID_TRIGGER_OPERATORS, + f'{label}.operator: "{operator}" not in {VALID_TRIGGER_OPERATORS}', + )) + + conditions = trigger.get("conditions") + if conditions is not None: + if not isinstance(conditions, list): + errors.append(f"{label}.conditions: must be a list") + else: + for ci, cond in enumerate(conditions): + if not isinstance(cond, dict): + errors.append(f"{label}.conditions[{ci}]: not a dict") + else: + errors.extend(_check( + isinstance(cond.get("signal"), str) and bool(cond["signal"].strip()), + f"{label}.conditions[{ci}].signal: required non-empty string", + )) + errors.extend(_check( + "operator" in cond, + f"{label}.conditions[{ci}].operator: required", + )) + # empty conditions is valid (e.g. "switch always off, no conditions") + + return errors + + +def _validate_action(action: dict, label: str) -> list[str]: + errors: list[str] = [] + if not isinstance(action, dict): + return [f"{label}: not a dict"] + + atype = action.get("type", "") + errors.extend(_check( + atype in VALID_ACTION_TYPES, + f'{label}.type: "{atype}" not in {VALID_ACTION_TYPES}', + )) + errors.extend(_check( + isinstance(action.get("description"), str) and bool(action["description"].strip()), + f"{label}.description: required non-empty string", + )) + + return errors + + +def _find_nulls(obj: Any, label: str) -> list[str]: + """Find any None values at any depth in *obj*.""" + errors: list[str] = [] + if obj is None: + return [f"{label}: null value"] + elif isinstance(obj, dict): + for k, v in obj.items(): + errors.extend(_find_nulls(v, f"{label}.{k}")) + elif isinstance(obj, list): + for i, v in enumerate(obj): + errors.extend(_find_nulls(v, f"{label}[{i}]")) + return errors + + +# ── Top-level validation ──────────────────────────────────────────────────── + +def validate_ir(ir_data: dict) -> dict: + """Validate the entire IR document. + + Returns: + { + "valid": bool, + "errors": [str, ...], + "stats": {total_rules, valid_rules, has_config_defaults, ...} + } + """ + errors: list[str] = [] + stats = {"total_rules": 0, "valid_rules": 0, "has_config_defaults": False, "features": 0} + + if not isinstance(ir_data, dict): + return {"valid": False, "errors": ["IR root is not a dict"], "stats": stats} + + # top-level required fields + for field in ("feature", "feature_id", "rules"): + if field not in ir_data: + errors.append(f"root.{field}: missing required field") + elif field in ("feature", "feature_id") and not ( + isinstance(ir_data[field], str) and ir_data[field].strip() + ): + errors.append(f"root.{field}: must be non-empty string") + + # config_defaults (optional) + if "config_defaults" in ir_data: + stats["has_config_defaults"] = True + cd = ir_data["config_defaults"] + if not isinstance(cd, dict): + errors.append("root.config_defaults: must be a dict") + + # rules array + rules = ir_data.get("rules", []) + if not isinstance(rules, list): + errors.append("root.rules: must be a list") + else: + stats["total_rules"] = len(rules) + if len(rules) == 0: + errors.append("root.rules: must have at least one rule") + else: + for i, rule in enumerate(rules): + rule_errors = validate_rule(rule, i) + if rule_errors: + errors.extend(rule_errors) + else: + stats["valid_rules"] += 1 + + # feature count + if isinstance(ir_data.get("feature_id"), str): + stats["features"] = 1 + + return { + "valid": len(errors) == 0, + "errors": errors, + "stats": stats, + } + + +# ── Summary helpers ───────────────────────────────────────────────────────── + +def schema_checklist(ir_data: dict) -> list[dict]: + """Run individual checks and return a checklist for reporting. + + Each item: {"check": str, "passed": bool, "detail": str} + """ + report = validate_ir(ir_data) + checks: list[dict] = [] + + def _add(name: str, passed: bool, detail: str = ""): + checks.append({"check": name, "passed": passed, "detail": detail}) + + # Top-level + _add("root is dict", isinstance(ir_data, dict)) + _add("root.feature present", isinstance(ir_data.get("feature"), str) and bool(ir_data["feature"].strip())) + _add("root.feature_id present", isinstance(ir_data.get("feature_id"), str) and bool(ir_data["feature_id"].strip())) + _add("root.rules is non-empty list", isinstance(ir_data.get("rules"), list) and len(ir_data["rules"]) > 0) + + # Per-rule checks + rules = ir_data.get("rules", []) if isinstance(ir_data, dict) else [] + rule_ids = [] + for i, rule in enumerate(rules): + if not isinstance(rule, dict): + continue + rid = rule.get("rule_id", f"rules[{i}]") + rule_ids.append(rid) + + errs = validate_rule(rule, i) + _add(f"{rid}: valid", len(errs) == 0, "; ".join(errs) if errs else "") + + # Aggregate checks + _add("no duplicate rule_ids", len(rule_ids) == len(set(rule_ids)), + f"duplicates: {[r for r in rule_ids if rule_ids.count(r) > 1]}" if len(rule_ids) != len(set(rule_ids)) else "") + + _add("all rules valid", report["valid"], + f"{report['stats']['valid_rules']}/{report['stats']['total_rules']} valid") + + return checks diff --git a/tests/acceptance/report.py b/tests/acceptance/report.py new file mode 100644 index 0000000..b1ac3d7 --- /dev/null +++ b/tests/acceptance/report.py @@ -0,0 +1,178 @@ +"""Structured JSON report generation for QE acceptance test results. + +Produces a unified report with three-layer verdict: + Layer A – Schema compliance + Layer B – Structural coverage + stability + Layer C – LLM QE expert audit + +Final verdict: PASS (releasable) or FAIL (blocked). +""" + +from __future__ import annotations + +import json +import time +from pathlib import Path +from typing import Any + + +def generate_report( + schema_result: dict, + coverage_result: dict, + audit_result: dict | None, + *, + commit: str = "", + branch: str = "main", + output_path: str | None = None, +) -> dict: + """Assemble the three-layer report and return it. + + Args: + schema_result: ``{"verdict": "PASS"|"FAIL", "total_checks": N, "passed": N, "failed": N}`` + coverage_result: ``{"verdict": "PASS"|"FAIL", "coverage_rate": float, + "stability": {"runs": N, "values": [...], "std": float}}`` + audit_result: ``{"verdict": "ACCEPT"|"REJECT", "inadequate_ratio": float, + "rationale": str, "section_assessments": [...]}`` or None + commit: git commit SHA + branch: branch name + output_path: if set, write the report JSON to this path + + Returns the report dict. + """ + layers: dict[str, Any] = { + "A_schema": schema_result, + "B_coverage": coverage_result, + } + if audit_result is not None: + layers["C_qe_audit"] = audit_result + + # ── final verdict ── + a_pass = schema_result.get("verdict") == "PASS" + b_pass = coverage_result.get("verdict") == "PASS" + c_pass = ( + audit_result is None + or audit_result.get("verdict") == "ACCEPT" + ) + all_pass = a_pass and b_pass and c_pass + + report = { + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), + "commit": commit, + "branch": branch, + "layers": layers, + "final_verdict": "PASS" if all_pass else "FAIL", + "releasable": all_pass, + "failure_details": _failure_details(layers), + } + + if output_path: + out = Path(output_path) + out.parent.mkdir(parents=True, exist_ok=True) + out.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") + + return report + + +def _failure_details(layers: dict) -> list[str]: + """Summarise which layers failed and why.""" + details: list[str] = [] + + schema = layers.get("A_schema", {}) + if schema.get("verdict") != "PASS": + details.append( + f"Layer A (Schema): {schema.get('failed', '?')}/{schema.get('total_checks', '?')} checks failed" + ) + + coverage = layers.get("B_coverage", {}) + if coverage.get("verdict") != "PASS": + cv = coverage.get("coverage_rate", "?") + details.append(f"Layer B (Coverage): rate={cv} (threshold: 0.70)") + + audit = layers.get("C_qe_audit", {}) + if audit.get("verdict") == "REJECT": + details.append( + f"Layer C (QE Audit): REJECT — inadequate_ratio={audit.get('inadequate_ratio', '?')}" + ) + + return details + + +# ── Layer-specific result builders ────────────────────────────────────────── + +def schema_verdict(errors: list[str], stats: dict) -> dict: + """Build Layer A result from schema validation errors & stats.""" + total = stats.get("total_rules", 0) + valid = stats.get("valid_rules", 0) + failed_checks = len(errors) + (total - valid) + + return { + "verdict": "PASS" if failed_checks == 0 else "FAIL", + "total_checks": max(total, 1), # at minimum, we checked the root + "passed": valid if failed_checks == 0 else valid, + "failed": failed_checks, + "rule_pass_rate": round(valid / max(total, 1), 2) if total > 0 else 0, + "sample_errors": errors[:10], # first 10 for the report + } + + +def coverage_verdict( + coverage_rate: float, + stability_std: float, + stability_values: list[float], + *, + coverage_threshold: float = 0.70, + stability_threshold: float = 0.05, + section_coverage: dict | None = None, + table_coverage: dict | None = None, + diagram_coverage: dict | None = None, +) -> dict: + """Build Layer B result from coverage metrics.""" + b1_pass = coverage_rate >= coverage_threshold + b2_pass = stability_std <= stability_threshold + both_pass = b1_pass and b2_pass + + result: dict[str, Any] = { + "verdict": "PASS" if both_pass else "FAIL", + "coverage_rate": round(coverage_rate, 3), + "coverage_threshold": coverage_threshold, + "coverage_pass": b1_pass, + "stability": { + "runs": len(stability_values), + "values": [round(v, 3) for v in stability_values], + "std": round(stability_std, 4), + "threshold": stability_threshold, + "pass": b2_pass, + }, + } + + if section_coverage: + result["section_coverage"] = section_coverage + if table_coverage: + result["table_coverage"] = table_coverage + if diagram_coverage: + result["diagram_coverage"] = diagram_coverage + + return result + + +def audit_verdict(audit_data: dict, *, inadequate_threshold: float = 0.30) -> dict: + """Build Layer C result from LLM QE audit. + + *audit_data* should contain: + inadequate_ratio: float + rationale: str + section_assessments: list[dict] + """ + ratio = audit_data.get("inadequate_ratio", 1.0) + passed = ratio <= inadequate_threshold + + return { + "verdict": "ACCEPT" if passed else "REJECT", + "inadequate_ratio": round(ratio, 3), + "threshold": inadequate_threshold, + "rationale": audit_data.get("rationale", ""), + "total_sections": audit_data.get("total_functional_sections", 0), + "adequate": audit_data.get("adequate", 0), + "inadequate": audit_data.get("inadequate", 0), + "not_applicable": audit_data.get("not_applicable", 0), + } diff --git a/tests/acceptance/test_main_health.py b/tests/acceptance/test_main_health.py new file mode 100644 index 0000000..b64a53c --- /dev/null +++ b/tests/acceptance/test_main_health.py @@ -0,0 +1,558 @@ +"""QE Acceptance Test — Three-layer main branch health check. + +Layer A (Schema): structural correctness of IR +Layer B (Coverage): structural source-traceability coverage + stability +Layer C (QE Audit): LLM as QE expert — functional coverage assessment + +Final verdict: all three layers must pass for main to be releasable. +""" + +from __future__ import annotations + +import json +import math +import re +import statistics +import tempfile +import time +from pathlib import Path +from typing import Any + +import pytest + +from .ir_schema import validate_ir, schema_checklist +from .report import generate_report, schema_verdict, coverage_verdict, audit_verdict + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Layer A: SCHEMA — deterministic structural validation +# ═══════════════════════════════════════════════════════════════════════════════ + +def test_layer_a_schema(ir_data: dict, request): + """Validate IR structure: required fields, types, naming conventions, no nulls.""" + report = validate_ir(ir_data) + checks = schema_checklist(ir_data) + + # Build Layer A result + a_errors = report["errors"] + a_stats = report["stats"] + a_result = schema_verdict(a_errors, a_stats) + a_result["checks"] = checks + + # Store for downstream layers & report + _stash(request, "layer_a", a_result) + + # Assert + assert report["valid"], ( + f"Schema validation FAILED ({len(a_errors)} errors)\n" + + "\n".join(f" - {e}" for e in a_errors[:20]) + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Layer B: STRUCTURAL COVERAGE + STABILITY +# ═══════════════════════════════════════════════════════════════════════════════ + +# Section titles that are NOT functional requirements +NON_FUNCTIONAL_PATTERNS = [ + re.compile(p) for p in [ + r"编制.*变更.*日志", + r"文档背景", + r"文档范围", + r"术语解释", + r"参考", + r"附录", + r"版本", + r"变更记录", + r"目录", + r"前言", + r"概述", + r"简介", + r"概述.*背景", + ] +] + + +def _is_functional_section(section_name: str) -> bool: + """Heuristic: exclude background, glossary, changelog, scope sections. + + Sections that are purely structural — preface, glossary, changelog — are excluded. + Sections with numbering like '3.1.1' are always considered functional. + """ + # Numbered sections are functional + if _section_number(section_name) != section_name: + return True + for pat in NON_FUNCTIONAL_PATTERNS: + if pat.search(section_name): + return False + return True + + +def _extract_content_units(parsed_data: dict) -> dict: + """Extract countable content units from parsed JSON. + + Returns: + {"sections": [{"name": ..., "number": ...}, ...], + "table_rows": int, "diagram_images": [rid, ...]} + """ + sections = parsed_data.get("sections", []) + + functional_sections: list[dict] = [] + total_table_rows = 0 + + for sec in sections: + name = sec.get("source", "") + if _is_functional_section(name): + functional_sections.append({ + "name": name, + "number": _section_number(name), + }) + + for block in sec.get("blocks", []): + if block.get("type") == "table": + rows = block.get("rows", []) + total_table_rows += len(rows) + + # Diagram-type images from image_analysis + diagram_rids: list[str] = [] + for img in parsed_data.get("image_analysis", []): + img_type = img.get("type", "") + if img_type in ("flowchart", "logic_tree", "architecture", + "state", "sequence", "activity"): + diagram_rids.append(img.get("rid", "")) + + return { + "functional_sections": functional_sections, + "table_rows": total_table_rows, + "diagram_images": diagram_rids, + } + + +def _section_number(section_name: str) -> str: + """Extract leading section number, e.g. '3.1.1 系统限制' → '3.1.1'.""" + import re + m = re.match(r"^([\d.]+)", section_name) + return m.group(1) if m else section_name + + +def _section_matches(sec_ref: str, func_sections: list[dict]) -> str | None: + """Find a functional section matching *sec_ref*. Returns the section name or None. + + Matching: exact match → starts-with match → number match → substring match. + """ + # exact + for s in func_sections: + if s["name"] == sec_ref: + return s["name"] + # starts with section number + for s in func_sections: + if s["name"].startswith(sec_ref) or sec_ref.startswith(s["name"]): + return s["name"] + # number match + sec_num = _section_number(sec_ref) + if sec_num: + for s in func_sections: + if s["number"] == sec_num: + return s["name"] + # substring + for s in func_sections: + if sec_ref in s["name"] or s["name"] in sec_ref: + return s["name"] + return None + + +def _measure_coverage(ir_data: dict, parsed_data: dict) -> dict: + """Compute structural coverage of IR over parsed document. + + Returns: + { + "section_coverage": {total, covered, rate, uncovered}, + "table_coverage": {total_rows, covered_rows, rate}, + "diagram_coverage": {total, covered, rate}, + "overall_rate": float, + } + """ + units = _extract_content_units(parsed_data) + rules = ir_data.get("rules", []) + + # ── section coverage ── + func_sections = units["functional_sections"] + covered_sections: set[str] = set() + for rule in rules: + for src in rule.get("sources", []): + sec_ref = src.get("section", "") + if sec_ref: + matched = _section_matches(sec_ref, func_sections) + if matched: + covered_sections.add(matched) + + section_coverage = { + "total": len(func_sections), + "covered": len(covered_sections), + "rate": round(len(covered_sections) / max(len(func_sections), 1), 3), + "uncovered": [s["name"] for s in func_sections + if s["name"] not in covered_sections], + } + + # ── table row coverage ── + covered_rows: set[tuple] = set() + for rule in rules: + for src in rule.get("sources", []): + if src.get("type") == "table": + sec = src.get("section", "") + row = src.get("row") + if sec and row is not None: + covered_rows.add((sec, row)) + + total_rows = units["table_rows"] + table_coverage = { + "total_rows": total_rows, + "covered_rows": len(covered_rows), + "rate": round(len(covered_rows) / max(total_rows, 1), 3), + } + + # ── diagram coverage ── + diagram_rids = units["diagram_images"] + covered_rids: set[str] = set() + for rule in rules: + for src in rule.get("sources", []): + if src.get("type") == "logic_tree": + img_id = src.get("image_id", "") + if img_id and img_id in diagram_rids: + covered_rids.add(img_id) + + diagram_coverage = { + "total": len(diagram_rids), + "covered": len(covered_rids), + "rate": round(len(covered_rids) / max(len(diagram_rids), 1), 3), + "uncovered": [r for r in diagram_rids if r not in covered_rids], + } + + # ── overall ── + rates = [ + section_coverage["rate"], + table_coverage["rate"], + diagram_coverage["rate"], + ] + overall = round(sum(rates) / len(rates), 3) if rates else 0.0 + + return { + "section_coverage": section_coverage, + "table_coverage": table_coverage, + "diagram_coverage": diagram_coverage, + "overall_rate": overall, + } + + +def test_layer_b_coverage( + ir_data: dict, + parsed_data: dict | None, + ir_path: str, + acceptance_runs: int, + run_ir_pipeline, + request, +): + """Measure structural coverage and (optionally) coverage stability.""" + if parsed_data is None: + pytest.skip("No parsed JSON available for coverage analysis") + + # ── B1: single-run coverage ── + cov = _measure_coverage(ir_data, parsed_data) + + # ── B2: stability (multi-run) ── + stability_values: list[float] = [cov["overall_rate"]] + stability_std = 0.0 + + if acceptance_runs > 1: + parsed_path = request.config.getoption("--parsed-path") + if parsed_path and os.path.exists(parsed_path): + for _ in range(acceptance_runs - 1): + try: + ir_list, _ = run_ir_pipeline(parsed_path) + # Convert list-format IR to dict for coverage measurement + run_ir = _wrap_list_ir(ir_list) + run_cov = _measure_coverage(run_ir, parsed_data) + stability_values.append(run_cov["overall_rate"]) + time.sleep(0.5) # rate limiting between runs + except Exception as e: + pytest.fail(f"Stability run failed: {e}") + + if len(stability_values) > 1: + stability_std = statistics.stdev(stability_values) + + # Build Layer B result + b_result = coverage_verdict( + coverage_rate=cov["overall_rate"], + stability_std=stability_std, + stability_values=stability_values, + section_coverage=cov["section_coverage"], + table_coverage=cov["table_coverage"], + diagram_coverage=cov["diagram_coverage"], + ) + _stash(request, "layer_b", b_result) + + # Assert — both B1 and B2 must pass + assert b_result["coverage_pass"], ( + f"Coverage {cov['overall_rate']:.1%} < threshold 70%\n" + f" Sections: {cov['section_coverage']['covered']}/{cov['section_coverage']['total']} " + f"({cov['section_coverage']['rate']:.1%})\n" + f" Uncovered: {cov['section_coverage']['uncovered']}\n" + f" Table rows: {cov['table_coverage']['covered_rows']}/{cov['table_coverage']['total_rows']} " + f"({cov['table_coverage']['rate']:.1%})\n" + f" Diagrams: {cov['diagram_coverage']['covered']}/{cov['diagram_coverage']['total']} " + f"({cov['diagram_coverage']['rate']:.1%})\n" + f" Uncovered diagrams: {cov['diagram_coverage']['uncovered']}" + ) + + if len(stability_values) > 1: + assert b_result["stability"]["pass"], ( + f"Coverage stability std={stability_std:.4f} > threshold 0.05\n" + f" Values across {len(stability_values)} runs: {stability_values}" + ) + + +def _wrap_list_ir(ir_list: list) -> dict: + """Wrap a list-format IR (from ir_generator.py) into a dict for schema compat.""" + # Convert simple format to rich format for coverage measurement + rules = [] + for i, entry in enumerate(ir_list): + if not isinstance(entry, dict): + continue + rule = { + "rule_id": f"GEN-001-RULE-{i:03d}", + "description": entry.get("function", ""), + "path": [], + "priority": "P2", + "sources": [], + "precondition": {}, + "trigger": entry.get("trigger", {"operator": "AND", "conditions": []}), + "actions": [], + } + # Convert source + src = entry.get("source", {}) + if src.get("section"): + rule["sources"].append({ + "type": "text", + "section": src["section"], + "paragraph": 1, + "text_snippet": src.get("location", ""), + "priority": "primary_source", + }) + rules.append(rule) + + return { + "feature": "generated", + "feature_id": "GEN-001", + "rules": rules, + } + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Layer C: LLM QE EXPERT AUDIT +# ═══════════════════════════════════════════════════════════════════════════════ + +QE_AUDITOR_PROMPT = """你是一个资深 QE 专家,负责审查需求文档的 IR(中间表示层)是否充分覆盖了源文档的所有可测试功能点。 + +你不是 IR 的生成者,你是独立的质量审计员。你的职责是判断 IR 的功能覆盖率是否充分。 + +## 审计输入 + +### Layer B 结构化覆盖率数据(参考) +{coverage_summary} + +### 源文档内容(Parsed JSON) +{parsed_content} + +### 生成的 IR(待审计) +{ir_content} + +## 审计要求 + +对源文档中的每个章节逐一评估其功能需求是否被 IR 充分覆盖。 + +**判断标准**: +- **adequate**(充分覆盖):该章节的所有功能需求在 IR 中都有对应的 rule,包括触发条件、执行动作 +- **inadequate**(覆盖不足):该章节存在功能需求未在 IR 中体现,或描述不完整(缺少触发条件或动作) +- **not_applicable**(不适用):该章节为背景介绍、术语定义、变更日志等,不包含功能需求 + +**注意**: +- 如果某个章节涉及多个决策路径(如流程图),检查 IR 是否覆盖了每条路径 +- 表格中的每个功能行都应被至少一个 IR rule 覆盖 +- 图片分析中的流程图/决策树节点应被 IR 引用 + +## 输出格式 + +请严格输出以下 JSON 格式(不要包含代码块标记): + +{{ + "total_functional_sections": , + "adequate": , + "inadequate": , + "not_applicable": , + "inadequate_ratio": , + "verdict": "ACCEPT 或 REJECT", + "rationale": "<一句话说明接受或拒绝的理由>", + "section_assessments": [ + {{ + "section": "<章节名>", + "assessment": "adequate | inadequate | not_applicable", + "reason": "<评估理由>", + "missing": ["<缺失项1>", "<缺失项2>"] // 仅 inadequate 时需要 + }} + ] +}} + +verdict 判定规则: +- inadequate_ratio ≤ 0.30 → "ACCEPT"(风险可控) +- inadequate_ratio > 0.30 → "REJECT"(功能点认知差异大,需要补充 IR) +""" + + +def test_layer_c_qe_audit( + ir_data: dict, parsed_data: dict | None, llm_client, request +): + """LLM QE expert audit of functional coverage.""" + if parsed_data is None: + pytest.skip("No parsed JSON available — cannot run QE audit") + + # ── get Layer B summary for context ── + layer_b = _unstash(request, "layer_b") or {} + cov_summary = json.dumps( + { + "coverage_rate": layer_b.get("coverage_rate", "N/A"), + "section_coverage": layer_b.get("section_coverage", {}), + "diagram_coverage": layer_b.get("diagram_coverage", {}), + }, + ensure_ascii=False, + indent=2, + ) + + # ── prepare content (trim to avoid token overflow) ── + parsed_str = json.dumps(parsed_data, ensure_ascii=False) + ir_str = json.dumps(ir_data, ensure_ascii=False) + + max_parsed = 12000 + max_ir = 8000 + if len(parsed_str) > max_parsed: + parsed_str = parsed_str[:max_parsed] + "\n...[truncated]" + if len(ir_str) > max_ir: + ir_str = ir_str[:max_ir] + "\n...[truncated]" + + prompt = QE_AUDITOR_PROMPT.format( + coverage_summary=cov_summary, + parsed_content=parsed_str, + ir_content=ir_str, + ) + + # ── call LLM ── + try: + raw = llm_client.chat( + model=llm_client.TEXT_MODEL, + messages=[{"role": "user", "content": prompt}], + response_format={"type": "json_object"}, + ) + except Exception as e: + pytest.fail(f"QE audit LLM call failed: {e}") + + # ── parse response ── + audit_data = _parse_json_response(raw) + if audit_data is None: + pytest.fail(f"QE audit returned unparseable response:\n{raw[:500]}") + + # Build Layer C result + c_result = audit_verdict(audit_data) + c_result["raw_assessments"] = audit_data.get("section_assessments", []) + _stash(request, "layer_c", c_result) + + # Assert + assert c_result["verdict"] == "ACCEPT", ( + f"QE Audit REJECTED — inadequate_ratio={c_result['inadequate_ratio']:.1%} > 30%\n" + f" Rationale: {c_result['rationale']}\n" + f" Adequate: {c_result['adequate']}, Inadequate: {c_result['inadequate']}" + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Final report (runs last) +# ═══════════════════════════════════════════════════════════════════════════════ + +def test_final_report(ir_data: dict, ir_path: str, request): + """Generate the final three-layer JSON report. + + This test always passes (report generation). The verdicts from layers A/B/C + determine the final releasable status, but the report itself is informational. + """ + layer_a = _unstash(request, "layer_a") or {"verdict": "SKIPPED"} + layer_b = _unstash(request, "layer_b") or {"verdict": "SKIPPED"} + layer_c = _unstash(request, "layer_c") or {"verdict": "SKIPPED"} + + report_path = request.config.getoption("--json-report-file", None) or str( + Path.cwd() / "acceptance-report.json" + ) + + report = generate_report( + layer_a, + layer_b, + layer_c, + commit=os.environ.get("GITEA_SHA", ""), + branch=os.environ.get("GITEA_BRANCH", "main"), + output_path=report_path, + ) + + # Print summary + print(f"\n{'='*60}") + print(f"QE ACCEPTANCE REPORT") + print(f"{'='*60}") + print(f" Layer A (Schema): {layer_a.get('verdict', '?')}") + print(f" Layer B (Coverage): {layer_b.get('verdict', '?')} " + f"(rate={layer_b.get('coverage_rate', '?')})") + print(f" Layer C (QE Audit): {layer_c.get('verdict', '?')}") + print(f" {'─'*40}") + print(f" FINAL: {report['final_verdict']} | " + f"Releasable: {report['releasable']}") + print(f" Report: {report_path}") + print(f"{'='*60}\n") + + # Fail if any layer failed (aggregate assertion) + failures = report.get("failure_details", []) + if failures: + pytest.fail( + "Acceptance tests FAILED:\n" + "\n".join(f" - {f}" for f in failures) + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Helpers +# ═══════════════════════════════════════════════════════════════════════════════ + +import os # noqa: E402 + +# Module-level stash for sharing results across tests in the same module. +# Each test function stores its result here; later tests read earlier results. +_module_stash: dict[str, dict] = {} + + +def _stash(request, key: str, value: dict): + """Store a result dict for cross-test access within this module.""" + _module_stash[key] = value + + +def _unstash(request, key: str) -> dict | None: + """Retrieve a stashed result.""" + return _module_stash.get(key) + + +def _parse_json_response(raw: str) -> dict | None: + """Parse JSON from an LLM response, handling markdown code fences.""" + if not raw: + return None + text = raw.strip() + if text.startswith("```"): + nl = text.find("\n") + text = text[nl + 1:] if nl != -1 else text[3:] + if text.endswith("```"): + text = text[:-3] + try: + return json.loads(text) + except json.JSONDecodeError: + return None diff --git a/tests/test_sample.py b/tests/test_sample.py index 785528d..088fd42 100644 --- a/tests/test_sample.py +++ b/tests/test_sample.py @@ -55,12 +55,17 @@ def test_import_detect_conflicts(): # -- IR generation tests ------------------------------------------------------ -def test_import_ir_generator(): - """ir_generator module should be importable.""" +def test_import_ir_main(): + """ir_generation main module should be importable (new project structure).""" os.environ.setdefault("DASHSCOPE_API_KEY", "test-fake-key") - _import_from_skill("ir_generation_skill", "ir_generator") - import ir_generator - assert hasattr(ir_generator, "generate_ir") + skill_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "skills", "ir_generation_skill" + ) + if skill_dir not in sys.path: + sys.path.insert(0, skill_dir) + import main + assert hasattr(main, "main") # -- Resolution application tests ---------------------------------------------