sync: update all skills from latest workspace code
CI / test (push) Successful in 8s

doc_parser_skill:
- New: verify_flowchart.py (flowchart validation)
- Updated: LLM.py (multi-provider: DeepSeek + DashScope)
- Updated: image_parser.py (logic tree support, external prompts)
- Updated: SKILL.md, prompts/image_prompt.md

conflict_detection_skill:
- Updated: LLM.py (multi-provider sync)
- Updated: detect_conflicts.py (logic tree text conversion)

ir_generation_skill:
- Replaced old scripts/LLM.py + ir_generator.py with standalone project
- New: main.py, config.py, step1-3_*.py, ensemble_merge.py
- New: prompts/, tests/ subdirectories

tests:
- New: acceptance/ test suite with schema validation
- Fixed: conftest no longer globally skips non-acceptance tests
- Updated: test_sample.py for new ir_generation structure

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-30 22:45:08 +08:00
parent db64df2da1
commit fec4c09ee0
35 changed files with 8021 additions and 530 deletions
+54
View File
@@ -0,0 +1,54 @@
name: QE Acceptance Tests
on:
workflow_dispatch:
inputs:
acceptance_runs:
description: 'Layer B stability runs (1 = skip stability testing)'
required: false
default: '1'
ir_path:
description: 'Path to IR JSON file (relative to workspace)'
required: false
default: 'output/ir_final.json'
parsed_path:
description: 'Path to _parsed.json or _updated.json (relative to workspace)'
required: false
default: 'output/车机娱乐系统禁止功能文档_精简_updated.json'
jobs:
acceptance:
runs-on: shell
timeout-minutes: 30
steps:
- name: Checkout main branch
run: |
git clone --depth 1 http://localhost:3000/pzhang_zywl/document_analyzer.git .
git checkout main
- name: Install dependencies
run: pip install -r requirements.txt
- name: Run QE Acceptance Tests
run: >-
python -m pytest tests/acceptance/ -v
--run-acceptance
--acceptance-runs=${{ github.event.inputs.acceptance_runs }}
--ir-path=${{ github.event.inputs.ir_path }}
--parsed-path=${{ github.event.inputs.parsed_path }}
--tb=long
env:
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
- name: Create issue on failure
if: failure()
env:
GITEA_API_TOKEN: ${{ secrets.GITEA_TOKEN }}
run: >-
python scripts/create_failure_issue.py
--sha "${{ github.sha }}"
--branch "main"
--run "${{ github.run_number }}"
--message "QE Acceptance Tests Failed"
--workflow "QE Acceptance"
--labels "acceptance-failure,agent-task"
+4
View File
@@ -7,3 +7,7 @@ output/
dist/ dist/
.runner .runner
*_output/ *_output/
*.png
*.jpg
acceptance-report.json
ir_final.json
+7 -3
View File
@@ -17,14 +17,18 @@ def main():
parser.add_argument("--run", required=True) parser.add_argument("--run", required=True)
parser.add_argument("--message", required=True) parser.add_argument("--message", required=True)
parser.add_argument("--api-token", default=os.environ.get("GITEA_API_TOKEN", "")) parser.add_argument("--api-token", default=os.environ.get("GITEA_API_TOKEN", ""))
parser.add_argument("--workflow", default="CI", help="Workflow name that triggered this (default: CI)")
parser.add_argument("--labels", default="ci-failure,agent-task",
help="Comma-separated labels for the issue (default: ci-failure,agent-task)")
args = parser.parse_args() args = parser.parse_args()
sha_short = args.sha[:7] sha_short = args.sha[:7]
run_url = f"{GITEA_URL}/{REPO}/actions/runs/{args.run}" run_url = f"{GITEA_URL}/{REPO}/actions/runs/{args.run}"
labels = [l.strip() for l in args.labels.split(",") if l.strip()]
title = f"CI Failure: {args.message[:80]}" title = f"[{args.workflow}] Failure: {args.message[:80]}"
body = ( body = (
f"## CI 测试失败\n\n" f"## {args.workflow} 测试失败\n\n"
f"- **Commit:** {sha_short}\n" f"- **Commit:** {sha_short}\n"
f"- **Branch:** {args.branch}\n" f"- **Branch:** {args.branch}\n"
f"- **工作流运行:** {run_url}\n\n" f"- **工作流运行:** {run_url}\n\n"
@@ -38,7 +42,7 @@ def main():
payload = json.dumps({ payload = json.dumps({
"title": title, "title": title,
"body": body, "body": body,
"labels": [], "labels": labels,
}).encode("utf-8") }).encode("utf-8")
url = f"{GITEA_URL}/api/v1/repos/{REPO}/issues" url = f"{GITEA_URL}/api/v1/repos/{REPO}/issues"
+85 -10
View File
@@ -1,38 +1,97 @@
import logging import logging
import os import os
import time import time
from pathlib import Path
from typing import Optional from typing import Optional
from openai import OpenAI from openai import OpenAI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Resolve secrets file: priority 1) env OPENCLAW_SECRETS,
# 2) workspace-document-analyzer/config/ (relative to skills dir),
# 3) .openclaw/config/
_SECRETS_FILE = None
for _candidate in (
os.environ.get("OPENCLAW_SECRETS", ""),
Path(__file__).resolve().parents[3] / "config" / "secrets.yaml",
Path(__file__).resolve().parents[5] / ".openclaw" / "config" / "secrets.yaml",
):
if _candidate and Path(_candidate).exists():
_SECRETS_FILE = Path(_candidate)
break
if _SECRETS_FILE is None:
_SECRETS_FILE = Path("") # empty fallback
def _load_secrets() -> dict:
"""Load API keys from secrets.yaml, with env-var overrides."""
secrets = {}
if _SECRETS_FILE.exists():
try:
import yaml
with open(_SECRETS_FILE, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
for provider in ("deepseek", "dashscope"):
if provider in data and isinstance(data[provider], dict):
secrets[provider] = data[provider]
except ImportError:
logger.warning("pyyaml not installed, cannot read %s", _SECRETS_FILE)
except Exception as e:
logger.warning("Failed to load %s: %s", _SECRETS_FILE, e)
# Env overrides
dk_env = os.environ.get("DEEPSEEK_API_KEY", "")
ds_env = os.environ.get("DASHSCOPE_API_KEY", "")
if dk_env:
secrets.setdefault("deepseek", {})["apiKey"] = dk_env
if ds_env:
secrets.setdefault("dashscope", {})["apiKey"] = ds_env
return secrets
class LLMClient: class LLMClient:
"""Low-level OpenAI-compatible LLM client with retry and token tracking. """Multi-provider LLM client with retry and token tracking.
Routes text models to DeepSeek, vision models to DashScope (Bailian).
Reads API keys from openclaw config/secrets.yaml, with env-var overrides.
Usage:: Usage::
llm = LLMClient() llm = LLMClient()
content = llm.chat("qwen3.5-flash", [{"role": "user", "content": "Hello"}]) content = llm.chat("deepseek-v4-pro", [{"role": "user", "content": "Hello"}])
print(llm.usage) print(llm.usage)
""" """
IMAGE_MODEL = "qwen3-vl-plus" IMAGE_MODEL = "qwen3-vl-plus"
TEXT_MODEL = "qwen3.5-flash-2026-02-23" TEXT_MODEL = "deepseek-v4-flash"
DASHSCOPE_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
DEEPSEEK_BASE = "https://api.deepseek.com/v1"
TIMEOUT = 120 TIMEOUT = 120
MAX_RETRIES = 3 MAX_RETRIES = 3
_VISION_KEYWORDS = ("vl", "vision", "qwen-vl", "qwen3-vl")
def __init__( def __init__(
self, self,
*, *,
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1",
timeout: int | None = None, timeout: int | None = None,
): ):
key = os.environ.get("DASHSCOPE_API_KEY", "") secrets = _load_secrets()
if not key:
raise ValueError("DASHSCOPE_API_KEY environment variable is not set.") ds_cfg = secrets.get("dashscope", {})
self._client = OpenAI(api_key=key, base_url=base_url) dk_cfg = secrets.get("deepseek", {})
dashscope_key = ds_cfg.get("apiKey", "")
dashscope_url = ds_cfg.get("baseUrl", self.DASHSCOPE_BASE)
deepseek_key = dk_cfg.get("apiKey", "")
deepseek_url = dk_cfg.get("baseUrl", self.DEEPSEEK_BASE)
self._ds_client = OpenAI(api_key=dashscope_key, base_url=dashscope_url) if dashscope_key else None
self._dk_client = OpenAI(api_key=deepseek_key, base_url=deepseek_url) if deepseek_key else None
self._timeout = timeout or self.TIMEOUT self._timeout = timeout or self.TIMEOUT
self._prompt_tokens = 0 self._prompt_tokens = 0
self._completion_tokens = 0 self._completion_tokens = 0
@@ -49,7 +108,7 @@ class LLMClient:
@staticmethod @staticmethod
def estimate_tokens(text: str) -> int: def estimate_tokens(text: str) -> int:
"""Quick token estimate. CJK ≈1.7/token, others ≈3.0/token.""" """Quick token estimate. CJK ≈1.7/token, others ≈3.0/token."""
cjk = sum(1 for c in text if '' <= c <= '鿿' or ' ' <= c <= '') cjk = sum(1 for c in text if '\u4e00' <= c <= '\u9fff' or '\u3000' <= c <= '\u303f')
other = len(text) - cjk other = len(text) - cjk
return max(1, int(cjk / 1.7 + other / 3.0)) return max(1, int(cjk / 1.7 + other / 3.0))
@@ -58,6 +117,20 @@ class LLMClient:
"""Fixed estimate for one vision-model image (~500 tokens).""" """Fixed estimate for one vision-model image (~500 tokens)."""
return 500 return 500
@staticmethod
def _is_vision_model(model: str) -> bool:
return any(kw in model.lower() for kw in LLMClient._VISION_KEYWORDS)
def _get_client(self, model: str) -> OpenAI:
if self._is_vision_model(model):
if self._ds_client is None:
raise ValueError("DASHSCOPE_API_KEY not set but required for vision model")
return self._ds_client
else:
if self._dk_client is None:
raise ValueError("DEEPSEEK_API_KEY not set but required for text model")
return self._dk_client
def chat( def chat(
self, model: str, messages: list[dict], *, timeout: int | None = None, self, model: str, messages: list[dict], *, timeout: int | None = None,
response_format: dict | None = None, response_format: dict | None = None,
@@ -65,8 +138,10 @@ class LLMClient:
"""Send a chat completion request and return the response content. """Send a chat completion request and return the response content.
Automatically retries on failure and accumulates token usage. Automatically retries on failure and accumulates token usage.
Routes to DeepSeek for text, DashScope for vision.
""" """
label = f"chat({model})" label = f"chat({model})"
client = self._get_client(model)
def _call(): def _call():
t0 = time.time() t0 = time.time()
@@ -74,7 +149,7 @@ class LLMClient:
if response_format is not None: if response_format is not None:
kwargs["response_format"] = response_format kwargs["response_format"] = response_format
kwargs["temperature"] = 0 kwargs["temperature"] = 0
resp = self._client.chat.completions.create(**kwargs) resp = client.chat.completions.create(**kwargs)
content = resp.choices[0].message.content content = resp.choices[0].message.content
usg = resp.usage usg = resp.usage
if usg: if usg:
@@ -96,6 +96,77 @@ PROMPT_DETECT_CONFLICT = """你是一个文档一致性检查专家。以下内
""" """
def _is_nested_tree(lt: dict) -> bool:
"""Return True if logic_tree uses the nested children format."""
return isinstance(lt.get("children"), list)
def _logic_tree_to_text(lt: dict) -> str:
"""Convert logic_tree JSON to readable text for conflict detection.
Supports both the new nested-tree format and the legacy flat-nodes format.
"""
if _is_nested_tree(lt):
return _nested_tree_to_text(lt)
return _flat_tree_to_text(lt)
def _nested_tree_to_text(tree: dict) -> str:
"""Convert a nested flowchart tree to readable text."""
lines: list[str] = []
def _walk(node: dict, indent: int = 0):
prefix = " " * indent
nid = node.get("id", "")
name = node.get("name", "")
ntype = node.get("type", "")
type_label = {
"start": "起始", "end": "结束", "process": "处理",
"decision": "判断", "action": "动作",
}.get(ntype, ntype)
lines.append(f"{prefix}[{type_label}] {nid}: {name}")
if ntype == "decision":
for child in node.get("children", []):
cond = child.get("condition", "")
lines.append(f"{prefix} 分支 \"{cond}\":")
_walk(child["node"], indent + 2)
elif "children" in node:
for child in node.get("children", []):
_walk(child, indent + 1)
_walk(tree)
return "\n".join(lines)
def _flat_tree_to_text(lt: dict) -> str:
"""Convert legacy flat-nodes logic_tree to readable text."""
lines: list[str] = []
root = lt.get("root", "")
if root:
lines.append(f"根节点: {root}")
for node in lt.get("nodes", []):
nid = node.get("id", "")
ntype = node.get("type", "")
if ntype == "decision":
cond = node.get("condition", "")
branches = node.get("branches", [])
lines.append(f"判断节点 {nid}: 条件=\"{cond}\"")
for b in branches:
lines.append(f" - 分支 \"{b.get('value', '')}\"{b.get('target', '')}")
elif ntype == "action":
lines.append(f"动作节点 {nid}: {node.get('description', '')}")
elif ntype == "state":
lines.append(f"状态节点 {nid}: {node.get('description', '')}")
elif ntype == "start":
lines.append(f"起始节点 {nid}: {node.get('description', '')}")
elif ntype == "end":
lines.append(f"结束节点 {nid}: {node.get('description', '')}")
return "\n".join(lines)
def _build_text_for_section(sections: list[dict], section_name: str) -> str: def _build_text_for_section(sections: list[dict], section_name: str) -> str:
"""Build a single text block for the given section name.""" """Build a single text block for the given section name."""
texts: list[str] = [] texts: list[str] = []
@@ -184,8 +255,9 @@ def detect_conflicts(
img_type = img.get("type", "other") img_type = img.get("type", "other")
rid = img.get("rid", "") rid = img.get("rid", "")
description = img.get("description", "").strip() description = img.get("description", "").strip()
logic_tree = img.get("logic_tree_nested") or img.get("logic_tree")
if img_type not in DIAGRAM_TYPES or not description: if img_type not in DIAGRAM_TYPES or (not description and not logic_tree):
logger.info("Skip conflict check: rid=%s type=%s", rid, img_type) logger.info("Skip conflict check: rid=%s type=%s", rid, img_type)
continue continue
@@ -211,8 +283,17 @@ def detect_conflicts(
logger.info(" [DRY RUN] would call LLM to detect conflicts") logger.info(" [DRY RUN] would call LLM to detect conflicts")
continue continue
# Enrich description with logic_tree if available
combined_desc = description
if logic_tree:
lt_text = _logic_tree_to_text(logic_tree)
if combined_desc:
combined_desc = f"[结构化逻辑树]\n{lt_text}\n\n[文字描述]\n{combined_desc}"
else:
combined_desc = f"[结构化逻辑树]\n{lt_text}"
prompt = PROMPT_DETECT_CONFLICT.format( prompt = PROMPT_DETECT_CONFLICT.format(
image_description=description, image_description=combined_desc,
text_description=text_content, text_description=text_content,
section_name=section_name, section_name=section_name,
) )
+4 -1
View File
@@ -29,7 +29,10 @@ description: 解析文档(.docx, .pdf)以提取图像和文本结构,并
该技能生成一个结构化JSON文件,文件名为输入文档的基本名称后跟'_parsed.json',包含: 该技能生成一个结构化JSON文件,文件名为输入文档的基本名称后跟'_parsed.json',包含:
- `sections`:按标题分组的文档文本结构 - `sections`:按标题分组的文档文本结构
- `image_sources`:从图像标识符到其在文档中位置的映射 - `image_sources`:从图像标识符到其在文档中位置的映射
- `image_analysis`:由视觉大语言模型确定的每个图像的类型内容描述 - `image_analysis`:由视觉大语言模型确定的每个图像的类型内容描述和(如适用)结构化逻辑树
- `type`: 图片类型(flowchart/architecture/state/sequence/activity/other
- `description`: 图片的文字描述
- `logic_tree`(可选,仅图表类型):结构化逻辑树JSON,包含 `root`(根节点描述)和 `nodes` 数组。节点类型:`decision`(判断)、`action`(动作)、`state`(状态)、`start`(开始)、`end`(结束)。decision 节点包含 `condition``branches` 字段,其他节点包含 `description` 字段。
## 集成点 ## 集成点
@@ -0,0 +1,203 @@
请分析这张图片,判断类型并输出文字描述和(如适用)结构化逻辑树。
## 判断图片类型
如果是 **流程图 / 架构图 / 状态图 / 时序图 / 活动图**,你需要输出三项内容:
1. 类型标签
2. **嵌套逻辑树 JSON**(见下方格式)
3. 文字描述
如果是 **其他类型**(UI原型图 / 界面截图 / 设计稿 / 手机屏幕截图 / 网页截图等),只输出类型标签和简要文字描述。
## 嵌套逻辑树 JSON 格式(仅流程图/架构图/状态图/时序图/活动图需要)
**核心原则:用嵌套的 `children` 数组表达流程的层级关系,而不是用 id 引用。** 这种格式更贴近流程图的自然结构,每个节点的后续步骤直接嵌套在其下方。
### 节点类型
| 类型 | 含义 | 对应形状 |
|------|------|----------|
| `start` | 起始节点 | 椭圆/圆角矩形 |
| `end` | 结束节点 | 椭圆/圆角矩形 |
| `process` | 处理/状态节点 | 矩形/圆角矩形 |
| `decision` | 判断节点 | 菱形 |
| `action` | 动作节点 | 矩形 |
### 非判断节点的 children 格式
对于 `start``end``process``action` 节点,`children` 是一个数组,包含后续步骤节点:
```json
{
"id": "n1",
"name": "节点名称",
"type": "process",
"children": [
{
"id": "n2",
"name": "下一个步骤",
"type": "action",
"children": [...]
}
]
}
```
### 判断节点的 children 格式
对于 `decision` 节点,`children` 是一个数组,每个元素包含 `condition`(分支条件)和 `node`(该分支对应的子节点):
```json
{
"id": "n5",
"name": "是否满足条件?",
"type": "decision",
"children": [
{
"condition": "是",
"node": {
"id": "n6",
"name": "满足条件时的动作",
"type": "action",
"children": [...]
}
},
{
"condition": "否",
"node": {
"id": "n7",
"name": "不满足条件时的动作",
"type": "action",
"children": [...]
}
}
]
}
```
### 结束节点
`end` 节点没有 `children` 字段:
```json
{
"id": "n10",
"name": "流程结束",
"type": "end"
}
```
### 完整示例
```json
{
"id": "n1",
"name": "开关状态",
"type": "start",
"children": [
{
"id": "n2",
"name": "开启",
"type": "process",
"children": [
{
"id": "n3",
"name": "是否在目标场景?",
"type": "decision",
"children": [
{
"condition": "否",
"node": {
"id": "n4",
"name": "不受限",
"type": "end"
}
},
{
"condition": "是",
"node": {
"id": "n5",
"name": "车速是否≥15km/h且持续5秒?",
"type": "decision",
"children": [
{
"condition": "否",
"node": {
"id": "n6",
"name": "不受限",
"type": "end"
}
},
{
"condition": "是",
"node": {
"id": "n7",
"name": "暂停功能",
"type": "action",
"children": [
{
"id": "n8",
"name": "发起Toast提示",
"type": "end"
}
]
}
}
]
}
}
]
}
]
},
{
"id": "n9",
"name": "关闭",
"type": "process",
"children": [
{
"id": "n10",
"name": "不受限",
"type": "end"
}
]
}
]
}
```
### 规则
1. 每条从根节点到 `end` 节点的路径必须是完整的逻辑链
2. `decision` 节点的 `children` 必须穷举所有分支(通常为"是/否"),每条分支包含 `condition``node`
3. 只有 `end` 节点没有 `children` 字段,其他所有节点都应该有 `children`
4. 节点 id 使用 "n1", "n2", "n3"... 格式,按流程图从上到下、从左到右的顺序编号
5. 仔细阅读图片中的每个判断条件和分支走向,确保分支目标节点正确
6. 如果流程图中某个分支的后续步骤在图片中没有展示,将其标记为 `end` 节点,`name` 设为"(图中未展示)"
7. **如果图片包含多个独立的流程图**(例如上半部分和下半部分分别描述不同场景),使用一个统一的 `process` 根节点将它们组织在一起。例如图片中有"策略A"和"策略B"两个流程,结构为:
```json
{
"id": "n1",
"name": "策略总览",
"type": "process",
"children": [
{"id": "n2", "name": "策略A流程", "type": "process", "children": [...]},
{"id": "n3", "name": "策略B流程", "type": "process", "children": [...]}
]
}
```
## 输出格式
**1. 类型标签(单独一行):**
type: <flowchart|architecture|state|sequence|activity|other>
**2. 逻辑树 JSON(仅上述5种类型,以 logic_tree: 开头,后跟 JSON 对象):**
logic_tree:
{...}
**3. 文字描述(以 description: 开头):**
description:
该图片的详细文字描述。对于流程图/架构图等类型,这里提供自然语言总结;对于其他类型,这是唯一的描述内容。
不要输出 ``` 代码块包裹符号,不要输出 ---YAML--- 分隔符,不要添加任何额外的解释或问候语。
+85 -10
View File
@@ -1,38 +1,97 @@
import logging import logging
import os import os
import time import time
from pathlib import Path
from typing import Optional from typing import Optional
from openai import OpenAI from openai import OpenAI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Resolve secrets file: priority 1) env OPENCLAW_SECRETS,
# 2) workspace-document-analyzer/config/ (relative to skills dir),
# 3) .openclaw/config/
_SECRETS_FILE = None
for _candidate in (
os.environ.get("OPENCLAW_SECRETS", ""),
Path(__file__).resolve().parents[3] / "config" / "secrets.yaml",
Path(__file__).resolve().parents[5] / ".openclaw" / "config" / "secrets.yaml",
):
if _candidate and Path(_candidate).exists():
_SECRETS_FILE = Path(_candidate)
break
if _SECRETS_FILE is None:
_SECRETS_FILE = Path("") # empty fallback
def _load_secrets() -> dict:
"""Load API keys from secrets.yaml, with env-var overrides."""
secrets = {}
if _SECRETS_FILE.exists():
try:
import yaml
with open(_SECRETS_FILE, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
for provider in ("deepseek", "dashscope"):
if provider in data and isinstance(data[provider], dict):
secrets[provider] = data[provider]
except ImportError:
logger.warning("pyyaml not installed, cannot read %s", _SECRETS_FILE)
except Exception as e:
logger.warning("Failed to load %s: %s", _SECRETS_FILE, e)
# Env overrides
dk_env = os.environ.get("DEEPSEEK_API_KEY", "")
ds_env = os.environ.get("DASHSCOPE_API_KEY", "")
if dk_env:
secrets.setdefault("deepseek", {})["apiKey"] = dk_env
if ds_env:
secrets.setdefault("dashscope", {})["apiKey"] = ds_env
return secrets
class LLMClient: class LLMClient:
"""Low-level OpenAI-compatible LLM client with retry and token tracking. """Multi-provider LLM client with retry and token tracking.
Routes text models to DeepSeek, vision models to DashScope (Bailian).
Reads API keys from openclaw config/secrets.yaml, with env-var overrides.
Usage:: Usage::
llm = LLMClient() llm = LLMClient()
content = llm.chat("qwen3.5-flash", [{"role": "user", "content": "Hello"}]) content = llm.chat("deepseek-v4-pro", [{"role": "user", "content": "Hello"}])
print(llm.usage) print(llm.usage)
""" """
IMAGE_MODEL = "qwen3-vl-plus" IMAGE_MODEL = "qwen3-vl-plus"
TEXT_MODEL = "qwen3.5-flash-2026-02-23" TEXT_MODEL = "deepseek-v4-flash"
DASHSCOPE_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
DEEPSEEK_BASE = "https://api.deepseek.com/v1"
TIMEOUT = 120 TIMEOUT = 120
MAX_RETRIES = 3 MAX_RETRIES = 3
_VISION_KEYWORDS = ("vl", "vision", "qwen-vl", "qwen3-vl")
def __init__( def __init__(
self, self,
*, *,
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1",
timeout: int | None = None, timeout: int | None = None,
): ):
key = os.environ.get("DASHSCOPE_API_KEY", "") secrets = _load_secrets()
if not key:
raise ValueError("DASHSCOPE_API_KEY environment variable is not set.") ds_cfg = secrets.get("dashscope", {})
self._client = OpenAI(api_key=key, base_url=base_url) dk_cfg = secrets.get("deepseek", {})
dashscope_key = ds_cfg.get("apiKey", "")
dashscope_url = ds_cfg.get("baseUrl", self.DASHSCOPE_BASE)
deepseek_key = dk_cfg.get("apiKey", "")
deepseek_url = dk_cfg.get("baseUrl", self.DEEPSEEK_BASE)
self._ds_client = OpenAI(api_key=dashscope_key, base_url=dashscope_url) if dashscope_key else None
self._dk_client = OpenAI(api_key=deepseek_key, base_url=deepseek_url) if deepseek_key else None
self._timeout = timeout or self.TIMEOUT self._timeout = timeout or self.TIMEOUT
self._prompt_tokens = 0 self._prompt_tokens = 0
self._completion_tokens = 0 self._completion_tokens = 0
@@ -49,7 +108,7 @@ class LLMClient:
@staticmethod @staticmethod
def estimate_tokens(text: str) -> int: def estimate_tokens(text: str) -> int:
"""Quick token estimate. CJK ≈1.7/token, others ≈3.0/token.""" """Quick token estimate. CJK ≈1.7/token, others ≈3.0/token."""
cjk = sum(1 for c in text if '' <= c <= '鿿' or ' ' <= c <= '') cjk = sum(1 for c in text if '\u4e00' <= c <= '\u9fff' or '\u3000' <= c <= '\u303f')
other = len(text) - cjk other = len(text) - cjk
return max(1, int(cjk / 1.7 + other / 3.0)) return max(1, int(cjk / 1.7 + other / 3.0))
@@ -58,6 +117,20 @@ class LLMClient:
"""Fixed estimate for one vision-model image (~500 tokens).""" """Fixed estimate for one vision-model image (~500 tokens)."""
return 500 return 500
@staticmethod
def _is_vision_model(model: str) -> bool:
return any(kw in model.lower() for kw in LLMClient._VISION_KEYWORDS)
def _get_client(self, model: str) -> OpenAI:
if self._is_vision_model(model):
if self._ds_client is None:
raise ValueError("DASHSCOPE_API_KEY not set but required for vision model")
return self._ds_client
else:
if self._dk_client is None:
raise ValueError("DEEPSEEK_API_KEY not set but required for text model")
return self._dk_client
def chat( def chat(
self, model: str, messages: list[dict], *, timeout: int | None = None, self, model: str, messages: list[dict], *, timeout: int | None = None,
response_format: dict | None = None, response_format: dict | None = None,
@@ -65,8 +138,10 @@ class LLMClient:
"""Send a chat completion request and return the response content. """Send a chat completion request and return the response content.
Automatically retries on failure and accumulates token usage. Automatically retries on failure and accumulates token usage.
Routes to DeepSeek for text, DashScope for vision.
""" """
label = f"chat({model})" label = f"chat({model})"
client = self._get_client(model)
def _call(): def _call():
t0 = time.time() t0 = time.time()
@@ -74,7 +149,7 @@ class LLMClient:
if response_format is not None: if response_format is not None:
kwargs["response_format"] = response_format kwargs["response_format"] = response_format
kwargs["temperature"] = 0 kwargs["temperature"] = 0
resp = self._client.chat.completions.create(**kwargs) resp = client.chat.completions.create(**kwargs)
content = resp.choices[0].message.content content = resp.choices[0].message.content
usg = resp.usage usg = resp.usage
if usg: if usg:
+320 -33
View File
@@ -1,6 +1,8 @@
import base64 import base64
import json
import logging import logging
import os import os
import re
from typing import Optional from typing import Optional
from LLM import LLMClient from LLM import LLMClient
@@ -8,32 +10,56 @@ from LLM import LLMClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Prompts # Prompt loading
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
PROMPT_IMAGE = """请分析这张图片,判断类型并输出文字描述。 def _load_prompt() -> str:
"""Load PROMPT_IMAGE from external file, falling back to inline default."""
prompt_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "prompts")
prompt_path = os.path.join(prompt_dir, "image_prompt.md")
if os.path.isfile(prompt_path):
with open(prompt_path, "r", encoding="utf-8") as f:
return f.read()
# Fallback inline prompt (nested tree format)
return """请分析这张图片,判断类型并输出文字描述和(如适用)结构化逻辑树。
## 判断图片类型 ## 判断图片类型
如果是 **流程图 / 架构图 / 状态图 / 时序图 / 活动图**,详细描述 如果是 **流程图 / 架构图 / 状态图 / 时序图 / 活动图**,你需要输出三项内容
- 图中所有节点/步骤/状态/组件的名称 1. 类型标签
- 所有连线/箭头/转换关系及其方向 2. **嵌套逻辑树 JSON**(见下方格式)
- 所有分支条件、判断逻辑和判断结果 3. 文字描述
- 所有文字标注、注释、标签
- 图的整体结构和逻辑流程
- 如果图片包含多个子图,拆解描述
如果是 **其他类型**(UI原型图 / 界面截图 / 设计稿 / 手机屏幕截图 / 网页截图等),简要描述图片内容 如果是 **其他类型**(UI原型图 / 界面截图 / 设计稿 / 手机屏幕截图 / 网页截图等),只输出类型标签和简要文字描述
## 嵌套逻辑树 JSON 格式(仅流程图/架构图/状态图/时序图/活动图需要)
**核心原则:用嵌套的 `children` 数组表达流程的层级关系,而不是用 id 引用。**
节点类型:`start`(起始), `end`(结束), `process`(处理/状态), `decision`(判断), `action`(动作)
非判断节点的 `children` 是子节点数组。`end` 节点无 `children`。
判断节点的 `children` 格式:
```json
{"condition": "", "node": {"id": "n6", "name": "...", "type": "action", "children": [...]}}
```
每条从根到 `end` 的路径必须是完整逻辑链。decision 必须穷举所有分支。
节点 id 使用 "n1", "n2", "n3"... 格式。
## 输出格式 ## 输出格式
**1. 类型标签(单独一行):**
type: <flowchart|architecture|state|sequence|activity|other> type: <flowchart|architecture|state|sequence|activity|other>
**2. 文字描述:** logic_tree:
该图片的详细文字描述。 {...}
不要输出 ---YAML--- 分隔符或 YAML 内容,不要添加任何额外的解释或问候语。""" description:
该图片的详细文字描述。"""
PROMPT_IMAGE = _load_prompt()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -41,7 +67,10 @@ type: <flowchart|architecture|state|sequence|activity|other>
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class ImageParser: class ImageParser:
"""Vision LLM wrapper for parsing images (type + description). """Vision LLM wrapper for parsing images (type + description + logic_tree).
The nested-tree ``logic_tree`` is stored alongside a backward-compatible
flat representation so downstream consumers are not broken.
Usage:: Usage::
@@ -49,7 +78,7 @@ class ImageParser:
result = parser.parse_image("images/img1.png") result = parser.parse_image("images/img1.png")
""" """
_VALID_TYPES = {"flowchart", "architecture", "state", "sequence", "activity", "text"} _VALID_TYPES = {"flowchart", "architecture", "state", "sequence", "activity", "other"}
def __init__(self, llm: LLMClient | None = None): def __init__(self, llm: LLMClient | None = None):
self._llm = llm or LLMClient() self._llm = llm or LLMClient()
@@ -59,9 +88,9 @@ class ImageParser:
return self._llm.usage return self._llm.usage
def parse_image(self, image_path: str) -> Optional[dict]: def parse_image(self, image_path: str) -> Optional[dict]:
"""Parse an image and return its type and description (no YAML IR). """Parse an image and return its type, description, and optional logic_tree.
Returns ``{type, description}``, or *None* for UI mockups. Returns ``{type, description, [logic_tree], [logic_tree_nested]}``.
""" """
logger.info("Parsing image: %s", image_path) logger.info("Parsing image: %s", image_path)
@@ -84,34 +113,292 @@ class ImageParser:
logger.error(str(e)) logger.error(str(e))
return {"type": "other", "description": "", "error": str(e)} return {"type": "other", "description": "", "error": str(e)}
parsed = self._parse_type_and_description(content) parsed = self._parse_response(content)
if parsed is None: if parsed is None:
return None return None
return {"type": parsed[0], "description": parsed[1]} ptype, description, logic_tree_nested = parsed
result: dict = {"type": ptype, "description": description}
if logic_tree_nested is not None:
result["logic_tree_nested"] = logic_tree_nested
result["logic_tree"] = self._flatten_tree(logic_tree_nested)
return result
# ---- internals ---------------------------------------------------------- # ---- internals ----------------------------------------------------------
def _parse_type_and_description(self, content: str) -> Optional[tuple[str, str]]: def _parse_response(self, content: str) -> Optional[tuple[str, str, Optional[dict]]]:
"""Extract ``(type, description)`` from LLM response. """Extract ``(type, description, logic_tree_nested)`` from LLM response.
Returns *None* for ``[[UI]]`` (skip). Parses the nested-tree format. Returns *None* for unparseable content.
""" """
content = content.strip() content = content.strip()
if content == "[[UI]]" or content.startswith("[[UI]]"):
return None
parsed_type = "other" parsed_type = "other"
desc_lines: list[str] = [] logic_tree = None
for line in content.splitlines(): description = ""
stripped = line.strip()
if (stripped.startswith("type:") or stripped.startswith("类型:")) and parsed_type == "other": # --- type ---
type_val = stripped.split(":", 1)[1].strip().lower() type_match = re.search(r'(?:type|类型):\s*(\S+)', content)
if type_match:
type_val = type_match.group(1).strip().lower()
if type_val in self._VALID_TYPES: if type_val in self._VALID_TYPES:
parsed_type = type_val parsed_type = type_val
else:
desc_lines.append(line)
return parsed_type, "\n".join(desc_lines).strip() # --- logic_tree (anchored at line start) ---
lt_match = re.search(r'(?m)^logic_tree:\s*', content)
desc_match = re.search(r'(?m)^description:\s*', content)
if lt_match:
lt_start = lt_match.end()
lt_end = desc_match.start() if desc_match and desc_match.start() > lt_start else len(content)
lt_raw = content[lt_start:lt_end].strip()
# Try multiple JSON extraction strategies
logic_tree = self._extract_json(lt_raw)
if logic_tree is not None:
is_valid, err_msg = self._validate_flowchart(logic_tree)
if not is_valid:
logger.warning("Flowchart validation warning: %s", err_msg)
else:
logger.info("Failed to extract logic_tree JSON. Raw block length=%d", len(lt_raw))
logger.debug("Raw logic_tree block: %s", lt_raw[:500])
elif parsed_type in self._VALID_TYPES - {"other"}:
logger.info("Diagram type=%s but no logic_tree: in response. Response length=%d",
parsed_type, len(content))
logger.debug("Raw response (first 500): %s", content[:500])
# --- description ---
if desc_match:
description = content[desc_match.end():].strip()
else:
desc = content
if type_match:
desc = desc[type_match.end():]
desc = re.sub(r'(?m)^logic_tree:\s*\{.*?\}\s*', '', desc, flags=re.DOTALL)
description = desc.strip()
return parsed_type, description, logic_tree
@staticmethod
def _validate_flowchart(tree: dict) -> tuple[bool, str]:
"""Validate a nested flowchart tree structure.
Returns ``(is_valid, error_message)``. Non-fatal: returns ``False``
with a warning message but the tree is still kept.
"""
if not isinstance(tree, dict):
return False, "logic_tree is not a dict"
seen_ids: set[str] = set()
def _walk(node: dict, depth: int = 0) -> tuple[bool, str]:
if depth > 20:
return False, f"Tree too deep (>20) at node {node.get('id', '?')}"
nid = node.get("id", "")
if not nid:
return False, "Node missing 'id' field"
if not isinstance(nid, str):
return False, f"Node id must be string, got {type(nid).__name__}"
if nid in seen_ids:
return False, f"Duplicate node id: {nid}"
seen_ids.add(nid)
ntype = node.get("type", "")
if ntype not in ("start", "end", "process", "decision", "action"):
return False, f"Unknown node type '{ntype}' at {nid}"
if ntype == "end":
if "children" in node:
return False, f"End node {nid} should not have children"
return True, ""
children = node.get("children")
if not children:
if ntype != "end":
return False, f"Non-end node {nid} ({ntype}) has no children"
return True, ""
if not isinstance(children, list):
return False, f"children of {nid} is not a list"
if ntype == "decision":
for child in children:
if not isinstance(child, dict):
return False, f"decision child of {nid} is not a dict"
if "condition" not in child:
return False, f"decision child of {nid} missing 'condition'"
if "node" not in child:
return False, f"decision child of {nid} missing 'node'"
ok, err = _walk(child["node"], depth + 1)
if not ok:
return False, err
else:
for child in children:
if not isinstance(child, dict):
return False, f"child of {nid} is not a dict"
ok, err = _walk(child, depth + 1)
if not ok:
return False, err
return True, ""
return _walk(tree)
@staticmethod
def _flatten_tree(tree: dict) -> dict:
"""Convert a nested flowchart tree into the legacy flat-nodes format.
This preserves backward compatibility with downstream consumers
(conflict_detection_skill, ir_generator) that expect the flat format.
"""
nodes: list[dict] = []
root_name = ""
def _collect(node: dict):
nonlocal root_name
nid = node.get("id", "")
ntype = node.get("type", "")
name = node.get("name", "")
if root_name == "" and "children" in node:
root_name = name
if ntype == "decision":
branches = []
for child in node.get("children", []):
branches.append({
"value": child.get("condition", ""),
"target": child["node"].get("id", ""),
})
_collect(child["node"])
nodes.append({
"id": nid,
"type": ntype,
"condition": name,
"branches": branches,
})
elif ntype in ("action", "process", "state"):
nodes.append({
"id": nid,
"type": ntype,
"description": name,
})
for child in node.get("children", []):
_collect(child)
elif ntype == "start":
nodes.append({
"id": nid,
"type": ntype,
"description": name,
})
for child in node.get("children", []):
_collect(child)
# end nodes are collected but have no children
_collect(tree)
# Add end nodes from the nested tree
ends: list[dict] = []
def _collect_ends(node: dict):
if node.get("type") == "end":
ends.append({
"id": node.get("id", ""),
"type": "end",
"description": node.get("name", ""),
})
elif "children" in node:
for child in node.get("children", []):
if isinstance(child, dict):
if "node" in child:
_collect_ends(child["node"])
else:
_collect_ends(child)
_collect_ends(tree)
nodes.extend(ends)
return {"root": root_name, "nodes": nodes}
@staticmethod
def extract_paths(tree: dict) -> list[list[dict]]:
"""Extract all root-to-leaf paths from a nested flowchart tree.
Each path is a list of node dicts (each with id, name, type).
Returns a list of paths useful for human review and LLM verification.
"""
paths: list[list[dict]] = []
def _walk(node: dict, current_path: list[dict]):
entry = {"id": node.get("id", ""), "name": node.get("name", ""), "type": node.get("type", "")}
new_path = current_path + [entry]
if node.get("type") == "end":
paths.append(new_path)
return
children = node.get("children", [])
if not children:
paths.append(new_path)
return
if node.get("type") == "decision":
for child in children:
_walk(child["node"], new_path)
else:
for child in children:
_walk(child, new_path)
_walk(tree, [])
return paths
@staticmethod
def paths_to_text(paths: list[list[dict]]) -> str:
"""Render extracted paths as human-readable text for review."""
lines: list[str] = []
for i, path in enumerate(paths, 1):
steps = []
for node in path:
if node["type"] == "decision":
steps.append(f"[判断] {node['name']}")
elif node["type"] == "end":
steps.append(f"[结束] {node['name']}")
else:
steps.append(f"[{node['type']}] {node['name']}")
lines.append(f"路径 {i}: {' -> '.join(steps)}")
return "\n".join(lines)
@staticmethod
def _extract_json(text: str) -> Optional[dict]:
"""Try multiple strategies to extract a JSON object from text.
Returns the parsed dict or None.
"""
# Strategy 1: first { ... } pair (simple regex)
json_match = re.search(r'\{.*\}', text, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group())
except json.JSONDecodeError:
pass
# Strategy 2: find balanced braces
start = text.find("{")
if start >= 0:
depth = 0
for i in range(start, len(text)):
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
if depth == 0:
try:
return json.loads(text[start:i + 1])
except json.JSONDecodeError:
break
return None
@staticmethod @staticmethod
def _mime_type(image_path: str) -> str: def _mime_type(image_path: str) -> str:
@@ -0,0 +1,384 @@
#!/usr/bin/env python3
"""Verify flowchart logic trees for structural correctness and consistency.
Usage::
python verify_flowchart.py <parsed.json|flowchart.json> [--llm] [--output-report REPORT.md]
Performs three levels of checks:
1. **Structural validation** — tree integrity, node uniqueness, leaf types
2. **Path extraction** — renders all root-to-leaf paths as readable text
3. **LLM consistency check** (opt-in with ``--llm``) — compares extracted paths
against the original text description for logical inconsistencies
Outputs PASS/FAIL and a detailed report.
"""
import argparse
import json
import logging
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from image_parser import ImageParser
from LLM import LLMClient
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Prompt for LLM path-vs-description consistency check
# ---------------------------------------------------------------------------
PROMPT_VERIFY_PATHS = """你是一个流程图审核专家。以下内容来自同一张流程图的解析结果:
## 流程图路径(从嵌套逻辑树提取的所有根到叶路径)
```
{paths_text}
```
## 原始文字描述
```
{description}
```
## 你的任务
逐条检查每条路径是否与文字描述一致。重点关注:
1. **分支方向错误**:路径中的判断分支走向是否与文字描述矛盾?
例如:文字说"满足条件后退出",但路径中""分支走向了"不受限"
2. **缺失步骤**:路径中是否缺少文字描述中提到的关键步骤?
3. **冗余步骤**:路径中是否包含文字描述未提及的多余步骤?
4. **条件颠倒**:判断条件的"是/否"分支是否与文字描述相反?
## 输出格式
如果**所有路径一致**,只输出:
```
[[PATHS_CONSISTENT]]
```
如果**发现不一致**,输出 JSON 数组:
```json
[
{{
"path_index": 1,
"issue_type": "branch_error|missing_step|redundant_step|condition_reversed",
"severity": "high|medium|low",
"description": "用中文说明具体问题"
}}
]
```
注意:输出必须是严格合法的 JSON 数组,不要有尾随逗号,不要包含代码块包裹符号。
"""
# ---------------------------------------------------------------------------
# Core verification logic
# ---------------------------------------------------------------------------
def verify_parsed_json(parsed_path: str, *, use_llm: bool = False) -> dict:
"""Load _parsed.json and verify all flowchart logic trees.
Returns a report dict with keys:
- total_flowcharts: int
- passed: int
- failed: int
- results: list of per-flowchart results
"""
with open(parsed_path, "r", encoding="utf-8") as f:
data = json.load(f)
image_analysis = data.get("image_analysis", [])
flowcharts = [img for img in image_analysis if img.get("type") == "flowchart"]
report = {
"total_flowcharts": len(flowcharts),
"passed": 0,
"failed": 0,
"results": [],
}
llm = LLMClient() if use_llm else None
for img in flowcharts:
rid = img.get("rid", "unknown")
logger.info("Verifying flowchart: rid=%s", rid)
result = _verify_single(img, llm)
report["results"].append(result)
if result["structural_ok"] and (not use_llm or result.get("llm_ok", True)):
report["passed"] += 1
else:
report["failed"] += 1
return report
def verify_flowchart_file(filepath: str, *, use_llm: bool = False) -> dict:
"""Load a standalone flowchart JSON file and verify it."""
with open(filepath, "r", encoding="utf-8") as f:
tree = json.load(f)
img = {"logic_tree_nested": tree, "description": "", "rid": os.path.basename(filepath)}
llm = LLMClient() if use_llm else None
result = _verify_single(img, llm)
return {
"total_flowcharts": 1,
"passed": 1 if result["structural_ok"] else 0,
"failed": 0 if result["structural_ok"] else 1,
"results": [result],
}
def _verify_single(img: dict, llm: LLMClient | None) -> dict:
"""Verify a single flowchart image analysis entry."""
rid = img.get("rid", "unknown")
description = img.get("description", "").strip()
# Try nested format first, fall back to flat format
tree = img.get("logic_tree_nested") or img.get("logic_tree")
if tree is None:
return {
"rid": rid,
"structural_ok": False,
"errors": ["No logic_tree found"],
"paths_text": "",
"llm_issues": [],
}
# Check if it's the new nested format or old flat format
is_nested = "children" in tree and isinstance(tree.get("children"), list)
# --- Level 1: Structural validation ---
structural_ok = True
errors: list[str] = []
if is_nested:
ok, err = ImageParser._validate_flowchart(tree)
if not ok:
structural_ok = False
errors.append(f"Structure: {err}")
# Extract paths
paths = ImageParser.extract_paths(tree)
paths_text = ImageParser.paths_to_text(paths)
errors.append(f"Path count: {len(paths)}")
else:
# Old flat format — basic check
nodes = tree.get("nodes", [])
ids = [n.get("id", "") for n in nodes]
if len(ids) != len(set(ids)):
structural_ok = False
errors.append("Structure: duplicate node ids in flat format")
# Build simple path-like text for flat format
paths_text = _flat_to_text(tree)
# --- Level 2: Path count sanity check ---
if is_nested and len(paths) == 0:
structural_ok = False
errors.append("No paths extracted from tree")
# --- Level 3: LLM consistency check ---
llm_issues: list[dict] = []
llm_ok = True
if llm and description and paths_text:
prompt = PROMPT_VERIFY_PATHS.format(
paths_text=paths_text,
description=description,
)
try:
raw = llm.chat(
model=LLMClient.TEXT_MODEL,
messages=[{"role": "user", "content": prompt}],
)
llm_issues = _parse_llm_issues(raw)
if llm_issues:
llm_ok = False
errors.append(f"LLM found {len(llm_issues)} issue(s)")
except RuntimeError as e:
errors.append(f"LLM check failed: {e}")
return {
"rid": rid,
"structural_ok": structural_ok,
"errors": errors,
"paths_text": paths_text,
"llm_ok": llm_ok,
"llm_issues": llm_issues,
}
def _flat_to_text(tree: dict) -> str:
"""Build path-like text from old flat-format logic_tree."""
nodes = tree.get("nodes", [])
root = tree.get("root", "")
lines = [f"Root: {root}"]
node_map = {n["id"]: n for n in nodes}
def _trace(node_id: str, visited: set, path: list[str]) -> list[str]:
if node_id in visited:
path.append(f"[循环] {node_id}")
return path
visited.add(node_id)
node = node_map.get(node_id)
if node is None:
path.append(f"[缺失] {node_id}")
return path
ntype = node.get("type", "")
if ntype == "decision":
cond = node.get("condition", "")
for b in node.get("branches", []):
val = b.get("value", "")
tgt = b.get("target", "")
new_path = path + [f"[判断] {cond}{val}"]
_trace(tgt, visited.copy(), new_path)
elif ntype == "end":
path.append(f"[结束] {node.get('description', '')}")
lines.append(" -> ".join(path))
else:
path.append(f"[{ntype}] {node.get('description', '')}")
# Flat format doesn't have explicit children for non-decision nodes
# so we can't trace further
lines.append(" -> ".join(path))
return path
# Try to find start nodes
starts = [n for n in nodes if n.get("type") == "start"]
if starts:
for s in starts:
_trace(s["id"], set(), [])
else:
lines.append("(Cannot trace: no start node in flat format)")
return "\n".join(lines)
def _parse_llm_issues(content: str) -> list[dict]:
"""Parse LLM response for path consistency issues."""
stripped = content.strip()
if "[[PATHS_CONSISTENT]]" in stripped:
return []
# Remove markdown code fences
if "```json" in stripped:
stripped = stripped.split("```json", 1)[1]
if "```" in stripped:
stripped = stripped.split("```", 1)[0]
elif "```" in stripped:
stripped = stripped.split("```", 1)[1]
if "```" in stripped:
stripped = stripped.split("```", 1)[0]
stripped = stripped.strip()
if not stripped:
return []
try:
issues = json.loads(stripped)
if isinstance(issues, list):
return issues
return []
except json.JSONDecodeError:
logger.debug("Failed to parse LLM issues: %s", stripped[:200])
return []
# ---------------------------------------------------------------------------
# Report rendering
# ---------------------------------------------------------------------------
def print_report(report: dict) -> str:
"""Print a human-readable verification report and return it as a string."""
lines: list[str] = []
lines.append("=" * 60)
lines.append("流程图校验报告")
lines.append("=" * 60)
lines.append(f"流程图总数: {report['total_flowcharts']}")
lines.append(f"通过: {report['passed']}")
lines.append(f"失败: {report['failed']}")
overall = "PASS" if report["failed"] == 0 else "FAIL"
lines.append(f"总体结果: {overall}")
lines.append("")
for i, r in enumerate(report["results"], 1):
rid = r["rid"]
status = "[PASS]" if r["structural_ok"] else "[FAIL]"
lines.append(f"[{i}] rid={rid} {status}")
for err in r.get("errors", []):
lines.append(f" - {err}")
if r.get("paths_text"):
lines.append(" 路径:")
for path_line in r["paths_text"].split("\n"):
lines.append(f" {path_line}")
llm_issues = r.get("llm_issues", [])
if llm_issues:
lines.append(" LLM发现的问题:")
for issue in llm_issues:
lines.append(f" [{issue.get('severity', '?')}] {issue.get('description', '')}")
lines.append("")
report_text = "\n".join(lines)
print(report_text)
return report_text
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Verify flowchart logic trees for correctness.",
)
parser.add_argument(
"input", metavar="FILE",
help="Path to _parsed.json or standalone flowchart JSON",
)
parser.add_argument(
"--llm", action="store_true",
help="Run LLM consistency check (compares paths against text description)",
)
parser.add_argument(
"--output-report", metavar="PATH",
help="Save verification report to a file",
)
args = parser.parse_args()
# Determine input type
with open(args.input, "r", encoding="utf-8") as f:
data = json.load(f)
if "image_analysis" in data:
report = verify_parsed_json(args.input, use_llm=args.llm)
else:
report = verify_flowchart_file(args.input, use_llm=args.llm)
report_text = print_report(report)
if args.output_report:
with open(args.output_report, "w", encoding="utf-8") as f:
f.write(report_text)
logger.info("Report saved: %s", args.output_report)
# Exit code: 0 for PASS, 1 for FAIL
if report["failed"] > 0:
sys.exit(1)
if __name__ == "__main__":
main()
+9
View File
@@ -0,0 +1,9 @@
# Generated output
output/
# Python
__pycache__/
*.pyc
# Console log
Console output.txt
+137
View File
@@ -0,0 +1,137 @@
"""
Shared configuration for the IR Generation pipeline.
Reads API keys from a secrets.yaml file, falling back to environment variables.
"""
import os
import json
import yaml
# ---- Paths ----
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
WORKSPACE_DIR = os.path.dirname(BASE_DIR)
DOC_PARSER_OUTPUT = os.path.join(WORKSPACE_DIR, "doc_parser_skill", "output")
PROMPTS_DIR = os.path.join(BASE_DIR, "prompts")
TESTS_DIR = os.path.join(BASE_DIR, "tests")
OUTPUT_DIR = os.path.join(BASE_DIR, "output")
# Input file (the parsed PRD JSON)
_DEFAULT_INPUT = os.path.join(
DOC_PARSER_OUTPUT,
"车机娱乐系统禁止功能文档_脱敏 v0.9_v2_updated.json",
)
INPUT_JSON = os.environ.get("IR_INPUT_JSON", _DEFAULT_INPUT)
def set_input_file(path: str) -> None:
"""Override the default input JSON path."""
global INPUT_JSON
INPUT_JSON = path
# Secrets file (shared with workspace-document-analyzer)
# .openclaw/workspace/skills/ir_generation_new_skill -> .openclaw/workspace-document-analyzer
OPENCLAW_HOME = os.path.dirname(os.path.dirname(WORKSPACE_DIR))
SECRETS_YAML = os.path.join(
OPENCLAW_HOME, "workspace-document-analyzer", "config", "secrets.yaml",
)
# Intermediate outputs
SEMANTIC_INDEX_R1_JSON = os.path.join(OUTPUT_DIR, "semantic_index_r1.json")
SEMANTIC_INDEX_R2_JSON = os.path.join(OUTPUT_DIR, "semantic_index_r2.json")
SEMANTIC_INDEX_R3_JSON = os.path.join(OUTPUT_DIR, "semantic_index_r3.json")
SEMANTIC_INDEX_JSON = os.path.join(OUTPUT_DIR, "semantic_index.json") # merged final
IR_FRAGMENTS_JSON = os.path.join(OUTPUT_DIR, "ir_fragments.json")
PATH_ENUM_JSON = os.path.join(OUTPUT_DIR, "path_enumeration.json")
IR_AUTOCOMPLETE_FRAGMENTS_JSON = os.path.join(OUTPUT_DIR, "ir_autocomplete_fragments.json")
# Final deliverables (placed in doc_parser output per spec)
IR_FINAL_JSON = os.path.join(DOC_PARSER_OUTPUT, "ir_final.json")
IR_AUDIT_REPORT_MD = os.path.join(DOC_PARSER_OUTPUT, "ir_audit_report.md")
# ---- LLM API ----
# Choose provider: "deepseek" | "dashscope"
LLM_PROVIDER = os.environ.get("IR_PROVIDER", "deepseek")
# Model names per provider
PROVIDER_MODELS = {
"deepseek": os.environ.get("IR_MODEL", "deepseek-v4-flash"),
"dashscope": os.environ.get("IR_MODEL", "qwen-max"),
}
MODEL_NAME = PROVIDER_MODELS.get(LLM_PROVIDER, PROVIDER_MODELS["deepseek"])
# Maximum tokens for LLM responses
MAX_TOKENS = int(os.environ.get("IR_MAX_TOKENS", "16000"))
TEMPERATURE = float(os.environ.get("IR_TEMPERATURE", "0.1"))
# ---- Iteration & Quality ----
MAX_RETRIES_PER_STAGE = int(os.environ.get("IR_MAX_RETRIES", "3"))
COVERAGE_TARGET = float(os.environ.get("IR_COVERAGE_TARGET", "0.95"))
# Stage 1 ensemble temperatures (parallel multi-temperature generation)
ENSEMBLE_TEMPERATURES = [
float(os.environ.get("IR_ENSEMBLE_T1", "0.0")),
float(os.environ.get("IR_ENSEMBLE_T2", "0.3")),
float(os.environ.get("IR_ENSEMBLE_T3", "0.7")),
]
def _load_secrets() -> dict[str, dict[str, str]]:
"""Load provider credentials from secrets.yaml.
Returns a dict like: {"deepseek": {"apiKey": "...", "baseUrl": "..."}, ...}
"""
if os.path.isfile(SECRETS_YAML):
with open(SECRETS_YAML, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {}
return {}
def _get_provider_config(provider: str) -> dict[str, str]:
"""Get {apiKey, baseUrl} for a provider from secrets, with env-var fallback."""
secrets = _load_secrets()
entry = secrets.get(provider, {})
env_prefix = provider.upper()
api_key = (
os.environ.get(f"{env_prefix}_API_KEY")
or entry.get("apiKey", "")
)
base_url = (
os.environ.get(f"{env_prefix}_BASE_URL")
or entry.get("baseUrl", "https://api.deepseek.com/v1")
)
if not api_key:
raise RuntimeError(
f"No API key found for provider '{provider}'. "
f"Check {SECRETS_YAML} or set {env_prefix}_API_KEY."
)
return {"apiKey": api_key, "baseUrl": base_url}
def llm_client():
"""Return an OpenAI-compatible client configured from secrets.yaml."""
from openai import OpenAI
cfg = _get_provider_config(LLM_PROVIDER)
return OpenAI(base_url=cfg["baseUrl"], api_key=cfg["apiKey"])
def load_input_document(path: str | None = None) -> dict:
"""Load the parsed PRD JSON document."""
path = path or INPUT_JSON
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def save_json(data, path: str) -> None:
"""Save data as formatted JSON."""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def load_json(path: str) -> dict:
"""Load a JSON file."""
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
@@ -0,0 +1,593 @@
"""
Deterministic ensemble merge for semantic index generation.
All functions are pure Python with zero LLM calls. Fully testable with mock data.
Cross-references N semantic_index outputs (generated with different temperatures)
and produces a single merged index with confidence scores.
Used by: step1_semantic_index.py
Tested by: tests/test_ensemble_merge.py
"""
from collections import defaultdict
from difflib import SequenceMatcher
# =============================================================================
# Concept Name Similarity
# =============================================================================
def concept_name_similarity(name_a: str, name_b: str) -> float:
"""Compute similarity between two concept names for cross-version matching.
Strategy (in order of precedence):
1. Exact string match -> 1.0
2. Substring containment (one is a substring of the other) -> 0.9
3. SequenceMatcher ratio on character sequences -> 0.0-1.0
Returns:
float in [0.0, 1.0] where >= 0.7 means "likely the same concept".
"""
if name_a == name_b:
return 1.0
# Substring containment: one name is contained in the other
if name_a in name_b or name_b in name_a:
# Only count as similar if they're of comparable length
# (avoid matching "国内" with "国内行车娱乐限制")
len_ratio = min(len(name_a), len(name_b)) / max(len(name_a), len(name_b))
if len_ratio >= 0.5:
return 0.85 + 0.05 * len_ratio # range 0.875-0.90
return 0.55 # too different in length → below threshold
return SequenceMatcher(None, name_a, name_b).ratio()
# =============================================================================
# Concept Clustering & Merging
# =============================================================================
def cluster_concepts(
all_concepts_lists: list[list[dict]],
similarity_threshold: float = 0.7,
) -> list[list[tuple[int, dict]]]:
"""Group concepts across ensemble versions by name similarity.
Uses greedy single-pass clustering: for each concept, find the best-matching
existing cluster. If max similarity >= threshold, add to it; otherwise,
create a new cluster.
Args:
all_concepts_lists: List of concept lists, one per ensemble version.
all_concepts_lists[i] = concepts from version i.
similarity_threshold: Minimum name similarity to join a cluster.
Returns:
List of clusters. Each cluster is list of (version_idx, concept_dict).
"""
clusters = [] # type: list[list[tuple[int, dict]]]
for version_idx, concepts in enumerate(all_concepts_lists):
for c in concepts:
name = c.get("name", "")
if not name:
continue
best_cluster = None
best_sim = 0.0
for cluster in clusters:
# Compare against the first member of the cluster (seed)
seed_name = cluster[0][1].get("name", "")
sim = concept_name_similarity(name, seed_name)
if sim > best_sim:
best_sim = sim
best_cluster = cluster
if best_cluster is not None and best_sim >= similarity_threshold:
best_cluster.append((version_idx, c))
else:
clusters.append([(version_idx, c)])
return clusters
def merge_concept_cluster(
cluster: list[tuple[int, dict]],
total_versions: int,
) -> tuple[dict, str]:
"""Merge a single cluster of matched concepts into one concept dict.
Rules:
- name: Longest name (most specific). Tie-break by lower version_idx.
- aliases: Union of all aliases across versions.
- defined_in: Union of all defined_in across versions.
- parent: Most common non-null parent (voting). Tie-break by lower version_idx.
Returns:
(merged_concept_dict, confidence_level) where confidence is "high"/"medium"/"low".
"""
if not cluster:
return {}, "low"
# --- name: longest (most specific) ---
best_name = ""
best_name_len = 0
for v_idx, c in cluster:
n = c.get("name", "")
if len(n) > best_name_len:
best_name = n
best_name_len = len(n)
elif len(n) == best_name_len and v_idx < cluster[0][0]: # lower version idx
best_name = n
# --- aliases: union ---
aliases = set()
for _, c in cluster:
for a in c.get("aliases", []):
aliases.add(a)
# --- defined_in: union ---
defined_in = set()
for _, c in cluster:
for d in c.get("defined_in", []):
defined_in.add(d)
# --- parent: most common non-null parent (vote) ---
parent_votes = defaultdict(int)
for v_idx, c in cluster:
p = c.get("parent")
if p is not None:
parent_votes[p] += 1
if parent_votes:
best_parent = max(parent_votes, key=lambda p: (parent_votes[p], -1))
else:
best_parent = None
# --- confidence ---
versions_present = len({v_idx for v_idx, _ in cluster})
confidence = compute_confidence_versions(versions_present, total_versions,
any(v_idx == 0 for v_idx, _ in cluster))
merged = {
"name": best_name,
"aliases": sorted(aliases),
"defined_in": sorted(defined_in),
"parent": best_parent,
"confidence": confidence,
}
return merged, confidence
# =============================================================================
# Unit Similarity Functions
# =============================================================================
def _collect_logic_tree_nodes(unit: dict) -> set[str]:
"""Extract the flattened set of all logic tree node IDs from a function_unit."""
nodes = set()
for src in unit.get("sources", []):
if src.get("type") == "logic_tree":
nodes.update(src.get("logic_tree_nodes", []))
return nodes
def unit_node_jaccard(unit_a: dict, unit_b: dict) -> float:
"""Compute Jaccard similarity on logic tree node sets between two units.
Jaccard(A, B) = |A ∩ B| / |A B|. Returns 0.0 if both have no nodes.
"""
nodes_a = _collect_logic_tree_nodes(unit_a)
nodes_b = _collect_logic_tree_nodes(unit_b)
if not nodes_a and not nodes_b:
return 0.0
if not nodes_a or not nodes_b:
return 0.0
intersection = nodes_a & nodes_b
union = nodes_a | nodes_b
return len(intersection) / len(union)
def path_similarity(path_a: list[str], path_b: list[str]) -> float:
"""Compute similarity between two path arrays.
Hybrid approach:
- Sequential similarity (order-aware): SequenceMatcher on joined strings.
- Set similarity (order-independent): Jaccard on path element sets.
- Final score: 0.5 * seq_sim + 0.5 * set_sim
Returns:
float in [0.0, 1.0].
"""
if not path_a and not path_b:
return 1.0
if not path_a or not path_b:
return 0.0
# Sequential similarity
joined_a = "|".join(path_a)
joined_b = "|".join(path_b)
seq_sim = SequenceMatcher(None, joined_a, joined_b).ratio()
# Set similarity
set_a = set(path_a)
set_b = set(path_b)
set_sim = len(set_a & set_b) / len(set_a | set_b)
return 0.5 * seq_sim + 0.5 * set_sim
def unit_similarity(unit_a: dict, unit_b: dict) -> float:
"""Combined similarity between two function_units.
Weighted combination:
- 0.6 * unit_node_jaccard (primary signal: same logic tree nodes = same rule)
- 0.4 * path_similarity (secondary signal: semantic agreement)
Returns:
float in [0.0, 1.0]. >= 0.5 means "likely the same function_unit".
"""
return 0.6 * unit_node_jaccard(unit_a, unit_b) + 0.4 * path_similarity(
unit_a.get("path", []), unit_b.get("path", [])
)
# =============================================================================
# Function Unit Clustering & Merging
# =============================================================================
def cluster_function_units(
all_units_lists: list[list[dict]],
similarity_threshold: float = 0.5,
) -> list[list[tuple[int, dict]]]:
"""Group function_units across ensemble versions by content similarity.
Lowest-temperature versions are processed first (most stable → cluster seeds).
Higher-temperature variants join existing clusters if similar enough.
Args:
all_units_lists: List of unit lists, one per ensemble version.
similarity_threshold: Minimum unit_similarity to join a cluster.
Returns:
List of clusters. Each cluster is list of (version_idx, unit_dict).
"""
clusters = [] # type: list[list[tuple[int, dict]]]
for version_idx, units in enumerate(all_units_lists):
for unit in units:
best_cluster = None
best_sim = 0.0
for cluster in clusters:
# Compare against all members already in the cluster
cluster_sim = max(
unit_similarity(unit, existing_unit)
for (_, existing_unit) in cluster
)
if cluster_sim > best_sim:
best_sim = cluster_sim
best_cluster = cluster
if best_cluster is not None and best_sim >= similarity_threshold:
best_cluster.append((version_idx, unit))
else:
clusters.append([(version_idx, unit)])
return clusters
def pick_best_representative(
cluster: list[tuple[int, dict]],
) -> dict:
"""Select the best function_unit from a cluster as the merged representative.
Scoring formula (all normalized to [0, 1]):
- 0.35: Node count (more logic_tree_nodes = more complete trace)
- 0.25: Source count (more sources = more evidence)
- 0.20: Description length (longer = more detail, capped at 500 chars)
- 0.20: Temperature rank (lower version_idx = lower temp = more stable)
Returns a deep copy of the winning unit dict.
"""
if not cluster:
return {}
# Compute max values for normalization
max_nodes = max(
len(_collect_logic_tree_nodes(unit)) for _, unit in cluster
)
max_sources = max(
len(unit.get("sources", [])) for _, unit in cluster
)
max_desc_len = max(
len(unit.get("description", "")) for _, unit in cluster
)
max_version_idx = max(v_idx for v_idx, _ in cluster)
num_versions = len(cluster)
def score(v_idx: int, unit: dict) -> float:
nodes = len(_collect_logic_tree_nodes(unit))
sources = len(unit.get("sources", []))
desc_len = min(len(unit.get("description", "")), 500)
temp_rank = 1.0 - (v_idx / max(num_versions, max_version_idx + 1))
return (
0.35 * (nodes / max(1, max_nodes))
+ 0.25 * (sources / max(1, max_sources))
+ 0.20 * (desc_len / max(1, max_desc_len))
+ 0.20 * temp_rank
)
best = max(cluster, key=lambda x: score(x[0], x[1]))
return dict(best[1]) # deep-ish copy (1 level)
def merge_unit_sources(
cluster: list[tuple[int, dict]],
) -> list[dict]:
"""Union all sources from units in a cluster, deduplicating by (type, image_id, section).
When the same source key appears in multiple versions, keeps the one with
the most logic_tree_nodes.
"""
# Group by dedup key
source_groups = defaultdict(list)
for v_idx, unit in cluster:
for src in unit.get("sources", []):
# Build a dedup key
src_type = src.get("type", "")
if src_type == "logic_tree":
key = ("logic_tree", src.get("image_id", ""))
else:
key = (src_type, src.get("section", ""), src.get("row", ""))
source_groups[key].append(src)
# Pick best per group
result = []
for key, sources in source_groups.items():
# Pick the source with the most logic_tree_nodes (if any)
best = max(sources, key=lambda s: len(s.get("logic_tree_nodes", [])))
result.append(dict(best))
return result
def compute_confidence_versions(
versions_present: int,
total_versions: int,
includes_lowest_temp: bool = False,
) -> str:
"""Compute 3-level confidence based on cross-version agreement.
- "high": Appears in all versions, OR >= 2/3 with lowest-temp version (T=0.0).
- "medium": Appears in >= half the versions but not all.
- "low": Appears in fewer than half (singleton in ensemble).
Args:
versions_present: Number of versions this item appeared in.
total_versions: Total number of ensemble versions.
includes_lowest_temp: Whether the item appeared in the T=0.0 version.
"""
ratio = versions_present / total_versions
if ratio >= 1.0:
return "high"
if ratio >= 0.5 and includes_lowest_temp:
return "high"
if ratio >= 0.5:
return "medium"
return "low"
def ensemble_merge_concepts(
all_concepts_lists: list[list[dict]],
) -> list[dict]:
"""Merge concepts across all ensemble versions.
Returns:
List of merged concept dicts, each with added "confidence" field.
"""
total = len(all_concepts_lists)
clusters = cluster_concepts(all_concepts_lists)
merged = []
seen_names = set()
for cluster in clusters:
concept, confidence = merge_concept_cluster(cluster, total)
name = concept.get("name", "")
if name and name not in seen_names:
concept["ensemble_support"] = f"{len({v for v, _ in cluster})}/{total}"
merged.append(concept)
seen_names.add(name)
# Sort: high confidence first, then by name
conf_order = {"high": 0, "medium": 1, "low": 2}
merged.sort(key=lambda c: (conf_order.get(c.get("confidence", "low"), 3), c.get("name", "")))
# Validate and fix parent references
merged = _validate_concept_parents(merged)
return merged
def _validate_concept_parents(concepts: list[dict]) -> list[dict]:
"""Post-merge: validate that every concept's parent exists in the list.
Strategy for dangling parents:
1. Fuzzy match (concept_name_similarity >= 0.7) → fix reference
2. No match → set parent to null, downgrade confidence to "low"
"""
concept_names = {c["name"] for c in concepts}
conf_order = {"high": 0, "medium": 1, "low": 2}
for c in concepts:
parent = c.get("parent")
if parent is None:
continue
if parent in concept_names:
continue
# Dangling parent — try fuzzy match
best_match = None
best_sim = 0.0
for name in concept_names:
sim = concept_name_similarity(parent, name)
if sim > best_sim:
best_sim = sim
best_match = name
if best_match and best_sim >= 0.7:
c["parent"] = best_match
# Downgrade if match was fuzzy (not exact)
if best_sim < 1.0:
current_conf = c.get("confidence", "low")
c["confidence"] = _downgrade_confidence(current_conf)
else:
c["parent"] = None
c["confidence"] = _downgrade_confidence(c.get("confidence", "low"))
# Re-sort after confidence changes
concepts.sort(key=lambda c: (conf_order.get(c.get("confidence", "low"), 3), c.get("name", "")))
return concepts
def _downgrade_confidence(current: str) -> str:
"""Drop confidence one level."""
if current == "high":
return "medium"
return "low"
def ensemble_merge_function_units(
all_units_lists: list[list[dict]],
) -> list[dict]:
"""Merge function_units across all ensemble versions.
1. Cluster units across versions.
2. For each cluster: pick best, merge sources, compute confidence.
3. Reassign stable unit_ids: FU-ENS-001, FU-ENS-002, ...
Returns:
List of merged function_unit dicts with added "confidence",
"ensemble_support", "source_versions" fields.
"""
total = len(all_units_lists)
clusters = cluster_function_units(all_units_lists)
merged = []
for cluster in clusters:
# Pick best representative
best = pick_best_representative(cluster)
# Merge sources from all cluster members
best["sources"] = merge_unit_sources(cluster)
# Compute confidence
versions_present = len({v_idx for v_idx, _ in cluster})
includes_t0 = any(v_idx == 0 for v_idx, _ in cluster)
confidence = compute_confidence_versions(
versions_present, total, includes_t0
)
best["confidence"] = confidence
best["ensemble_support"] = f"{versions_present}/{total}"
best["source_versions"] = versions_present
merged.append(best)
# Sort by confidence desc, then by unit_id
conf_order = {"high": 0, "medium": 1, "low": 2}
merged.sort(key=lambda u: (conf_order.get(u.get("confidence", "low"), 3),
u.get("unit_id", "")))
# Reassign stable unit_ids
for i, unit in enumerate(merged):
# Preserve original unit_id for traceability
if "original_unit_id" not in unit:
unit["original_unit_id"] = unit.get("unit_id", "")
unit["unit_id"] = f"FU-ENS-{i + 1:03d}"
return merged
# =============================================================================
# Top-Level Ensemble Merge
# =============================================================================
def ensemble_merge(
semantic_indices: list[dict],
) -> dict:
"""Merge N semantic index outputs into one ensemble result.
Args:
semantic_indices: List of semantic_index dicts from each temperature run.
semantic_indices[0] should be the lowest-temperature version.
Returns:
Merged semantic_index dict with structure:
{
"feature_name": str,
"ensemble_versions": int,
"concepts": [...],
"function_units": [...],
"confidence_summary": {...},
}
"""
if not semantic_indices:
return {
"feature_name": "",
"ensemble_versions": 0,
"concepts": [],
"function_units": [],
"confidence_summary": {},
}
total = len(semantic_indices)
# Extract concepts and function_units from each version
all_concepts = [si.get("concepts", []) for si in semantic_indices]
all_units = [si.get("function_units", []) for si in semantic_indices]
# Merge
merged_concepts = ensemble_merge_concepts(all_concepts)
merged_units = ensemble_merge_function_units(all_units)
# Feature name: majority vote across versions
feature_names = [si.get("feature_name", "") for si in semantic_indices]
name_counts = defaultdict(int)
for fn in feature_names:
if fn:
name_counts[fn] += 1
feature_name = max(name_counts, key=name_counts.get) if name_counts else ""
# Confidence summary
unit_conf = defaultdict(int)
for u in merged_units:
unit_conf[u.get("confidence", "low")] += 1
concept_conf = defaultdict(int)
for c in merged_concepts:
concept_conf[c.get("confidence", "low")] += 1
return {
"feature_name": feature_name,
"ensemble_versions": total,
"concepts": merged_concepts,
"function_units": merged_units,
"confidence_summary": {
"total_units": len(merged_units),
"high": unit_conf.get("high", 0),
"medium": unit_conf.get("medium", 0),
"low": unit_conf.get("low", 0),
"total_concepts": len(merged_concepts),
"concept_high": concept_conf.get("high", 0),
"concept_medium": concept_conf.get("medium", 0),
"concept_low": concept_conf.get("low", 0),
},
}
+157
View File
@@ -0,0 +1,157 @@
"""
IR Generation Pipeline Orchestrator.
Run all four stages sequentially:
python main.py [--skip-step1] [--skip-step2] [--skip-step2.5] [--skip-step3] [--test-only]
The pipeline reads the parsed PRD JSON from doc_parser and produces:
- ir_final.json: the final IR rules
- ir_audit_report.md: completeness audit report for human review
"""
import argparse
import os
import subprocess
import sys
from pathlib import Path
import config
BASE_DIR = Path(__file__).parent
def _subprocess_env(extra: dict | None = None) -> dict:
"""Build environment dict for subprocesses, carrying forward overrides."""
env = os.environ.copy()
env.update(extra or {})
return env
def run_step(script_name: str, description: str, extra_env: dict | None = None) -> bool:
"""Run a single pipeline step script, return True if it succeeded."""
print(f"\n{'#' * 60}")
print(f"# {description}")
print(f"{'#' * 60}")
script_path = BASE_DIR / script_name
if not script_path.exists():
print(f"错误: 脚本不存在 {script_path}")
return False
result = subprocess.run(
[sys.executable, str(script_path)],
cwd=str(BASE_DIR),
env=_subprocess_env(extra_env),
)
return result.returncode == 0
def run_test(test_name: str, description: str, extra_env: dict | None = None) -> bool:
"""Run a test script, return True if all tests passed."""
print(f"\n{'='*60}")
print(f"测试: {description}")
print(f"{'='*60}")
test_path = BASE_DIR / "tests" / test_name
if not test_path.exists():
print(f"错误: 测试脚本不存在 {test_path}")
return False
result = subprocess.run(
[sys.executable, str(test_path)],
cwd=str(BASE_DIR),
env=_subprocess_env(extra_env),
)
return result.returncode == 0
def main():
parser = argparse.ArgumentParser(description="IR Generation Pipeline")
parser.add_argument("--skip-step1", action="store_true",
help="跳过阶段一(语义索引)")
parser.add_argument("--skip-step2", action="store_true",
help="跳过阶段二(IR 提取)")
parser.add_argument("--skip-step2.5", "--skip-step2-5", action="store_true",
dest="skip_step2_5",
help="跳过阶段2.5(分支覆盖自动补全)")
parser.add_argument("--skip-step3", action="store_true",
help="跳过阶段三(合并与审计)")
parser.add_argument("--test-only", action="store_true",
help="仅运行测试,不调用 LLM")
parser.add_argument(
"--input", "-i", type=str, default=None,
help="输入 JSON 文件路径(覆盖默认的 doc_parser 输出)"
)
parser.add_argument(
"--provider", "-p", type=str, default=None,
help="LLM provider: deepseek | dashscope(覆盖 IR_PROVIDER 环境变量)"
)
args = parser.parse_args()
# Build extra env vars for subprocesses
extra_env = {}
if args.input:
extra_env["IR_INPUT_JSON"] = args.input
print(f"输入文件: {args.input}")
if args.provider:
extra_env["IR_PROVIDER"] = args.provider
print(f"LLM Provider: {args.provider}")
if args.test_only:
all_ok = True
all_ok &= run_test("test_step1.py", "Step 1 验证", extra_env)
all_ok &= run_test("test_step2.py", "Step 2 验证", extra_env)
all_ok &= run_test("test_step2_5.py", "Step 2.5 验证", extra_env)
all_ok &= run_test("test_step3.py", "Step 3 验证", extra_env)
sys.exit(0 if all_ok else 1)
failures = []
# Stage 1
if not args.skip_step1:
ok = run_step("step1_semantic_index.py",
"阶段一:宏观语义索引", extra_env)
if not ok:
failures.append("阶段一")
print("\n阶段一失败,停止流水线。修复后重试。")
sys.exit(1)
run_test("test_step1.py", "Step 1 验证", extra_env)
# Stage 2
if not args.skip_step2:
ok = run_step("step2_ir_extraction.py",
"阶段二:逐功能单元 IR 提取", extra_env)
if not ok:
failures.append("阶段二")
print("\n阶段二失败,停止流水线。修复后重试。")
sys.exit(1)
run_test("test_step2.py", "Step 2 验证", extra_env)
# Stage 2.5
if not args.skip_step2_5:
ok = run_step("step2_5_branch_coverage.py",
"阶段2.5:分支覆盖自动补全", extra_env)
if not ok:
failures.append("阶段2.5")
print("\n阶段2.5失败,停止流水线。修复后重试。")
sys.exit(1)
run_test("test_step2_5.py", "Step 2.5 验证", extra_env)
# Stage 3
if not args.skip_step3:
ok = run_step("step3_merge_and_audit.py",
"阶段三:确定性合并与完整性校验", extra_env)
if not ok:
failures.append("阶段三")
sys.exit(1)
run_test("test_step3.py", "Step 3 验证", extra_env)
if failures:
print(f"\n失败阶段: {', '.join(failures)}")
sys.exit(1)
print(f"\n{'='*60}")
print("流水线全部完成!")
print(f"最终 IR: {config.IR_FINAL_JSON}")
print(f"审计报告: {config.IR_AUDIT_REPORT_MD}")
print(f"{'='*60}")
if __name__ == "__main__":
main()
@@ -0,0 +1,46 @@
## 上一轮遗漏分析
上一轮生成的语义索引经过自动校验,发现以下问题需要修正:
### 遗漏的逻辑树路径
以下逻辑树决策路径未被任何 function_unit 覆盖,请为每条路径生成对应的 function_unit
{missing_paths}
### 遗漏的概念
以下关键概念未在 concepts 列表中出现,请补充:
{missing_concepts}
### 格式问题
以下 function_unit 或 concept 的格式不符合要求:
{format_issues}
### concept parent 问题
以下概念的 parent 引用有问题(悬空引用或缺少 parent):
{parent_issues}
---
请在本次生成中针对以上问题进行修正。注意:
1. 你不需要从头生成完整的语义索引,只需要输出**补充和修正**的部分
2. function_units 的输出应只包含本次新增或修正的单元(已有的正确单元不需要重复)
3. concepts 的输出应只包含本次新增或修正的概念
4. 如果格式问题中提到"空壳单元":删除该 unit,或将其合并到包含实际 action 的 unit 中。纯开关状态不是独立的功能行为
5. 如果格式问题中提到"不构成有效路径":说明你引用了互斥分支上的节点。检查 logic_tree_nodes,确保它们都落在逻辑树的**同一条分支路径**上(例如 n4 是关闭分支,n8 是开启分支,不能共存)
6. 如果格式问题提到"缺少 path"或"缺少 sources":补充对应字段
## 输出格式
只输出 JSON
{
"feature_name": "(与之前相同)",
"supplemental_function_units": [
// 只放新增的或修正的 function_unit
],
"supplemental_concepts": [
// 只放新增的或修正的 concept
],
"corrections": {
// 需要修正的已有项: { "unit_id或concept_name": { 修正后的字段 }, ... }
}
}
@@ -0,0 +1,123 @@
你是吉利汽车车机系统(XX Auto)的产品需求分析师。你的任务是从行车娱乐限制功能 PRD 文档中提取"语义索引"——一份结构化、有层级的功能清单,而不是逐字翻译。
## 文档结构说明
下面是一份 Word 文档的解析结果,包含:
1. **sections**:按章节组织的混合内容(段落 + 表格),每个 section 有 `source`(章节标题)、`blocks``para` 文本段落和 `table` 结构表格)、`images`(引用的图片 ID 列表)
2. **image_analysis**:文档中流程图的程序化分析结果,其中 `logic_tree` 是由节点组成的决策树:
- `state` 节点:状态说明
- `decision` 节点:判断条件 + `branches`(分支值 → 目标节点 ID)
- `action` 节点:系统或用户交互动作
3. **resolved_conflicts**:文档中图文冲突的仲裁结果,明确指出应以文字还是图片为准
## 文档全文
{document_json}
## 你的任务
阅读整份文档后,输出一份 **语义索引 JSON**,包含:
### 1. feature_name
功能名称,如"行车娱乐限制"
### 2. concepts(带层级)
文档中定义或使用的关键概念列表。每个概念包含:
- `name`:概念的标准名称(必填)
- `aliases`:同义词/别名列表(如"行车娱乐限制"、"行车娱乐禁止"
- `defined_in`:定义该概念的章节号列表(如 ["3.1", "3.1.1"]
- `parent`:父概念名称(字符串或 null)(必填)
**概念层级规则(重要)**
你必须按照以下 4 层结构组织概念,并为每个概念指定正确的 `parent`:
- **Level 0(地理范围)**: "国内"、"海外" — parent 为 null
- **Level 1(功能)**: "行车娱乐限制"、"行车娱乐禁止" — parent 为对应的 scope(如 "国内"
- **Level 2(限制方式)**: "系统限制"、"SDK限制"、"其他应用" — parent 为对应的 feature
- **Level 3(具体行为)**: "前台打断"、"后台限制启动"、"后台暂停功能"、"无限制" — parent 为对应的 method
除了以上层级,还可以有"行车娱乐限制开关"、"车速条件"、"档位条件"、"Toast提示"等辅助概念,它们应有合理的 parent。
**重要约束:每个 concept 的 parent 值必须是 concepts 列表中已存在的另一个 concept 的 name,或者是 null。禁止引用不存在的概念名。**
### 3. function_units(带路径)
文档中描述的所有主要功能行为的列表。**每个 function_unit 对应逻辑树中的一条叶子路径**。每个 function unit 包含:
- `unit_id`:唯一标识,格式 "FU-001", "FU-002"...
- `name`:简短名称,如"国内-系统限制-前台-行车打断"
- `description`1-3 句描述该规则的行为
- `path`:层级路径数组,从高到低,如 `["国内", "系统限制", "前台打断"]`(必填)。**path 中的每个元素必须是 concepts 列表中已存在的概念名。**
- `sources`:该规则在文档中的来源锚点列表,每项包含:
- `section`:章节号
- `type`:来源类型,`"table"` 或 `"para"` 或 `"logic_tree"`
- `row`:如果是表格行(从 1 开始)
- `text_snippet`:前 200 字的关键文字
- `image_id`:如果是逻辑树来源,填写图片 rId
- `logic_tree_nodes`:如果是逻辑树来源,列出相关节点 ID 列表
## function_units 分解策略(重要)
**按逻辑树的每条叶子路径生成一个 function_unit**
1. **叶子路径 = 从根节点到叶子节点(end 类型)的完整决策链**,包含路径上所有中间节点和叶子节点的最终动作
2. **每条叶子路径对应一个 function_unit**:不同决策分支导向不同叶子节点 → 不同的 function_unit
3. **"不受限"叶子节点也必须建模**:即使 action 是"不执行任何限制操作",也要创建对应的 function_unit
4. **禁止合并不同叶子节点**:不要将多个不同叶子节点的结果合并到一个 function_unit(除非它们触发完全相同的动作且属于同一父分支)
5. **文字描述中的功能单独列出**:对于无法对应到逻辑树节点的功能(如纯文字描述的功能行为),用 table/para 类型 sourcepath 用语义路径
6. **非流程图的图片也可能包含功能行为**:rId18 等图片的描述文本中可能包含功能规则(如"使用语音打开受限应用"),同样需要提取为 function_unit
**重要:不要创建纯开关/状态的空壳 unit**。"开关开启"本身不是一个功能行为(它没有 action),它是其他单元的 precondition。如果一个 function_unit 的 path 只有 `["国内", "开关开启"]` 且 sources 中只有 n1/n2/n3 这样的根/开关节点,说明它不是真正的功能单元,不应该输出。
{feedback}
## 权威性规则
1. **逻辑树(流程图)是权威来源**:逻辑树定义了功能的确切行为。识别 function_unit 时必须优先按逻辑树路径建模。文字和表格用于补充描述、提供确切措辞(如 Toast 文案),但不应覆盖或曲解逻辑树路径。
2. **logic_tree_nodes 必须构成有效路径**:每个 function_unit 引用的 logic_tree_nodes 列表,必须对应逻辑树中的**一条连通路径**。禁止将互斥分支上的节点混入同一个 source(例如 n4 是"开关关闭"分支,n8 是"开关开启"分支的下游节点,它们不能出现在同一 function_unit 中)。
3. **resolved_conflicts 中的仲裁是最终决定**:如果文档有图文冲突且已仲裁,严格按仲裁结果处理。
4. **逻辑树路径应全部覆盖**:下面是程序从文档逻辑树中枚举的全部决策路径,请逐一确认每条路径都有对应的 function_unit
{logic_tree_paths}
## 关键要求
1. **必须覆盖所有逻辑树路径**:上面列出的每条路径必须被至少一个 function_unit 的 sources 引用。
2. **必须覆盖表格中的所有规则**:表格中列出的每种"限制方法"、"限制规则"都要有对应的 function_unit。
3. **区分"限制"与"禁止"**:文档中"行车娱乐限制"(前台应用打断)和"行车娱乐禁止"(后台应用启动限制)是两个不同的子场景,必须分别建模。
4. **区分不同应用类型**:系统限制、SDK 限制、其他应用的行为路径不同,必须分别建模。
5. **包含开关状态**:开关"开启"和"关闭"两种状态下的行为都要覆盖。
6. **概念和路径必须有层级**:每个 concept 指定正确的 parent;每个 function_unit 输出 path 数组。
## 输出格式
**只输出 JSON,不要有 markdown 代码块标记或其他文字**:
{
"feature_name": "...",
"concepts": [
{"name": "国内", "aliases": [], "defined_in": ["2.7", "3.1"], "parent": null},
{"name": "行车娱乐限制", "aliases": [], "defined_in": ["3.1", "3.1.1"], "parent": "国内"},
...
],
"function_units": [
{
"unit_id": "FU-001",
"name": "国内-系统限制-前台-行车打断",
"description": "...",
"path": ["国内", "系统限制", "前台打断"],
"sources": [
{"section": "3.1.1", "type": "table", "row": 2, "text_snippet": "打断:车速>=15km/h且持续5秒后..."},
{"image_id": "rId16", "type": "logic_tree", "logic_tree_nodes": ["n2","n3","n8","n19","n21","n23","n25","n26"]}
]
},
...
]
}
@@ -0,0 +1,200 @@
你是吉利汽车车机系统的需求分析专家。你的任务是基于给定的精准上下文包,为单个功能单元(Function Unit)提取详细的 **IR 规则(Intermediate Representation Rule**。
## 上下文
下面是一个功能单元的精准上下文包,包含了从原始需求文档中提取的相关文字、表格和逻辑树:
### 功能单元概要
- **unit_id**: {unit_id}
- **unit_name**: {unit_name}
- **unit_description**: {unit_description}
### 相关文字段落
{texts}
### 相关表格
{tables}
### 相关逻辑树
{logic_trees}
### 图文冲突仲裁(如有)
{resolved_conflicts}
## IR Schema
你需要为这个功能单元输出一个 **规则数组(rules)**。每条规则遵循以下 schema:
```json
{{
"rule_id": "{unit_id}-DOMESTIC-SYS-FG-INTERRUPT-01",
"path": ["国内", "系统限制", "前台打断"],
"description": "国内车型,开关开启,系统限制类应用在前台,车速>=15km/h且持续>5秒且非P档时,系统打断应用前台进程、将应用调入后台,显示Toast'在行车状态下无法使用该应用'",
"priority": "P0",
"sources": [
{{"type": "table", "section": "3.1.1", "row": 2, "text_snippet": "打断:车速>=15km/h且持续5秒后..."}},
{{"type": "logic_tree", "image_id": "rId16", "node_ids": ["n2","n3","n8","n19","n21","n23","n25","n26"], "priority": "primary_source"}}
],
"precondition": {{
"geographic_scope": "国内",
"screen_type": "any",
"switch": "开启",
"app_type": "系统限制",
"app_state": "前台"
}},
"trigger": {{
"operator": "AND",
"conditions": [
{{"signal": "车速", "operator": ">=", "value": 15, "unit": "km/h"}},
{{"signal": "车速_持续时间", "operator": ">", "value": 5, "unit": "秒"}},
{{"signal": "档位", "operator": "!=", "value": "P"}}
]
}},
"actions": [
{{"type": "system", "description": "打断应用前台进程"}},
{{"type": "system", "description": "将应用调入后台"}},
{{"type": "user_interaction", "description": "显示Toast", "content": "在行车状态下无法使用该应用"}}
]
}}
```
## 字段说明(必读)
1. **rule_id**: 格式为 `{unit_id}-SCOPE-METHOD-BEHAVIOR-NN`,其中:
- SCOPE: DOMESTIC(国内)| OVERSEAS(海外)
- METHOD: SYS(系统限制)| SDKSDK限制)| OTHER(其他应用)
- BEHAVIOR: FG-INTERRUPT(前台打断)| BG-BLOCK(后台限制启动)| BG-PAUSE(后台暂停功能)| NO-RESTRICT(无限制)| SWITCH-OFF(开关关闭)
- NN: 序号从 01 开始
2. **path**: 层级路径数组(必填)。从 scope 到 behavior 逐级列出,如 `["国内", "系统限制", "前台打断"]`。此字段用于程序化遍历所有功能点。
3. **description**: 完整但简洁地描述整个规则,必须包含:地理范围 + 开关状态 + 应用类型 + 前后台状态 + 触发条件 + 所有动作。人读取此字段即可设计测试用例。
4. **priority**: P0(核心安全规则)、P1(重要规则)、P2(边界情况)。
5. **sources**: 每条规则必须列出所有数据来源。逻辑树类型的 source 必须标记 `"priority": "primary_source"`。文字/表格类型的 source 标记 `"priority": "supplementary"`。**node_ids 必须列举该规则在逻辑树中经历的所有 decision 和 action 节点。**
6. **precondition**: 规则生效的前置状态条件。必须包含以下字段:
- `geographic_scope`(必填):"国内" | "海外"
- `screen_type`(必填):"CSD" | "PSD" | "RFD" | "any"(如文档未区分屏幕类型则填 "any"
- `switch`:开关状态("开启" | "关闭"
- `app_type`:应用类型
- `app_state`:应用前后台状态("前台" | "后台"
如某字段不适用,可省略。
7. **trigger**: 触发条件对象:
- `operator`: "AND" | "OR"
- `conditions`: 条件数组,每个条件必须有 `signal`、`operator`、`value`。有单位加 `unit`。
- 如为瞬时事件(用户点击),用 `event` 字段。
8. **actions**: 每个动作必须有 `type`"system" | "user_interaction")和 `description`。
- `"user_interaction"` 类型必须有 `content` 字段,填写**确切的提示文案**。
- **禁止使用占位符**content 不能是"文案由业务定义"、"待定"、"自定义"等。如果文档中给出了文案,必须原样填入。如果文档确实未给出文案,填写 `"(文档未指定)"` 并标注。
## Few-shot 示例
### 示例 1:行车娱乐限制(前台打断)
**输入上下文**:国内车型,开关开启,系统限制类应用在前台,车速>=15km/h且持续>5秒且非P档时,打断应用并显示Toast"在行车状态下无法使用该应用"。
**期望输出**
```json
{{
"rule_id": "FU-001-DOMESTIC-SYS-FG-INTERRUPT-01",
"path": ["国内", "系统限制", "前台打断"],
"description": "国内车型,开关开启,系统限制类应用在前台,当车速>=15km/h且持续超过5秒且非P档时,系统打断应用前台进程、将应用调入后台,并弹出Toast提示'在行车状态下无法使用该应用'",
"priority": "P0",
"sources": [
{{"type": "table", "section": "3.1.1", "row": 2, "text_snippet": "行车娱乐限制:目标应用/功能处于前台时 ○ 打断:车速>=15km/h且持续5秒后...", "priority": "supplementary"}},
{{"type": "logic_tree", "image_id": "rId16", "node_ids": ["n2","n3","n8","n19","n21","n23","n25","n26"], "priority": "primary_source"}}
],
"precondition": {{
"geographic_scope": "国内",
"screen_type": "any",
"switch": "开启",
"app_type": "系统限制",
"app_state": "前台"
}},
"trigger": {{
"operator": "AND",
"conditions": [
{{"signal": "车速", "operator": ">=", "value": 15, "unit": "km/h"}},
{{"signal": "车速_持续时间", "operator": ">", "value": 5, "unit": "秒"}},
{{"signal": "档位", "operator": "!=", "value": "P"}}
]
}},
"actions": [
{{"type": "system", "description": "打断应用前台进程"}},
{{"type": "system", "description": "将应用调入后台"}},
{{"type": "user_interaction", "description": "显示Toast", "content": "在行车状态下无法使用该应用"}}
]
}}
```
### 示例 2:行车娱乐禁止(后台启动拦截)
**输入上下文**:国内车型,开关开启,应用在后台,非P档时阻止应用启动,提示"请在P挡时使用该功能/应用"。
**期望输出**
```json
{{
"rule_id": "FU-002-DOMESTIC-SYS-BG-BLOCK-01",
"path": ["国内", "系统限制", "后台限制启动"],
"description": "国内车型,开关开启,目标应用处于后台,当用户尝试启动应用且档位非P档时,系统限制应用/功能启用,并弹出Toast提示'请在P挡时使用该功能/应用'",
"priority": "P0",
"sources": [
{{"type": "table", "section": "3.1.1", "row": 2, "text_snippet": "行车娱乐禁止:目标应用/功能处于后台时 ○ 限制:非P挡时,限制目标应用/功能启用...", "priority": "supplementary"}},
{{"type": "logic_tree", "image_id": "rId17", "node_ids": ["n1","n2","n5","n7"], "priority": "primary_source"}}
],
"precondition": {{
"geographic_scope": "国内",
"screen_type": "any",
"switch": "开启",
"app_state": "后台"
}},
"trigger": {{
"operator": "AND",
"conditions": [
{{"signal": "应用请求启动", "operator": "==", "value": true}},
{{"signal": "档位", "operator": "!=", "value": "P"}}
]
}},
"actions": [
{{"type": "system", "description": "限制应用/功能启用"}},
{{"type": "user_interaction", "description": "显示Toast", "content": "请在P挡时使用该功能/应用"}}
]
}}
```
## 关键要求
1. **逻辑树为唯一权威来源**:触发条件和动作序列必须严格按逻辑树路径建模。文字/表格描述仅用于补充确切措辞(如 Toast 文案),不得覆盖或曲解逻辑树路径。在 sources 中,逻辑树类型标记 `"priority": "primary_source"`,文字/表格标记 `"priority": "supplementary"`。
2. **信号和数值必须精确**:禁止写"车速超过阈值",必须写 `{{"signal": "车速", "operator": ">=", "value": 15, "unit": "km/h"}}`。
3. **条件必须完整**:逻辑树中的每个 decision 条件必须对应 trigger.conditions 中的一条。如果文档说"车速>=15km/h 且持续超过5秒 且非P档",这三个条件必须全部出现。
4. **每条规则必须自包含**:人仅凭一条 rule JSON 就能设计出对应的测试用例。必须包含:geographic_scope、screen_type、开关状态、应用类型、前后台状态、完整触发条件、所有动作及确切 Toast 文案、来源引用。
5. **禁止占位符**`"user_interaction"` 类型的 `content` 不能是"文案由业务定义"、"待定"、"自定义"。如文档确实未给出文案,填 `"(文档未指定)"`。
6. **逻辑树节点必须追踪**:在 sources 中列出该规则在逻辑树中经历的所有 decision 节点和 action 节点。
7. **多条规则**:如果一个功能单元包含多个独立行为分支,输出多条规则分别描述。
8. **开关关闭状态**:开关关闭时所有限制失效,这也必须作为一条规则输出(path: ["...", "开关关闭", "无限制"])。
{format_feedback}
## 输出格式
**只输出 JSON 数组,不要有任何其他文字或 markdown 标记**
[
{{ ... }},
{{ ... }}
]
注意:即使只有一个规则,也必须用数组格式 `[...]`。
-105
View File
@@ -1,105 +0,0 @@
import logging
import os
import time
from typing import Optional
from openai import OpenAI
logger = logging.getLogger(__name__)
class LLMClient:
"""Low-level OpenAI-compatible LLM client with retry and token tracking.
Usage::
llm = LLMClient()
content = llm.chat("qwen3.5-flash", [{"role": "user", "content": "Hello"}])
print(llm.usage)
"""
IMAGE_MODEL = "qwen3-vl-plus"
TEXT_MODEL = "qwen3.5-flash-2026-02-23"
TIMEOUT = 120
MAX_RETRIES = 3
def __init__(
self,
*,
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1",
timeout: int | None = None,
):
key = os.environ.get("DASHSCOPE_API_KEY", "")
if not key:
raise ValueError("DASHSCOPE_API_KEY environment variable is not set.")
self._client = OpenAI(api_key=key, base_url=base_url)
self._timeout = timeout or self.TIMEOUT
self._prompt_tokens = 0
self._completion_tokens = 0
@property
def usage(self) -> dict:
"""Return accumulated token counts as ``{prompt, completion, total}``."""
return {
"prompt_tokens": self._prompt_tokens,
"completion_tokens": self._completion_tokens,
"total_tokens": self._prompt_tokens + self._completion_tokens,
}
@staticmethod
def estimate_tokens(text: str) -> int:
"""Quick token estimate. CJK ≈1.7/token, others ≈3.0/token."""
cjk = sum(1 for c in text if '' <= c <= '鿿' or ' ' <= c <= '')
other = len(text) - cjk
return max(1, int(cjk / 1.7 + other / 3.0))
@staticmethod
def estimate_image_tokens() -> int:
"""Fixed estimate for one vision-model image (~500 tokens)."""
return 500
def chat(
self, model: str, messages: list[dict], *, timeout: int | None = None,
response_format: dict | None = None,
) -> str:
"""Send a chat completion request and return the response content.
Automatically retries on failure and accumulates token usage.
"""
label = f"chat({model})"
def _call():
t0 = time.time()
kwargs = dict(model=model, messages=messages, timeout=timeout or self._timeout)
if response_format is not None:
kwargs["response_format"] = response_format
kwargs["temperature"] = 0
resp = self._client.chat.completions.create(**kwargs)
content = resp.choices[0].message.content
usg = resp.usage
if usg:
self._prompt_tokens += usg.prompt_tokens
self._completion_tokens += usg.completion_tokens
elapsed = time.time() - t0
logger.info("%s: %d chars in %.1fs", label, len(content) if content else 0, elapsed)
if not content:
raise RuntimeError("Empty response from LLM")
return content
return self._retry(_call, label)
def _retry(self, fn, label: str) -> str:
"""Call *fn()* with exponential-backoff retry."""
last_error: Optional[Exception] = None
for attempt in range(self.MAX_RETRIES):
try:
return fn()
except Exception as e:
last_error = e
logger.warning(
"%s error (attempt %d/%d): %s",
label, attempt + 1, self.MAX_RETRIES, e,
)
if attempt < self.MAX_RETRIES - 1:
time.sleep(2 ** attempt)
raise RuntimeError(f"{label}: all retries exhausted") from last_error
@@ -1,359 +0,0 @@
#!/usr/bin/env python3
"""Generate JSON intermediate representation from ``_parsed.json`` or ``_updated.json``.
Sends the JSON document directly to the LLM for analysis. If the document exceeds
``MAX_ANALYSIS_TOKENS``, sections are batched greedily without splitting any
individual section. Conflict corrections from ``resolved_conflicts`` are included
so the output respects user arbitration decisions.
Usage::
python scripts/ir_generator.py output/<basename>_updated.json [output_dir] [--dry-run]
Output: ``<basename>_ir.json``
"""
import argparse
import json
import logging
import os
import sys
import time
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from LLM import LLMClient
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
RATE_LIMIT_DELAY = 0.5
MAX_ANALYSIS_TOKENS = 6000 # max content size per LLM call
# ---------------------------------------------------------------------------
# Prompt
# ---------------------------------------------------------------------------
PROMPT = """你是一个需求文档分析助手。请分析以下需求文档的JSON内容,输出结构化JSON。
## 已知修正(来自冲突检测)
以下内容已确认修正生成JSON时请**使用修正后的值**不要同时输出两个版本
{conflict_context}
## 待分析内容(JSON格式)
{content}
## JSON字段说明
- sections: 文档章节列表每个章节含 source章节标题 blocks内容块数组
- blocks: 类型含 para段落字段 text table表格字段 rows每行含 columns 数组
- image_sources: 图片所在章节映射key 为图片 rid
- image_analysis: 图片分析结果每个含 ridtype流程图/架构图/状态图等description
- resolved_conflicts: 已知修正列表每个含 sectionconflict_typecorrectionsource
## 功能点定义
只有满足以下**全部条件**的才视为功能点
1. 描述了一个**系统或软件要实现的具体行为**有触发条件执行动作状态变化或逻辑规则
2. 该行为直接由**系统或框架**执行不是人的操作流程管理流程
3. 对用户或系统有**可观察的效果**
**以下内容不是功能点不要输出**
- 术语/缩略词定义
- 文档背景范围说明 "本文档涵盖xxx"
- 变更日志版本记录编制人信息
- 文档结构描述 "产品简介用户场景说明"
- 纯文本的概述没有具体行为的介绍
## 决策树/流程图分解规则(重要)
图片分析image_analysis中的流程图和决策树描述包含丰富的功能逻辑**必须完全分解**
1. **每个叶子路径 = 一个独立 function**从根节点到每个最终结果的完整路径都拆成一个 function
2. **每个判断分支 = 一个独立 function**菱形判断节点的每个分支方向和对应的结果单独作为一个 function
3. **不同约束条件 = 不同 function**例如"通过接入SDK限制""通过系统限制"是不同约束机制必须分别列出
4. **不要合并不同路径**即使最终结果相同只要到达路径不同就是不同的 function
## 输出格式
只输出功能点每个功能点格式如下
{
"function": "功能名称",
"source": {
"section": "章节名",
"location": "原文位置(如:正文第1段、表格1第2行、图片rId13)"
},
"trigger": {
"type": "AND或者OR",
"conditions": [
"触发条件1",
"触发条件2"
]
},
"actions": {
"场景/角色": [
"动作1",
"动作2"
]
}
}
## 输出原则
1. **只输出功能点**没有功能点就输出空数组 []
2. 每个功能点**必须**包含 source.section source.location
3. location 必须是具体的原文位置标签 "正文第1段""表格1""图片rId13"
4. **一个 function 只对应一种行为逻辑一条完整路径**决策树中的每个分支路径从根到叶子必须拆成独立 functionconditions 中明确写出该路径上的所有判断条件和分支方向
5. **穷举所有分支**流程图/决策树中的每一条分支路径都要输出对应的 function不能遗漏任何子逻辑
6. 没有 trigger actions 的字段直接**省略**不要写 null 或空列表/空对象
7. 所有功能点全部列出**宁多勿漏**
8. **已知修正**中确认的信息使用修正后的值
9. 输出一个JSON数组不要用 ```json 代码块包裹直接输出纯JSON
"""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _parse_llm_response(raw: str) -> list | dict | str | None:
"""Parse JSON from LLM response, handling markdown code fences."""
if raw is None:
return None
stripped = raw.strip()
if stripped.startswith("```"):
nl = stripped.find("\n")
stripped = stripped[nl + 1:] if nl != -1 else stripped[3:]
if stripped.endswith("```"):
stripped = stripped[:-3]
try:
return json.loads(stripped)
except json.JSONDecodeError:
logger.warning(" Failed to parse JSON, returning raw text")
return raw
def _build_conflict_context(
section_name: str | None,
resolved_conflicts: list[dict],
) -> str:
"""Build conflict correction context for a section, or all if section_name is None."""
if section_name is None:
relevant = resolved_conflicts
else:
relevant = [c for c in resolved_conflicts if c.get("section", "") == section_name]
if not relevant:
return "没有"
lines: list[str] = []
for c in relevant:
correction = c.get("correction", "")
conflict_type = c.get("conflict_type", "")
source = c.get("source", "")
lines.append(f"- 冲突类型:{conflict_type},依据:{source}")
lines.append(f" 修正后的值:{correction}")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# LLM analysis
# ---------------------------------------------------------------------------
def _analyze_content(
content: str,
conflict_context: str,
llm: LLMClient,
*,
dry_run: bool = False,
) -> list[dict]:
"""Send content to the LLM and return IR entries."""
prompt = PROMPT.replace("{conflict_context}", conflict_context).replace("{content}", content)
if dry_run:
est = llm.estimate_tokens(prompt)
logger.info(" [DRY RUN] prompt ~%d tokens", est)
return []
try:
raw = llm.chat(
model=LLMClient.TEXT_MODEL,
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"},
)
logger.info(" Response: %d chars", len(raw))
except RuntimeError as e:
logger.error(" Analysis failed: %s", e)
return []
parsed = _parse_llm_response(raw)
if isinstance(parsed, list):
return parsed
elif isinstance(parsed, dict):
return [parsed]
else:
logger.warning(" Unparseable response, raw length: %d", len(raw))
return []
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def generate_ir(
parsed_path: str,
output_dir: str = "output",
*,
dry_run: bool = False,
) -> dict:
"""Read parsed/updated JSON and generate JSON IR.
Produces ``<basename>_ir.json`` in *output_dir*.
"""
with open(parsed_path, "r", encoding="utf-8") as f:
data = json.load(f)
basename = os.path.splitext(os.path.basename(parsed_path))[0]
for suffix in ("_parsed", "_updated"):
if basename.endswith(suffix):
basename = basename[:-len(suffix)]
break
os.makedirs(output_dir, exist_ok=True)
llm = LLMClient()
ir_output: list[dict] = []
sections = data.get("sections", [])
image_sources = data.get("image_sources", {})
image_analysis = data.get("image_analysis", [])
resolved_conflicts = data.get("resolved_conflicts", [])
# Build full document JSON to measure size
full_doc = {
"sections": sections,
"image_sources": image_sources,
"image_analysis": image_analysis,
}
full_json = json.dumps(full_doc, ensure_ascii=False)
total_chars = len(full_json)
logger.info("Total document JSON chars: %d", total_chars)
if total_chars < MAX_ANALYSIS_TOKENS:
logger.info("Document fits in one request (< %d chars)", MAX_ANALYSIS_TOKENS)
conflict_ctx = _build_conflict_context(None, resolved_conflicts)
entries = _analyze_content(full_json, conflict_ctx, llm, dry_run=dry_run)
ir_output.extend(entries)
else:
logger.info("Document is large (>= %d chars), batching sections", MAX_ANALYSIS_TOKENS)
# Filter to non-empty sections, measure effective size per section
# (section JSON + image_sources + image_analysis for images in that section)
sec_sizes = []
for sec in sections:
if not sec.get("blocks"):
continue
sec_json = json.dumps(sec, ensure_ascii=False)
sec_chars = len(sec_json)
# Add image overhead for this section
sec_name = sec.get("source", "")
sec_rids = [rid for rid, src in image_sources.items()
if src.get("section", "") == sec_name]
if sec_rids:
overhead_doc = {
"image_sources": {rid: image_sources[rid] for rid in sec_rids},
"image_analysis": [img for img in image_analysis
if img.get("rid", "") in sec_rids],
}
sec_chars += len(json.dumps(overhead_doc, ensure_ascii=False))
sec_sizes.append((sec, sec_chars))
# Greedy batch: never split a section, keep adding until next exceeds limit
i = 0
while i < len(sec_sizes):
batch = []
batch_size = 0
while i < len(sec_sizes) and batch_size + sec_sizes[i][1] <= MAX_ANALYSIS_TOKENS:
batch.append(sec_sizes[i][0])
batch_size += sec_sizes[i][1]
i += 1
if not batch:
i += 1
continue
# Collect sections and their images for this batch
batch_names = [s.get("source", "") for s in batch]
batch_image_sources = {
rid: src for rid, src in image_sources.items()
if src.get("section", "") in batch_names
}
batch_images = [
img for img in image_analysis
if image_sources.get(img.get("rid", ""), {}).get("section", "") in batch_names
]
batch_doc = {
"sections": batch,
"image_sources": batch_image_sources,
"image_analysis": batch_images,
}
batch_json = json.dumps(batch_doc, ensure_ascii=False)
# Merge conflict contexts
ctx_parts = []
for sn in batch_names:
ctx = _build_conflict_context(sn, resolved_conflicts)
if ctx != "没有":
ctx_parts.append(ctx)
conflict_ctx = "\n".join(ctx_parts) if ctx_parts else "没有"
label = " + ".join(batch_names)
logger.info("Batch [%s]: %d sections, %d chars", label, len(batch), len(batch_json))
entries = _analyze_content(batch_json, conflict_ctx, llm, dry_run=dry_run)
ir_output.extend(entries)
time.sleep(RATE_LIMIT_DELAY)
# ---- save ----------------------------------------------------------------
ir_path = os.path.join(output_dir, f"{basename}_ir.json")
os.makedirs(os.path.dirname(ir_path) or ".", exist_ok=True)
with open(ir_path, "w", encoding="utf-8") as f:
json.dump(ir_output, f, ensure_ascii=False, indent=2)
logger.info("Saved: %s (%d entries)", ir_path, len(ir_output))
# ---- summary -------------------------------------------------------------
usg = llm.usage
logger.info("Tokens: %d prompt + %d completion = %d total",
usg["prompt_tokens"], usg["completion_tokens"], usg["total_tokens"])
logger.info("Output: %s", ir_path)
return {"ir": ir_output, "path": ir_path}
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate JSON intermediate representation from parsed/updated JSON.",
)
parser.add_argument("input", metavar="parsed.json",
help="Path to _parsed.json or _updated.json")
parser.add_argument("output_dir", nargs="?", default="output", metavar="output_dir",
help="Directory for output files (default: output/)")
parser.add_argument("--dry-run", action="store_true",
help="Print token estimates without calling the API.")
args = parser.parse_args()
generate_ir(args.input, args.output_dir, dry_run=args.dry_run)
@@ -0,0 +1,717 @@
"""
Stage 1: Ensemble Semantic Index Generation.
Generates N parallel LLM calls with different temperatures (e.g., 0.0, 0.3, 0.7),
then deterministically merges the results via ensemble_merge (pure Python, no LLM).
The merged output includes confidence scores for each concept and function_unit.
Outputs:
- output/semantic_index_r1.json (T=0.0 raw)
- output/semantic_index_r2.json (T=0.3 raw)
- output/semantic_index_r3.json (T=0.7 raw)
- output/semantic_index.json (ensemble-merged final)
"""
import concurrent.futures
import json
import re
import sys
import time
from pathlib import Path
import config
from ensemble_merge import ensemble_merge
# ---- Path Enumeration (for prompt embedding) ----
def _traverse_nested(node: dict, image_id: str, path_nodes: list,
branch_taken: str | None) -> list[dict]:
"""DFS traversal of a logic_tree_nested node, returning leaf path records."""
node_id = node.get("id", "?")
node_type = node.get("type", "?")
node_name = node.get("name", "")
path_nodes = path_nodes + [{
"id": node_id,
"type": node_type,
"label": node_name,
"branch_taken": branch_taken,
}]
if node_type == "end":
return [_make_path_record(path_nodes, image_id)]
children = node.get("children", [])
if not children:
return [_make_path_record(path_nodes, image_id)]
all_paths = []
for child in children:
# Decision nodes have {condition, node} wrappers; others are direct node dicts
if node_type == "decision":
condition = child.get("condition", "")
child_node = child.get("node", child)
else:
condition = "(implicit)"
child_node = child
all_paths.extend(
_traverse_nested(child_node, image_id, path_nodes, condition)
)
return all_paths
def _make_path_record(path_nodes: list, image_id: str) -> dict:
"""Build a path record from a completed node chain."""
action_nodes = [n for n in path_nodes if n["type"] == "action"]
decision_nodes = [n for n in path_nodes if n["type"] == "decision"]
node_ids = [n["id"] for n in path_nodes]
return {
"path_id": f"PATH-{image_id}-{'-'.join(node_ids)}",
"nodes": path_nodes,
"meaning": _describe_path(path_nodes),
"image_id": image_id,
"action_nodes": action_nodes,
"decision_nodes": decision_nodes,
"node_ids": node_ids,
}
def enumerate_logic_tree_paths(nested_tree: dict, image_id: str = "") -> list[dict]:
"""Enumerate all root-to-leaf paths from a logic_tree_nested structure.
Uses the nested tree directly (no flat-list adjacency). Decision nodes
fork by {condition, node} branches; other nodes have direct children.
"""
if not nested_tree:
return []
return _traverse_nested(nested_tree, image_id, [], None)
def _describe_path(path_nodes: list[dict]) -> str:
"""Generate a human-readable description of a logic tree path."""
parts = []
for n in path_nodes:
label = n["label"]
if n["branch_taken"] and n["branch_taken"] != "(implicit)":
label = f"{label}{n['branch_taken']}"
parts.append(label)
return "".join(parts)
def enumerate_all_paths(doc: dict) -> dict[str, list[dict]]:
"""Enumerate paths for all logic trees in the document.
Uses logic_tree_nested when available (proper tree), falling back to
flat logic_tree. Returns {image_id: [path, ...]}.
"""
result = {}
for img in doc.get("image_analysis", []):
rid = img.get("rid", "")
if not rid:
continue
nested = img.get("logic_tree_nested")
if nested:
result[rid] = enumerate_logic_tree_paths(nested, image_id=rid)
else:
lt = img.get("logic_tree")
if lt and lt.get("nodes"):
lt["image_id"] = rid
result[rid] = _enumerate_flat_tree(lt)
elif lt:
result[rid] = []
return result
def _enumerate_flat_tree(tree: dict) -> list[dict]:
"""Fallback: enumerate paths from flat logic_tree using adjacency.
Handles start/process/action/state nodes as implicit chain links.
"""
nodes = tree.get("nodes", [])
if not nodes:
return []
node_map = {n["id"]: n for n in nodes}
image_id = tree.get("image_id", "")
# Find root: first start/state node, or first process node, or first node
root = None
for n in nodes:
if n["type"] in ("start", "state"):
root = n
break
if root is None:
for n in nodes:
if n["type"] == "process":
root = n
break
if root is None:
root = nodes[0]
adj = _build_adjacency(nodes, node_map)
paths = []
def dfs(current_id, visited, path_nodes, branch_taken):
if current_id in visited:
return
new_visited = visited | {current_id}
node = node_map.get(current_id)
if node is None:
return
path_nodes = path_nodes + [{
"id": current_id,
"type": node["type"],
"label": node.get("description") or node.get("condition", ""),
"branch_taken": branch_taken,
}]
outgoing = adj.get(current_id, [])
if not outgoing:
action_nodes = [n for n in path_nodes if n["type"] == "action"]
decision_nodes = [n for n in path_nodes if n["type"] == "decision"]
node_ids = [n["id"] for n in path_nodes]
paths.append({
"path_id": f"PATH-{image_id}-{'-'.join(node_ids)}",
"nodes": path_nodes,
"meaning": _describe_path(path_nodes),
"image_id": image_id,
"action_nodes": action_nodes,
"decision_nodes": decision_nodes,
"node_ids": node_ids,
})
else:
for branch_val, target_id in outgoing:
dfs(target_id, new_visited, path_nodes, branch_val)
dfs(root["id"], set(), [], None)
return paths
def _build_adjacency(nodes, node_map):
"""Build {node_id: [(branch_value, target_id)]} adjacency for flat trees.
Handles: decision branches (explicit), non-branching nodes (implicit sequential).
"""
NON_BRANCHING = {"start", "process", "state", "action"}
adj = {}
has_explicit_incoming = set()
for n in nodes:
for br in n.get("branches", []):
has_explicit_incoming.add(br["target"])
for i, node in enumerate(nodes):
nid = node["id"]
adj.setdefault(nid, [])
# Explicit edges from decision nodes
for br in node.get("branches", []):
adj[nid].append((br["value"], br["target"]))
# Implicit edges for non-branching nodes (start/process/state/action)
if node["type"] in NON_BRANCHING and not node.get("branches"):
j = i + 1
targets = []
while j < len(nodes):
next_node = nodes[j]
next_nid = next_node["id"]
if next_nid in has_explicit_incoming:
break
if next_node["type"] in NON_BRANCHING | {"end"}:
targets.append(next_nid)
has_explicit_incoming.add(next_nid)
j += 1
continue
elif next_node["type"] == "decision":
if not targets:
targets.append(next_nid)
break
j += 1
for t in targets:
adj[nid].append(("(implicit)", t))
return adj
def format_paths_for_prompt(all_paths: dict[str, list[dict]]) -> str:
"""Format enumerated paths as a readable list for the LLM prompt."""
if not all_paths:
return "(无逻辑树路径)"
lines = []
for image_id, paths in all_paths.items():
lines.append(f"\n### {image_id} 的全部决策路径(共 {len(paths)} 条):")
for i, path in enumerate(paths, 1):
lines.append(f"\n**路径 {i}** (ID: {path['path_id']})")
lines.append(f" 含义: {path['meaning']}")
lines.append(f" 节点: {path['node_ids']}")
lines.append(f" 决策节点: {[n['id'] for n in path['decision_nodes']]}")
lines.append(f" 动作节点: {[n['id'] for n in path['action_nodes']]}")
return "\n".join(lines)
# ---- Document Formatting ----
def format_document_for_prompt(doc: dict) -> str:
"""Render the full parsed document as a readable string for the LLM prompt."""
lines = []
lines.append("=== SECTIONS ===")
for i, section in enumerate(doc.get("sections", [])):
source = section.get("source", f"(无标题-章节{i})")
lines.append(f"\n--- Section: {source} ---")
for block in section.get("blocks", []):
if block["type"] == "para":
lines.append(f"[段落 {block['index']}] {block['text']}")
elif block["type"] == "table":
lines.append(f"[表格 {block.get('table', '?')}]")
headers = block.get("headers", [])
lines.append(f" 表头: {' | '.join(headers)}")
for row in block.get("rows", []):
cols = row.get("columns", [])
cell_texts = []
for c in cols:
cell_texts.append(
f"[行{c.get('row','?')}]{c.get('name','')}: {c.get('text','')}"
)
lines.append(f" {'; '.join(cell_texts)}")
images = section.get("images", [])
if images:
lines.append(f" 图片引用: {', '.join(images)}")
lines.append("\n\n=== IMAGE_ANALYSIS (流程图逻辑树) ===")
for img in doc.get("image_analysis", []):
rid = img.get("rid", "?")
img_type = img.get("type", "?")
lines.append(f"\n--- Image: {rid} (type={img_type}) ---")
lines.append(f" 描述: {img.get('description', '')[:300]}")
lt = img.get("logic_tree")
if lt:
lines.append(f" 逻辑树根节点: {lt.get('root', '?')}")
lines.append(" 节点详情:")
for node in lt.get("nodes", []):
nid = node.get("id", "?")
ntype = node.get("type", "?")
desc = node.get("description", "") or node.get("condition", "")
lines.append(f" [{ntype}] {nid}: {desc}")
branches = node.get("branches", [])
if branches:
for br in branches:
lines.append(f"{br['value']}{br['target']}")
conflicts = doc.get("resolved_conflicts", [])
if conflicts:
lines.append("\n\n=== RESOLVED_CONFLICTS (图文冲突仲裁) ===")
for c in conflicts:
lines.append(
f" [{c.get('conflict_type','?')}] {c.get('section','?')}: "
f"{c.get('source','?')}为准 — {c.get('correction','')}"
)
return "\n".join(lines)
# ---- Prompt Building ----
def build_prompt(doc: dict, feedback: str = "", all_paths: dict | None = None) -> str:
"""Load the prompt template and inject the formatted document + paths + feedback."""
template_path = Path(config.PROMPTS_DIR) / "step1_semantic_index.txt"
template = template_path.read_text(encoding="utf-8")
formatted_doc = format_document_for_prompt(doc)
prompt = template.replace("{document_json}", formatted_doc)
if all_paths is None:
all_paths = enumerate_all_paths(doc)
path_text = format_paths_for_prompt(all_paths)
prompt = prompt.replace("{logic_tree_paths}", path_text)
if feedback:
prompt = prompt.replace("{feedback}", feedback)
else:
prompt = prompt.replace("{feedback}", "")
return prompt
# ---- Validation ----
def _quick_validate(
semantic_index: dict, doc: dict, all_paths: dict | None = None
) -> tuple[bool, dict]:
"""Validate semantic index and return (passed, gaps).
Uses a single COVERAGE_TARGET threshold (default 0.95).
"""
gaps = {
"missing_paths": [],
"missing_concepts": [],
"format_issues": [],
"parent_issues": [],
}
units = semantic_index.get("function_units", [])
concepts = semantic_index.get("concepts", [])
# --- Check function_units non-empty ---
if not units:
gaps["format_issues"].append("function_units 为空")
return False, gaps
# --- Check each function_unit has path ---
for fu in units:
uid = fu.get("unit_id", "?")
if not fu.get("path"):
gaps["format_issues"].append(f"{uid}: 缺少 path 字段")
if not fu.get("sources"):
gaps["format_issues"].append(f"{uid}: 缺少 sources")
# --- Logic tree node coverage ---
all_nodes = _collect_logic_tree_nodes(doc)
referenced = _collect_referenced_nodes(units)
threshold = config.COVERAGE_TARGET
for image_id, node_set in all_nodes.items():
ref_set = referenced.get(image_id, set())
checkable = {
nid for nid, ntype in node_set.items()
if ntype in ("decision", "action")
}
if not checkable:
continue
covered = checkable & ref_set
coverage = len(covered) / len(checkable) if checkable else 1.0
if coverage < threshold:
missing = checkable - ref_set
gaps["missing_paths"].append(
f"{image_id}: 覆盖率 {coverage:.0%} < {threshold:.0%}, "
f"未覆盖节点: {sorted(missing)}"
)
# --- Check logic tree path consistency ---
# A unit's logic_tree_nodes must form a valid (connected) path in the tree.
if all_paths is not None:
for fu in units:
uid = fu.get("unit_id", "?")
for src in fu.get("sources", []):
if src.get("type") != "logic_tree":
continue
image_id = src.get("image_id", "")
unit_nodes = set(src.get("logic_tree_nodes", []))
if not unit_nodes:
continue
# Check if there exists a path containing all these nodes
valid = False
for path in all_paths.get(image_id, []):
path_nodes = set(path.get("node_ids", []))
if unit_nodes.issubset(path_nodes):
valid = True
break
if not valid:
gaps["format_issues"].append(
f"{uid}: logic_tree_nodes 不构成有效路径 "
f"(image={image_id}, nodes={sorted(unit_nodes)})"
)
# --- Check for trivial units (only state/switch nodes, no actions) ---
if all_paths is not None:
for fu in units:
uid = fu.get("unit_id", "?")
has_logic_ref = False
has_action = False
has_non_trivial_decision = False
for src in fu.get("sources", []):
if src.get("type") != "logic_tree":
continue
has_logic_ref = True
node_ids = src.get("logic_tree_nodes", [])
node_types = {}
for image_id, nset in all_nodes.items():
for nid in node_ids:
if nid in nset:
node_types[nid] = nset[nid]
for nid in node_ids:
ntype = node_types.get(nid, "")
if ntype == "action":
has_action = True
# Count decisions beyond first level (e.g., n1/n2 are just root+switch)
decisions = [nid for nid in node_ids
if node_types.get(nid, "") == "decision"]
if len(decisions) > 1:
has_non_trivial_decision = True
if has_logic_ref and not has_action and not has_non_trivial_decision:
gaps["format_issues"].append(
f"{uid}: 可能为空壳单元(仅有state/开关节点,无action或深层decision"
)
# --- Concept parent validity ---
concept_names = {c["name"] for c in concepts}
for c in concepts:
name = c.get("name", "?")
parent = c.get("parent") # can be None for scope-level
if parent is not None and parent not in concept_names:
gaps["parent_issues"].append(
f"concept '{name}' 的 parent '{parent}' 不存在"
)
# Warn about scope-level concepts without parent=null
for c in concepts:
if c.get("parent") is not None:
continue
name = c.get("name", "")
# Scope-level concepts (国内/海外) should have parent=null
if name not in ("国内", "海外", ""):
gaps["parent_issues"].append(
f"concept '{name}' 的 parent 为 null,但它可能不是 scope 概念"
)
# --- Check for missing scope concepts ---
if "国内" not in concept_names:
gaps["missing_concepts"].append("缺少 scope 概念: 国内")
if "海外" not in concept_names and any(
"海外" in s.get("source", "") for s in doc.get("sections", [])
):
gaps["missing_concepts"].append("缺少 scope 概念: 海外")
passed = (
not gaps["missing_paths"]
and not gaps["format_issues"]
and not gaps["parent_issues"]
)
return passed, gaps
def _collect_logic_tree_nodes(doc: dict) -> dict[str, dict[str, str]]:
"""Return {image_id: {node_id: node_type}} for all logic trees."""
result = {}
for img in doc.get("image_analysis", []):
lt = img.get("logic_tree")
rid = img.get("rid", "")
if lt and rid:
result[rid] = {n["id"]: n["type"] for n in lt.get("nodes", [])}
return result
def _collect_referenced_nodes(units: list[dict]) -> dict[str, set[str]]:
"""Return {image_id: {referenced node_ids}} across all function_units."""
refs = {}
for fu in units:
for src in fu.get("sources", []):
if src.get("type") == "logic_tree":
image_id = src.get("image_id", "")
if image_id not in refs:
refs[image_id] = set()
refs[image_id].update(src.get("logic_tree_nodes", []))
return refs
# ---- LLM Calls ----
def extract_json_from_response(text: str) -> str:
"""Robustly extract JSON from LLM response."""
m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text)
if m:
return m.group(1).strip()
start = text.find("{")
if start == -1:
raise ValueError("No JSON object found in LLM response")
depth = 0
for i in range(start, len(text)):
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
if depth == 0:
return text[start : i + 1]
raise ValueError("Unclosed JSON object in LLM response")
def call_llm(prompt: str, max_retries: int = 2,
temperature: float | None = None) -> dict:
"""Send prompt to LLM, return parsed JSON dict.
Args:
temperature: Override config.TEMPERATURE. If None, uses config default.
"""
client = config.llm_client()
temp = temperature if temperature is not None else config.TEMPERATURE
for attempt in range(max_retries + 1):
print(f" LLM 调用 T={temp} (尝试 {attempt + 1}/{max_retries + 1})...", flush=True)
try:
resp = client.chat.completions.create(
model=config.MODEL_NAME,
messages=[
{
"role": "system",
"content": "你是一个精确的 JSON 输出引擎。只输出合法的 JSON。",
},
{"role": "user", "content": prompt},
],
temperature=temp,
max_tokens=config.MAX_TOKENS,
)
content = resp.choices[0].message.content
if content is None:
raise RuntimeError("LLM returned empty response")
json_str = extract_json_from_response(content)
return json.loads(json_str)
except (json.JSONDecodeError, ValueError) as e:
print(f" JSON 解析失败: {e}")
if attempt < max_retries:
time.sleep(2)
raise RuntimeError("无法从 LLM 响应中解析 JSON")
# ---- Ensemble Orchestration ----
def run_ensemble_semantic_index(doc: dict) -> dict:
"""Run N parallel LLM calls at different temperatures, then ensemble-merge.
1. Enumerate all logic tree paths (once).
2. Build the prompt (once no iterative feedback needed).
3. Launch len(ENSEMBLE_TEMPERATURES) parallel LLM calls via ThreadPoolExecutor.
4. Collect all results.
5. Call ensemble_merge() for deterministic merge.
6. Validate final output with _quick_validate().
7. Save individual version outputs + merged output.
"""
all_paths = enumerate_all_paths(doc)
print(f" 已枚举逻辑树路径: {sum(len(v) for v in all_paths.values())}")
prompt = build_prompt(doc, "", all_paths)
print(f" Prompt 长度: {len(prompt)} 字符")
temperatures = config.ENSEMBLE_TEMPERATURES
print(f" 集成温度: {temperatures}")
# Parallel LLM calls
raw_results: list[tuple[int, float, dict]] = []
with concurrent.futures.ThreadPoolExecutor(
max_workers=len(temperatures)
) as executor:
future_to_meta = {}
for i, temp in enumerate(temperatures):
future = executor.submit(call_llm, prompt, 2, temp)
future_to_meta[future] = (i, temp)
for future in concurrent.futures.as_completed(future_to_meta):
idx, temp = future_to_meta[future]
try:
si = future.result()
n_units = len(si.get("function_units", []))
n_concepts = len(si.get("concepts", []))
print(f" T={temp}: {n_concepts} 概念, {n_units} 功能单元")
raw_results.append((idx, temp, si))
except Exception as e:
print(f" T={temp}: FAIL — {e}")
raw_results.append((idx, temp, {
"feature_name": "", "concepts": [], "function_units": []
}))
if not raw_results:
raise RuntimeError("所有集成的 LLM 调用均失败")
# Sort by temperature for determinism
raw_results.sort(key=lambda x: x[1])
semantic_indices = [r[2] for r in raw_results]
# Save individual version outputs
version_paths = {
0: config.SEMANTIC_INDEX_R1_JSON,
1: config.SEMANTIC_INDEX_R2_JSON,
2: config.SEMANTIC_INDEX_R3_JSON,
}
for i, si in enumerate(semantic_indices):
out_path = version_paths.get(i)
if out_path:
config.save_json(si, out_path)
print(f" 保存版本 {i} (T={temperatures[i]}): {out_path}")
# Ensemble merge
print(f"\n 集成合并 {len(semantic_indices)} 个版本...")
merged = ensemble_merge(semantic_indices)
merged["ensemble_temperatures"] = list(temperatures)
# Validate
passed, gaps = _quick_validate(merged, doc, all_paths)
merged["validation_passed"] = passed
merged["validation_gaps"] = {
k: v for k, v in gaps.items() if v
}
# Print summary
cs = merged.get("confidence_summary", {})
print(f" 合并后: {cs.get('total_concepts', 0)} 概念, "
f"{cs.get('total_units', 0)} 功能单元")
print(f" 置信度: high={cs.get('high', 0)}, medium={cs.get('medium', 0)}, "
f"low={cs.get('low', 0)}")
print(f" 验证: {'PASS' if passed else 'GAPS FOUND'}")
if not passed:
for k, v in gaps.items():
if v:
print(f" {k}: {len(v)} 个问题")
return merged
# ---- Main ----
def main():
print("=" * 60)
print("阶段一:集成语义索引 (Ensemble Semantic Index)")
print("=" * 60)
# 1. Load input
print(f"\n[1/3] 加载输入文档: {config.INPUT_JSON}")
doc = config.load_input_document()
print(f" 已加载 {len(doc.get('sections', []))} 个 section, "
f"{len(doc.get('image_analysis', []))} 张图片分析")
# 2. Run ensemble generation + merge
print(f"\n[2/3] 运行集成语义索引 ({len(config.ENSEMBLE_TEMPERATURES)} 个温度版本)...")
merged_index = run_ensemble_semantic_index(doc)
# 3. Save outputs
print(f"\n[3/3] 保存最终语义索引: {config.SEMANTIC_INDEX_JSON}")
config.save_json(merged_index, config.SEMANTIC_INDEX_JSON)
# Also save path enumeration for downstream use
all_paths = enumerate_all_paths(doc)
config.save_json(
{"logic_tree_paths": {k: v for k, v in all_paths.items()}},
config.PATH_ENUM_JSON,
)
print(f" 路径枚举: {config.PATH_ENUM_JSON}")
cs = merged_index.get("confidence_summary", {})
n_concepts = cs.get("total_concepts", len(merged_index.get("concepts", [])))
n_units = cs.get("total_units", len(merged_index.get("function_units", [])))
n_versions = merged_index.get("ensemble_versions", len(config.ENSEMBLE_TEMPERATURES))
print(f"\n完成! {n_versions} 版本集成, {n_concepts} 个概念, {n_units} 个功能单元.")
print(f"输出: {config.SEMANTIC_INDEX_JSON}")
if __name__ == "__main__":
main()
@@ -0,0 +1,399 @@
"""
Stage 2.5: Branch Coverage Auto-Completion.
1. Enumerates all root-to-leaf paths in every logic tree
2. Compares paths against existing IR rules to find uncovered paths
3. Generates synthetic function_units for uncovered paths
4. Calls LLM (same extract_rules_for_unit) to produce rules for synthetic units
5. Iterates up to MAX_RETRIES_PER_STAGE rounds to reach COVERAGE_TARGET
Outputs:
- output/path_enumeration.json
- output/ir_autocomplete_fragments.json
"""
import concurrent.futures
import json
import time
from pathlib import Path
import config
# ---- Path Enumeration (shared with step1, duplicated for module independence) ----
def enumerate_all_paths(doc: dict) -> dict[str, list[dict]]:
"""Enumerate all root-to-leaf paths for every logic tree."""
from step1_semantic_index import enumerate_all_paths as _enum
return _enum(doc)
# ---- Coverage Analysis ----
def find_referenced_path_ids(rules: list[dict]) -> dict[str, set[str]]:
"""Map each rule to the set of logic tree nodes it references.
Returns {rule_id: set of "image_id:node_id" pairs}
"""
result = {}
for rule in rules:
rid = rule.get("rule_id", "?")
refs = set()
for src in rule.get("sources", []):
if src.get("type") == "logic_tree":
image_id = src.get("image_id", "")
for nid in src.get("node_ids", []):
refs.add(f"{image_id}:{nid}")
result[rid] = refs
return result
def compute_path_coverage(
all_paths: dict[str, list[dict]], rules: list[dict]
) -> tuple[list[dict], list[dict], dict]:
"""Compute coverage of enumerated paths by existing rules.
Returns (covered_paths, uncovered_paths, stats).
A path is "covered" if at least one rule's node_ids form a superset
of the path's decision+action nodes for that image.
"""
# Build per-rule node sets keyed by image_id
rule_node_sets = {} # {rule_id: {image_id: set(node_ids)}}
for rule in rules:
rid = rule.get("rule_id", "?")
rule_node_sets[rid] = {}
for src in rule.get("sources", []):
if src.get("type") == "logic_tree":
image_id = src.get("image_id", "")
rule_node_sets[rid].setdefault(image_id, set()).update(
src.get("node_ids", [])
)
covered = []
uncovered = []
for image_id, paths in all_paths.items():
for path in paths:
# Get checkable nodes for this path (decision + action)
checkable = set(
n["id"] for n in path["nodes"]
if n["type"] in ("decision", "action")
)
if not checkable:
# Path with no decision/action nodes — trivially covered
covered.append(path)
continue
path_covered = False
for rid, img_sets in rule_node_sets.items():
rule_nodes = img_sets.get(image_id, set())
if checkable.issubset(rule_nodes):
path_covered = True
break
if path_covered:
covered.append(path)
else:
uncovered.append(path)
total = len(covered) + len(uncovered)
stats = {
"total_paths": total,
"covered_paths": len(covered),
"uncovered_paths": len(uncovered),
"coverage_pct": round(len(covered) / total * 100, 1) if total > 0 else 100.0,
}
return covered, uncovered, stats
# ---- Synthetic Function Unit Generation ----
def generate_synthetic_unit(path: dict, unit_seq: int) -> dict:
"""Create a synthetic function_unit from an uncovered logic tree path.
Infers preconditions and trigger from the decision nodes along the path.
"""
node_map = {n["id"]: n for n in path["nodes"]}
# Infer switch state from path
switch = _infer_switch_state(path)
# Infer app_type from path
app_type = _infer_app_type(path)
# Infer app_state from path
app_state = _infer_app_state(path)
# Infer geographic_scope from section context
scope = _infer_scope(path)
# Build description from path meaning
description = f"自动补全: {path.get('meaning', '')}"
if switch:
description = f"开关{switch}, {description}"
# Build path list
path_labels = []
if scope:
path_labels.append(scope)
if switch:
path_labels.append(f"开关{switch}")
if app_type:
path_labels.append(app_type)
if app_state:
path_labels.append(app_state)
# Add behavior from terminal action
action_nodes = path.get("action_nodes", [])
if action_nodes:
last_action = action_nodes[-1].get("label", "")
path_labels.append(last_action[:20])
unit_id = f"FU-AUTO-{path['image_id']}-{unit_seq:03d}"
seq = f"{unit_seq:03d}"
return {
"unit_id": unit_id,
"name": f"自动补全-{path.get('meaning', '')[:60]}",
"description": description,
"path": path_labels,
"auto_generated": True,
"sources": [
{
"section": "",
"type": "logic_tree",
"image_id": path["image_id"],
"logic_tree_nodes": path.get("node_ids", []),
}
],
}
def _infer_switch_state(path: dict) -> str:
"""Infer switch state from decision nodes in path."""
for n in path["nodes"]:
label = n.get("label", "")
branch = n.get("branch_taken", "")
if "开关" in label and n["type"] == "decision":
if branch == "开启":
return "开启"
elif branch == "关闭":
return "关闭"
return ""
def _infer_app_type(path: dict) -> str:
"""Infer app type from state nodes in path."""
type_map = {
"其他应用": "其他应用",
"SDK限制": "SDK限制",
"通过接入SDK限制的应用": "SDK限制",
"系统限制": "系统限制",
"通过系统限制应用": "系统限制",
}
for n in path["nodes"]:
if n["type"] == "state":
for key, val in type_map.items():
if key in n.get("label", ""):
return val
return ""
def _infer_app_state(path: dict) -> str:
"""Infer app state (前台/后台) from decision nodes."""
for n in path["nodes"]:
label = n.get("label", "")
branch = n.get("branch_taken", "")
if "前台" in label:
if branch == "":
return "前台"
elif branch == "":
return "后台"
return ""
def _infer_scope(path: dict) -> str:
"""Infer geographic scope. Defaults to 国内."""
return "国内"
# ---- LLM Extraction for Synthetic Units ----
def extract_rules_for_synthetic_units(
synthetic_units: list[dict], doc: dict, max_retries: int | None = None
) -> list[dict]:
"""Extract IR rules for synthetic function_units using step2's LLM logic."""
from step2_ir_extraction import (
build_document_lookup,
extract_context_package,
extract_rules_for_unit,
)
if max_retries is None:
max_retries = config.MAX_RETRIES_PER_STAGE
sections_by_source, image_by_rid, conflicts_by_section = build_document_lookup(doc)
fragments = []
for unit in synthetic_units:
pkg = extract_context_package(
unit, doc, sections_by_source, image_by_rid, conflicts_by_section
)
# Enrich pkg with unit's own path and description
pkg["unit_path"] = unit.get("path", [])
pkg["unit_description"] = unit.get("description", pkg["unit_description"])
try:
rules = extract_rules_for_unit(pkg, max_retries)
except Exception as e:
rules = []
fragments.append({
"unit_id": unit["unit_id"],
"unit_name": unit.get("name", ""),
"rules": rules,
"auto_generated": True,
})
print(f" {unit['unit_id']}: {len(rules)} 条规则")
return fragments
# ---- Iterative Auto-Completion ----
def run_autocomplete(
all_paths: dict[str, list[dict]],
existing_rules: list[dict],
doc: dict,
) -> tuple[list[dict], dict]:
"""Run iterative auto-completion. Returns (autocomplete_fragments, final_stats)."""
print(f"\n 初始路径覆盖率分析...")
covered, uncovered, stats = compute_path_coverage(all_paths, existing_rules)
print(f" 覆盖: {stats['covered_paths']}/{stats['total_paths']} "
f"({stats['coverage_pct']}%)")
if not uncovered:
print(f" 所有路径已覆盖,无需自动补全")
return [], stats
print(f" 未覆盖路径: {len(uncovered)}")
all_fragments = []
best_stats = stats
for round_n in range(1, config.MAX_RETRIES_PER_STAGE + 1):
if not uncovered:
break
print(f"\n--- 自动补全 第 {round_n} 轮 ---")
print(f"{len(uncovered)} 条未覆盖路径生成合成单元...")
# Generate synthetic units
start_seq = (round_n - 1) * len(uncovered) + 1
synthetic_units = [
generate_synthetic_unit(path, start_seq + i)
for i, path in enumerate(uncovered)
]
# Extract rules via LLM
max_llm_workers = min(2, len(synthetic_units))
if len(synthetic_units) <= 1:
fragments = extract_rules_for_synthetic_units(synthetic_units, doc)
else:
# Sequential to avoid flooding the API
fragments = extract_rules_for_synthetic_units(synthetic_units, doc)
all_fragments.extend(fragments)
# Re-compute coverage
all_rules = existing_rules + [
rule for f in fragments for rule in f.get("rules", [])
]
covered, uncovered, stats = compute_path_coverage(all_paths, all_rules)
print(f"{round_n} 轮后覆盖: {stats['covered_paths']}/{stats['total_paths']} "
f"({stats['coverage_pct']}%)")
if stats["coverage_pct"] > best_stats["coverage_pct"]:
best_stats = stats
if stats["coverage_pct"] >= config.COVERAGE_TARGET * 100:
print(f" 达到目标覆盖率 {config.COVERAGE_TARGET:.0%},停止")
break
# If coverage didn't improve, try a different approach next round
uncovered_decision_nodes = set()
for p in uncovered:
for n in p.get("decision_nodes", []):
uncovered_decision_nodes.add(n.get("label", ""))
if not uncovered_decision_nodes:
print(f" 无更多可补全路径,停止")
break
return all_fragments, best_stats
# ---- Main ----
def main():
print("=" * 60)
print("阶段 2.5:分支覆盖自动补全")
print("=" * 60)
# 1. Load inputs
print(f"\n[1/5] 加载输入...")
doc = config.load_input_document()
fragments = config.load_json(config.IR_FRAGMENTS_JSON)
all_rules = []
for f in fragments:
all_rules.extend(f.get("rules", []))
print(f" 已有规则: {len(all_rules)}")
# 2. Enumerate paths
print(f"\n[2/5] 枚举逻辑树路径...")
all_paths = enumerate_all_paths(doc)
total_paths = sum(len(v) for v in all_paths.values())
print(f"{total_paths} 条路径")
# Save path enumeration for downstream audit
path_enum_data = {
"logic_tree_paths": {
k: [{kk: vv for kk, vv in p.items() if kk != "nodes"} for p in v]
for k, v in all_paths.items()
},
"total_paths": total_paths,
}
config.save_json(path_enum_data, config.PATH_ENUM_JSON)
# 3. Run auto-completion
print(f"\n[3/5] 运行自动补全...")
autocomplete_fragments, final_stats = run_autocomplete(
all_paths, all_rules, doc
)
# 4. Save
print(f"\n[4/5] 保存自动补全片段...")
config.save_json(
autocomplete_fragments, config.IR_AUTOCOMPLETE_FRAGMENTS_JSON
)
print(f" 输出: {config.IR_AUTOCOMPLETE_FRAGMENTS_JSON}")
print(f" 生成 {len(autocomplete_fragments)} 个补全片段")
# 5. Summary
print(f"\n[5/5] 完成!")
print(f" 最终路径覆盖: {final_stats['covered_paths']}/{final_stats['total_paths']} "
f"({final_stats['coverage_pct']}%)")
if final_stats["coverage_pct"] < config.COVERAGE_TARGET * 100:
remaining = final_stats["total_paths"] - final_stats["covered_paths"]
print(f" WARN: {remaining} 条路径仍未覆盖,将在审计报告中列出")
if __name__ == "__main__":
main()
@@ -0,0 +1,508 @@
"""
Stage 2: Per Function Unit IR Extraction.
For each function unit from the semantic index, constructs a precision context
package and calls the LLM to extract detailed IR rules.
Runs multiple LLM calls in parallel (up to MAX_CONCURRENCY).
Output: output/ir_fragments.json
"""
import concurrent.futures
import json
import re
import sys
import time
from pathlib import Path
import config
MAX_CONCURRENCY = 3 # Max parallel LLM calls
def load_semantic_index() -> dict:
"""Load the semantic index from Stage 1."""
return config.load_json(config.SEMANTIC_INDEX_JSON)
def build_document_lookup(doc: dict):
"""Build lookup structures for fast context extraction from the document."""
# sections_by_source: "3.1.1" -> section dict
sections_by_source = {}
for section in doc.get("sections", []):
source = section.get("source", "")
# Normalize: extract leading number like "3.1.1"
parts = source.split()
if parts:
key = parts[0].strip()
sections_by_source[key] = section
# image_by_rid: "rId16" -> image_analysis entry
image_by_rid = {}
for img in doc.get("image_analysis", []):
rid = img.get("rid", "")
if rid:
image_by_rid[rid] = img
# Conflicts indexed by section
conflicts_by_section = {}
for c in doc.get("resolved_conflicts", []):
section = c.get("section", "")
key = section.split()[0] if section else ""
conflicts_by_section.setdefault(key, []).append(c)
return sections_by_source, image_by_rid, conflicts_by_section
def extract_context_package(
fu: dict, doc: dict, sections_by_source: dict, image_by_rid: dict,
conflicts_by_section: dict
) -> dict:
"""Build a precision context package for a single function unit."""
texts = []
tables = []
logic_trees = []
seen_sections = set()
seen_images = set()
for src in fu.get("sources", []):
src_type = src.get("type", "")
section_key = src.get("section", "").split()[0] if src.get("section") else ""
# --- Text source ---
if src_type in ("table", "para") and section_key:
if section_key in seen_sections:
continue
seen_sections.add(section_key)
section = sections_by_source.get(section_key)
if section is None:
# Fuzzy match by prefix
for key in sections_by_source:
if key.startswith(section_key):
section = sections_by_source[key]
break
if section:
for block in section.get("blocks", []):
if block["type"] == "para":
texts.append({
"section": section_key,
"text": block["text"]
})
elif block["type"] == "table":
row_num = src.get("row") if src_type == "table" else None
if row_num is not None:
# Extract only the specific row
matching_rows = []
for r in block.get("rows", []):
for c in r.get("columns", []):
if c.get("row") == row_num:
matching_rows.append({
"headers": block.get("headers", []),
"cells": {
col["name"]: col["text"]
for col in r["columns"]
},
"row": row_num
})
break
tables.append({
"section": section_key,
"headers": block.get("headers", []),
"rows": matching_rows,
"all_rows": [
{
"row": col.get("row"),
"name": col.get("name"),
"text": col.get("text")
}
for row in block.get("rows", [])
for col in row.get("columns", [])
]
})
else:
# Include full table
tables.append({
"section": section_key,
"headers": block.get("headers", []),
"all_rows": [
{
"row": col.get("row"),
"name": col.get("name"),
"text": col.get("text")
}
for row in block.get("rows", [])
for col in row.get("columns", [])
]
})
# --- Logic tree source ---
if src_type == "logic_tree":
image_id = src.get("image_id", "")
if not image_id or image_id in seen_images:
continue
seen_images.add(image_id)
img = image_by_rid.get(image_id)
if img:
lt = img.get("logic_tree")
if lt:
logic_trees.append({
"image_id": image_id,
"description": img.get("description", ""),
"tree": lt
})
# Include relevant resolved conflicts
relevant_conflicts = []
for section_key in seen_sections:
for c in conflicts_by_section.get(section_key, []):
relevant_conflicts.append(c)
return {
"unit_id": fu["unit_id"],
"unit_name": fu.get("name", ""),
"unit_description": fu.get("description", ""),
"unit_path": fu.get("path", []),
"texts": texts,
"tables": tables,
"logic_trees": logic_trees,
"resolved_conflicts": relevant_conflicts
}
def format_context_package(pkg: dict) -> str:
"""Format a context package as a readable string for the prompt."""
parts = []
# Texts
parts.append("【文字段落】")
for i, t in enumerate(pkg.get("texts", [])):
parts.append(f"[{t.get('section', '?')}] {t.get('text', '')}")
if not pkg.get("texts"):
parts.append("(无)")
# Tables
parts.append("\n【表格数据】")
for i, tbl in enumerate(pkg.get("tables", [])):
parts.append(f"表格 {i+1} (section={tbl.get('section', '?')})")
headers = tbl.get("headers", [])
parts.append(f" 表头: {headers}")
parts.append(" 全部行数据:")
for row in tbl.get("all_rows", []):
parts.append(
f"{row.get('row','?')}[{row.get('name','?')}]: {row.get('text','')}"
)
# Highlight matched rows if any
matched = tbl.get("rows", [])
if matched:
parts.append(" <重点关注行>:")
for mr in matched:
parts.append(f"{mr.get('row','?')}: {mr.get('cells', {})}")
if not pkg.get("tables"):
parts.append("(无)")
# Logic trees
parts.append("\n【逻辑树】")
for i, lt in enumerate(pkg.get("logic_trees", [])):
parts.append(f"逻辑树 {i+1} (image_id={lt.get('image_id', '?')})")
parts.append(f" 描述: {lt.get('description', '')[:200]}")
tree = lt.get("tree", {})
parts.append(f" 根: {tree.get('root', '?')}")
parts.append(" 节点:")
for node in tree.get("nodes", []):
nid = node.get("id", "?")
ntype = node.get("type", "?")
desc = node.get("description", "") or node.get("condition", "")
parts.append(f" [{ntype}] {nid}: {desc}")
for br in node.get("branches", []):
parts.append(f"{br['value']}{br['target']}")
if not pkg.get("logic_trees"):
parts.append("(无)")
# Conflicts
conflicts = pkg.get("resolved_conflicts", [])
if conflicts:
parts.append("\n【图文冲突仲裁】")
for c in conflicts:
parts.append(
f" [{c.get('conflict_type', '?')}] 以{c.get('source', '?')}为准: "
f"{c.get('correction', '')}"
)
return "\n".join(parts)
def _escape_json_for_format(s: str) -> str:
"""Escape curly braces in a JSON string for use with str.format()."""
return s.replace("{", "{{").replace("}", "}}")
def build_prompt(pkg: dict, format_feedback: str = "") -> str:
"""Build the LLM prompt for a single function unit."""
template_path = Path(config.PROMPTS_DIR) / "step2_ir_extraction.txt"
template = template_path.read_text(encoding="utf-8")
prompt = template.format(
unit_id=pkg["unit_id"],
unit_name=_escape_json_for_format(pkg["unit_name"]),
unit_description=_escape_json_for_format(pkg["unit_description"]),
texts=_escape_json_for_format(
json.dumps(pkg.get("texts", []), ensure_ascii=False, indent=2)
),
tables=_escape_json_for_format(
json.dumps(pkg.get("tables", []), ensure_ascii=False, indent=2)
),
logic_trees=_escape_json_for_format(
json.dumps(pkg.get("logic_trees", []), ensure_ascii=False, indent=2)
),
resolved_conflicts=_escape_json_for_format(
json.dumps(pkg.get("resolved_conflicts", []), ensure_ascii=False, indent=2)
),
format_feedback=_escape_json_for_format(format_feedback),
)
return prompt
def extract_json_from_response(text: str) -> str:
"""Extract JSON array from LLM response."""
m = re.search(r"```(?:json)?\s*(\[[\s\S]*?\])\s*```", text)
if m:
return m.group(1).strip()
# Find outermost [ ... ]
start = text.find("[")
if start == -1:
raise ValueError("No JSON array found in LLM response")
depth = 0
for i in range(start, len(text)):
if text[i] == "[":
depth += 1
elif text[i] == "]":
depth -= 1
if depth == 0:
return text[start : i + 1]
raise ValueError("Unclosed JSON array in LLM response")
def _check_rule_fields(rules: list[dict]) -> tuple[bool, list[dict]]:
"""Validate each rule has required fields. Returns (passed, failures).
Each failure: {rule_id, field, issue}
"""
failures = []
for j, rule in enumerate(rules):
if not isinstance(rule, dict):
failures.append({"rule_id": f"rule[{j}]", "field": "-", "issue": "规则不是 dict"})
continue
rid = rule.get("rule_id") or f"rule[{j}]"
if not rule.get("path"):
failures.append({"rule_id": rid, "field": "path", "issue": "缺少 path 字段(必填)"})
precond = rule.get("precondition") or {}
if not precond.get("geographic_scope"):
failures.append({"rule_id": rid, "field": "precondition.geographic_scope", "issue": "缺少 geographic_scope(必填)"})
for k, action in enumerate(rule.get("actions") or []):
if not isinstance(action, dict):
continue
if action.get("type") == "user_interaction":
content = action.get("content") or ""
if not content:
failures.append({
"rule_id": rid, "field": f"actions[{k}].content",
"issue": "user_interaction 的 content 为空"
})
elif any(ph in content for ph in ["文案由业务定义", "待定", "自定义"]):
failures.append({
"rule_id": rid, "field": f"actions[{k}].content",
"issue": f"content 包含占位符: '{content}'"
})
trigger = rule.get("trigger") or {}
for k, cond in enumerate(trigger.get("conditions") or []):
if isinstance(cond, dict):
if not cond.get("signal"):
failures.append({
"rule_id": rid, "field": f"trigger.conditions[{k}].signal",
"issue": "缺少 signal"
})
if not cond.get("operator"):
failures.append({
"rule_id": rid, "field": f"trigger.conditions[{k}].operator",
"issue": "缺少 operator"
})
if "value" not in cond:
failures.append({
"rule_id": rid, "field": f"trigger.conditions[{k}].value",
"issue": "缺少 value"
})
return len(failures) == 0, failures
def _build_fix_prompt(failures: list[dict]) -> str:
"""Build a format-fix instruction block for the prompt."""
if not failures:
return ""
lines = [
"\n## 上一轮格式问题修正\n",
"上一轮输出的规则存在以下格式问题,请修正后重新输出:\n",
]
for f in failures:
lines.append(f"- **{f['rule_id']}.{f['field']}**: {f['issue']}")
lines.append("\n请修正以上所有问题,重新输出完整的规则数组。")
return "\n".join(lines)
def extract_rules_for_unit(pkg: dict, max_retries: int | None = None) -> list[dict]:
"""Call LLM for one function unit, return its IR rules.
Includes format validation with auto-fix retries.
"""
if max_retries is None:
max_retries = config.MAX_RETRIES_PER_STAGE
client = config.llm_client()
prompt = build_prompt(pkg)
last_failures = []
for attempt in range(max_retries + 1):
# Append format feedback on retry
if attempt > 0 and last_failures:
fix_text = _build_fix_prompt(last_failures)
prompt = build_prompt(pkg, format_feedback=fix_text)
try:
resp = client.chat.completions.create(
model=config.MODEL_NAME,
messages=[
{
"role": "system",
"content": "你是一个精确的 JSON 输出引擎。只输出合法的 JSON 数组。",
},
{"role": "user", "content": prompt},
],
temperature=config.TEMPERATURE,
max_tokens=config.MAX_TOKENS,
)
content = resp.choices[0].message.content
if content is None:
raise RuntimeError("LLM returned empty response")
json_str = extract_json_from_response(content)
rules = json.loads(json_str)
if not isinstance(rules, list):
raise ValueError(f"Expected JSON array, got {type(rules).__name__}")
# Format validation
passed, failures = _check_rule_fields(rules)
if passed:
return rules
# Format issues found — retry with fix instructions
print(f" 格式问题 ({len(failures)} 个): {[f['field'] for f in failures[:5]]}")
last_failures = failures
if attempt < max_retries:
time.sleep(1)
except (json.JSONDecodeError, ValueError) as e:
print(f" JSON 解析失败 (尝试 {attempt + 1}): {e}")
last_failures = [{"rule_id": "?", "field": "json", "issue": str(e)}]
if attempt < max_retries:
time.sleep(2)
# Exhausted retries — return what we have (even if imperfect)
print(f" WARN: {pkg['unit_id']} 格式修复耗尽了 {max_retries} 次重试")
return []
def extract_all_rules(
semantic_index: dict, doc: dict
) -> list[dict]:
"""Extract IR rules for all function units. Runs in parallel up to MAX_CONCURRENCY."""
sections_by_source, image_by_rid, conflicts_by_section = build_document_lookup(doc)
function_units = semantic_index.get("function_units", [])
print(f"{len(function_units)} 个功能单元待处理")
print(f" 最大并发: {MAX_CONCURRENCY}")
# Build context packages (serial — fast)
packages = []
for fu in function_units:
pkg = extract_context_package(
fu, doc, sections_by_source, image_by_rid, conflicts_by_section
)
packages.append(pkg)
# Run LLM calls in parallel
fragments = []
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_CONCURRENCY) as executor:
futures = {}
for i, pkg in enumerate(packages):
future = executor.submit(extract_rules_for_unit, pkg)
futures[future] = (i, pkg["unit_id"], pkg["unit_name"])
for future in concurrent.futures.as_completed(futures):
i, uid, uname = futures[future]
try:
rules = future.result()
fragments.append({
"unit_id": uid,
"unit_name": uname,
"rules": rules
})
print(f" [OK] {uid} ({uname}): {len(rules)} 条规则")
except Exception as e:
print(f" [FAIL] {uid} ({uname}): 失败 — {e}")
fragments.append({
"unit_id": uid,
"unit_name": uname,
"rules": [],
"error": str(e)
})
# Sort by unit_id to maintain stable ordering
fragments.sort(key=lambda f: f["unit_id"])
return fragments
def main():
print("=" * 60)
print("阶段二:逐功能单元 IR 提取")
print("=" * 60)
# 1. Load inputs
print(f"\n[1/3] 加载输入...")
semantic_index = load_semantic_index()
doc = config.load_input_document()
n_units = len(semantic_index.get("function_units", []))
print(f" 语义索引: {n_units} 个功能单元")
# 2. Extract rules
print(f"\n[2/3] 逐单元提取 IR 规则...")
fragments = extract_all_rules(semantic_index, doc)
# 3. Save
print(f"\n[3/3] 保存 IR 片段...")
config.save_json(fragments, config.IR_FRAGMENTS_JSON)
total_rules = sum(len(f["rules"]) for f in fragments)
failed_units = [f for f in fragments if f.get("error")]
print(f"\n完成! {len(fragments)} 个功能单元, 共 {total_rules} 条规则")
if failed_units:
print(f" [WARN] {len(failed_units)} 个单元提取失败: "
f"{[f['unit_id'] for f in failed_units]}")
print(f"输出: {config.IR_FRAGMENTS_JSON}")
if __name__ == "__main__":
main()
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,472 @@
"""
Tests for ensemble_merge.py all pure Python, no LLM calls, no file I/O.
Each test uses hardcoded mock data to verify one piece of the merge logic.
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from ensemble_merge import (
concept_name_similarity,
cluster_concepts,
merge_concept_cluster,
unit_node_jaccard,
path_similarity,
unit_similarity,
cluster_function_units,
pick_best_representative,
compute_confidence_versions,
ensemble_merge_concepts,
ensemble_merge_function_units,
ensemble_merge,
_collect_logic_tree_nodes,
)
PASS = "[PASS]"
FAIL = "[FAIL]"
# ---- Mock helpers ----
def _mk_unit(unit_id, name, path, logic_tree_nodes, description="", sources=None):
"""Create a minimal function_unit dict for testing."""
if sources is None:
srcs = []
if logic_tree_nodes:
srcs.append({
"image_id": "rId16",
"type": "logic_tree",
"logic_tree_nodes": logic_tree_nodes,
})
if not srcs:
srcs.append({
"section": "3.1",
"type": "table",
"text_snippet": "test",
})
else:
srcs = sources
return {
"unit_id": unit_id,
"name": name,
"description": description or f"desc for {name}",
"path": path,
"sources": srcs,
}
def _mk_concept(name, parent=None, aliases=None, defined_in=None):
"""Create a minimal concept dict for testing."""
return {
"name": name,
"aliases": aliases or [],
"defined_in": defined_in or ["3.1"],
"parent": parent,
}
# =============================================================================
# Test 1: concept_name_similarity
# =============================================================================
def test_concept_name_similarity_exact():
assert concept_name_similarity("国内", "国内") == 1.0
assert concept_name_similarity("行车娱乐限制", "行车娱乐限制") == 1.0
def test_concept_name_similarity_substring():
sim = concept_name_similarity("国内行车娱乐限制", "行车娱乐限制")
assert sim >= 0.85, f"expected >= 0.85, got {sim}"
def test_concept_name_similarity_different():
sim = concept_name_similarity("国内", "海外")
assert sim < 0.7, f"expected < 0.7, got {sim}"
def test_concept_name_similarity_seq_matcher():
sim = concept_name_similarity("前台打断", "前台应用打断")
assert 0.6 < sim < 0.95, f"expected 0.6-0.95, got {sim}"
# =============================================================================
# Test 2: _collect_logic_tree_nodes
# =============================================================================
def test_collect_logic_tree_nodes():
unit = _mk_unit("U1", "test", ["A"], ["n1", "n2", "n3"])
nodes = _collect_logic_tree_nodes(unit)
assert nodes == {"n1", "n2", "n3"}
def test_collect_logic_tree_nodes_empty():
unit = _mk_unit("U2", "test", ["A"], [], sources=[{"section": "3.1", "type": "table"}])
nodes = _collect_logic_tree_nodes(unit)
assert nodes == set()
# =============================================================================
# Test 3: unit_node_jaccard
# =============================================================================
def test_unit_node_jaccard_identical():
u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2", "n3"])
u2 = _mk_unit("U2", "b", ["A"], ["n1", "n2", "n3"])
assert unit_node_jaccard(u1, u2) == 1.0
def test_unit_node_jaccard_partial():
u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2", "n3", "n4"])
u2 = _mk_unit("U2", "b", ["A"], ["n1", "n2", "n3"])
# intersection=3, union=4
assert abs(unit_node_jaccard(u1, u2) - 0.75) < 0.01
def test_unit_node_jaccard_disjoint():
u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2"])
u2 = _mk_unit("U2", "b", ["B"], ["n3", "n4"])
assert unit_node_jaccard(u1, u2) == 0.0
def test_unit_node_jaccard_both_empty():
u1 = _mk_unit("U1", "a", ["A"], [], sources=[{"section": "3.1", "type": "table"}])
u2 = _mk_unit("U2", "b", ["B"], [], sources=[{"section": "3.1", "type": "table"}])
assert unit_node_jaccard(u1, u2) == 0.0
# =============================================================================
# Test 4: path_similarity
# =============================================================================
def test_path_similarity_identical():
assert path_similarity(
["国内", "系统限制", "前台打断"],
["国内", "系统限制", "前台打断"],
) == 1.0
def test_path_similarity_partial():
sim = path_similarity(
["国内", "系统限制", "前台打断"],
["国内", "系统限制", "后台限制启动"],
)
# 2/3 set overlap, sequential 3/5 ≈ 0.6
assert 0.4 < sim < 0.9, f"expected 0.4-0.9, got {sim}"
def test_path_similarity_different():
sim = path_similarity(["国内"], ["海外"])
assert sim < 0.7, f"expected < 0.7, got {sim}"
# =============================================================================
# Test 5: unit_similarity
# =============================================================================
def test_unit_similarity_identical():
u = _mk_unit("U1", "国内-系统限制-前台打断",
["国内", "系统限制", "前台打断"],
["n1", "n2", "n3", "n8", "n19"])
assert unit_similarity(u, u) > 0.99
def test_unit_similarity_different():
u1 = _mk_unit("U1", "a", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3"])
u2 = _mk_unit("U2", "b", ["海外", "SDK限制"], ["n10", "n11", "n12"])
assert unit_similarity(u1, u2) < 0.3
# =============================================================================
# Test 6: cluster_concepts
# =============================================================================
def test_cluster_concepts_identical():
v0 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")]
v1 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")]
v2 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")]
clusters = cluster_concepts([v0, v1, v2])
# Should have exactly 3 clusters (国内, 海外, 系统限制)
assert len(clusters) == 3, f"expected 3 clusters, got {len(clusters)}"
for c in clusters:
assert len(c) == 3, f"expected each cluster to have 3 members, got {len(c)}"
def test_cluster_concepts_name_variation():
v0 = [_mk_concept("国内行车娱乐限制", parent="国内")]
v1 = [_mk_concept("行车娱乐限制", parent="国内")]
v2 = [_mk_concept("国内行车娱乐限制", parent="国内")]
clusters = cluster_concepts([v0, v1, v2])
assert len(clusters) == 1, f"expected 1 cluster, got {len(clusters)}"
assert len(clusters[0]) == 3, f"expected 3 members, got {len(clusters[0])}"
# =============================================================================
# Test 7: merge_concept_cluster
# =============================================================================
def test_merge_concept_cluster():
cluster = [
(0, _mk_concept("国内行车娱乐限制", parent="国内", aliases=["限制"])),
(1, _mk_concept("行车娱乐限制", parent="国内", aliases=["行车限制"])),
(2, _mk_concept("行车娱乐限制", parent="国内", aliases=["限制"])),
]
merged, conf = merge_concept_cluster(cluster, 3)
assert "行车娱乐限制" in merged["name"]
assert merged["parent"] == "国内"
assert set(merged["aliases"]) == {"限制", "行车限制"}
assert conf in ("high", "medium")
# =============================================================================
# Test 8: cluster_function_units
# =============================================================================
def test_cluster_function_units_all_agree():
u0 = _mk_unit("U-001", "国内-系统限制-前台打断",
["国内", "系统限制", "前台打断"],
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
"switch ON, system app, foreground, speed>=15, non-P, interrupt + toast")
u1 = _mk_unit("U-001", "国内-系统限制-前台打断",
["国内", "系统限制", "前台打断"],
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
"switch ON, system app, foreground, speed>=15, non-P, interrupt + toast")
u2 = _mk_unit("U-001", "国内-系统限制-前台打断",
["国内", "系统限制", "前台打断"],
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
"switch ON, system app, foreground, interrupt")
clusters = cluster_function_units([[u0], [u1], [u2]])
assert len(clusters) == 1, f"expected 1 cluster, got {len(clusters)}"
assert len(clusters[0]) == 3
def test_cluster_function_units_partial_agree():
u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
["n1", "n2", "n3", "n8", "n19"])
u1 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
["n1", "n2", "n3", "n8", "n19"])
u2 = _mk_unit("U-002", "禁止", ["国内", "系统限制", "后台限制启动"],
["n5", "n6"])
clusters = cluster_function_units([[u0], [u1], [u2]])
# u0+u1 in one cluster, u2 in another
assert len(clusters) == 2, f"expected 2 clusters, got {len(clusters)}"
cluster_sizes = sorted(len(c) for c in clusters)
assert cluster_sizes == [1, 2], f"expected cluster sizes [1,2], got {cluster_sizes}"
def test_cluster_function_units_all_disagree():
u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3"])
u1 = _mk_unit("U-002", "禁止", ["国内", "系统限制", "后台限制启动"], ["n5", "n6"])
u2 = _mk_unit("U-003", "SDK", ["国内", "SDK限制"], ["n10", "n11"])
clusters = cluster_function_units([[u0], [u1], [u2]])
assert len(clusters) == 3, f"expected 3 clusters, got {len(clusters)}"
# =============================================================================
# Test 9: pick_best_representative
# =============================================================================
def test_pick_best_representative_prefers_rich():
u0 = _mk_unit("U-001", "short", ["国内", "系统限制"],
["n1", "n2", "n3"],
description="short desc")
u1 = _mk_unit("U-001", "detailed", ["国内", "系统限制", "前台打断"],
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
description="very detailed description of the full rule behavior " * 5)
cluster = [(0, u0), (1, u1)]
best = pick_best_representative(cluster)
# u1 should win: more nodes, longer description, though u0 has lower temp
assert best["name"] == "detailed"
# =============================================================================
# Test 10: compute_confidence_versions
# =============================================================================
def test_confidence_high_unanimous():
assert compute_confidence_versions(3, 3, True) == "high"
def test_confidence_high_two_of_three_with_t0():
assert compute_confidence_versions(2, 3, True) == "high"
def test_confidence_medium_two_of_three_without_t0():
assert compute_confidence_versions(2, 3, False) == "medium"
def test_confidence_low_one_of_three():
assert compute_confidence_versions(1, 3, False) == "low"
def test_confidence_high_all_two_versions():
assert compute_confidence_versions(2, 2, True) == "high"
# =============================================================================
# Test 11: ensemble_merge_concepts
# =============================================================================
def test_ensemble_merge_concepts():
v0 = [_mk_concept("国内"), _mk_concept("海外"),
_mk_concept("国内行车娱乐限制", parent="国内")]
v1 = [_mk_concept("国内"), _mk_concept("海外"),
_mk_concept("行车娱乐限制", parent="国内",
aliases=["限制"], defined_in=["3.1", "3.1.1"])]
v2 = [_mk_concept("国内"), _mk_concept("海外"),
_mk_concept("行车娱乐限制", parent="国内")]
merged = ensemble_merge_concepts([v0, v1, v2])
# Should merge the 3 concepts across 3 versions into 3 clusters
assert len(merged) == 3, f"expected 3 merged concepts, got {len(merged)}"
for c in merged:
assert "confidence" in c
assert "ensemble_support" in c
assert c["ensemble_support"] == "3/3"
# =============================================================================
# Test 12: ensemble_merge_function_units
# =============================================================================
def test_ensemble_merge_function_units():
u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
description="full description A")
u1 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
description="full description B (more detail)")
u2 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25"],
description="partial description")
merged = ensemble_merge_function_units([[u0], [u1], [u2]])
assert len(merged) == 1, f"expected 1 unit, got {len(merged)}"
unit = merged[0]
assert unit["confidence"] == "high"
assert unit["ensemble_support"] == "3/3"
assert unit["source_versions"] == 3
assert unit["unit_id"].startswith("FU-ENS-")
# Should have picked u1 (more detail)
assert "more detail" in unit["description"]
# =============================================================================
# Test 13: ensemble_merge full integration
# =============================================================================
def test_ensemble_merge_full():
v0 = {
"feature_name": "行车娱乐限制",
"concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")],
"function_units": [
_mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
["n1", "n2", "n3", "n8", "n19", "n25", "n26"]),
_mk_unit("U-002", "后台禁止", ["国内", "系统限制", "后台限制启动"],
["n5", "n6"]),
],
}
v1 = {
"feature_name": "行车娱乐限制",
"concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")],
"function_units": [
_mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
["n1", "n2", "n3", "n8", "n19", "n25", "n26"]),
_mk_unit("U-003", "SDK自定义", ["国内", "SDK限制", "自定义限制"],
["n10", "n11"]),
],
}
v2 = {
"feature_name": "行车娱乐限制",
"concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")],
"function_units": [
_mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
["n1", "n2", "n3", "n8", "n19", "n25", "n26"]),
],
}
result = ensemble_merge([v0, v1, v2])
assert result["feature_name"] == "行车娱乐限制"
assert result["ensemble_versions"] == 3
units = result["function_units"]
concepts = result["concepts"]
# Concepts: 国内 + 系统限制
assert len(concepts) == 2
# Units: 打断 (3 versions → high), 后台禁止 (1 version → low), SDK (1 version → low)
assert len(units) == 3
high_units = [u for u in units if u["confidence"] == "high"]
low_units = [u for u in units if u["confidence"] == "low"]
assert len(high_units) == 1
assert len(low_units) == 2
# All units should have ensemble fields
for u in units:
assert "confidence" in u
assert "ensemble_support" in u
assert "source_versions" in u
# Confidence summary
cs = result["confidence_summary"]
assert cs["total_units"] == 3
assert cs["high"] == 1
assert cs["low"] == 2
# =============================================================================
# Runner
# =============================================================================
def run_all_tests():
print("=" * 60)
print("Ensemble Merge 测试 (纯 Python, 无 LLM)")
print("=" * 60)
tests = [
("concept_name_similarity exact", test_concept_name_similarity_exact),
("concept_name_similarity substring", test_concept_name_similarity_substring),
("concept_name_similarity different", test_concept_name_similarity_different),
("concept_name_similarity seq_matcher", test_concept_name_similarity_seq_matcher),
("collect_logic_tree_nodes", test_collect_logic_tree_nodes),
("collect_logic_tree_nodes empty", test_collect_logic_tree_nodes_empty),
("unit_node_jaccard identical", test_unit_node_jaccard_identical),
("unit_node_jaccard partial", test_unit_node_jaccard_partial),
("unit_node_jaccard disjoint", test_unit_node_jaccard_disjoint),
("unit_node_jaccard both_empty", test_unit_node_jaccard_both_empty),
("path_similarity identical", test_path_similarity_identical),
("path_similarity partial", test_path_similarity_partial),
("path_similarity different", test_path_similarity_different),
("unit_similarity identical", test_unit_similarity_identical),
("unit_similarity different", test_unit_similarity_different),
("cluster_concepts identical", test_cluster_concepts_identical),
("cluster_concepts name variation", test_cluster_concepts_name_variation),
("merge_concept_cluster", test_merge_concept_cluster),
("cluster_function_units all_agree", test_cluster_function_units_all_agree),
("cluster_function_units partial_agree", test_cluster_function_units_partial_agree),
("cluster_function_units all_disagree", test_cluster_function_units_all_disagree),
("pick_best_representative", test_pick_best_representative_prefers_rich),
("confidence high unanimous", test_confidence_high_unanimous),
("confidence high 2/3 with t0", test_confidence_high_two_of_three_with_t0),
("confidence medium 2/3 no t0", test_confidence_medium_two_of_three_without_t0),
("confidence low 1/3", test_confidence_low_one_of_three),
("confidence high 2/2", test_confidence_high_all_two_versions),
("ensemble_merge_concepts", test_ensemble_merge_concepts),
("ensemble_merge_function_units", test_ensemble_merge_function_units),
("ensemble_merge full", test_ensemble_merge_full),
]
passed = 0
failed = 0
for name, test_fn in tests:
try:
test_fn()
print(f" {PASS} {name}")
passed += 1
except AssertionError as e:
print(f" {FAIL} {name}: {e}")
failed += 1
except Exception as e:
print(f" {FAIL} {name}: unexpected {type(e).__name__}: {e}")
failed += 1
print(f"\n{'='*60}")
if failed == 0:
print(f"{PASS} 所有 {passed} 个测试通过!")
else:
print(f"{FAIL} {failed}/{passed + failed} 个测试失败")
print(f"{'='*60}")
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)
@@ -0,0 +1,370 @@
"""
Tests for Stage 1 (Semantic Index).
Validates that the generated semantic_index.json meets all completeness
and structural requirements, including the new iterative features:
- function_units have path fields
- concepts have parent references
- logic tree node coverage meets thresholds
"""
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import config
PASS = "[PASS]"
FAIL = "[FAIL]"
WARN = "[WARN]"
def load_inputs():
"""Load semantic_index.json and the original parsed document."""
try:
si = config.load_json(config.SEMANTIC_INDEX_JSON)
except FileNotFoundError:
print(f"{FAIL} semantic_index.json 未找到: {config.SEMANTIC_INDEX_JSON}")
print(" 请先运行 step1_semantic_index.py")
sys.exit(1)
doc = config.load_input_document()
return si, doc
def build_image_index(doc: dict) -> dict[str, dict]:
"""Build lookup: image rId -> image_analysis entry."""
idx = {}
for img in doc.get("image_analysis", []):
rid = img.get("rid", "")
if rid:
idx[rid] = img
return idx
def build_logic_tree_node_index(doc: dict) -> dict[str, set[str]]:
"""Build lookup: image rId -> set of all node IDs in that logic_tree."""
idx = {}
for img in doc.get("image_analysis", []):
rid = img.get("rid", "")
lt = img.get("logic_tree")
if lt and rid:
node_ids = {n["id"] for n in lt.get("nodes", [])}
idx[rid] = node_ids
return idx
def check_unit_ids(units: list[dict]) -> list[str]:
"""Check that every function_unit has a non-empty unit_id and name."""
errors = []
seen_ids = set()
for i, fu in enumerate(units):
uid = fu.get("unit_id", "")
name = fu.get("name", "")
if not uid:
errors.append(f"function_unit[{i}]: unit_id 为空")
elif uid in seen_ids:
errors.append(f"function_unit[{i}]: unit_id '{uid}' 重复")
seen_ids.add(uid)
if not name:
errors.append(f"function_unit[{i}] ({uid}): name 为空")
return errors
def check_unit_paths(units: list[dict]) -> list[str]:
"""Check that every function_unit has a non-empty path array."""
errors = []
for fu in units:
uid = fu.get("unit_id", "?")
path = fu.get("path", [])
if not path:
errors.append(f"{uid}: path 字段为空或缺失")
elif not isinstance(path, list):
errors.append(f"{uid}: path 必须是数组")
return errors
def check_concept_parents(concepts: list[dict]) -> list[str]:
"""Check that non-scope concepts have valid parent references."""
errors = []
concept_names = {c.get("name", "") for c in concepts}
scope_concepts = {"国内", "海外"}
for c in concepts:
name = c.get("name", "?")
parent = c.get("parent", "")
if name in scope_concepts:
# Scope concepts should have no parent
if parent:
errors.append(f"scope 概念 '{name}' 不应有 parent (当前: '{parent}')")
else:
# Non-scope concepts must have a parent
if not parent:
errors.append(f"概念 '{name}' 缺少 parent 字段")
elif parent not in concept_names:
errors.append(f"概念 '{name}' 的 parent '{parent}' 不存在于 concepts 中")
return errors
def check_sources_exist(
units: list[dict], image_index: dict[str, dict], node_index: dict[str, set[str]]
) -> list[str]:
"""Check that all source references point to real content."""
errors = []
for fu in units:
uid = fu.get("unit_id", "?")
sources = fu.get("sources", [])
if not sources:
errors.append(f"{uid}: sources 为空,必须至少引用一张图片或一段文字")
continue
has_text = False
has_image = False
for j, src in enumerate(sources):
src_type = src.get("type", "")
if src_type in ("table", "para"):
has_text = True
section = src.get("section", "")
if not section:
errors.append(f"{uid}.sources[{j}]: 缺少 section")
elif src_type == "logic_tree":
has_image = True
image_id = src.get("image_id", "")
if not image_id:
errors.append(f"{uid}.sources[{j}]: logic_tree 缺少 image_id")
continue
if image_id not in image_index:
errors.append(
f"{uid}.sources[{j}]: image_id '{image_id}' "
f"在 image_analysis 中不存在"
)
continue
node_ids = src.get("logic_tree_nodes", [])
if node_ids and image_id in node_index:
valid_nodes = node_index[image_id]
for nid in node_ids:
if nid not in valid_nodes:
errors.append(
f"{uid}.sources[{j}]: 节点 '{nid}'"
f"{image_id} 的逻辑树中不存在"
)
elif not node_ids:
errors.append(
f"{uid}.sources[{j}]: logic_tree 类型但未提供 logic_tree_nodes"
)
if not has_text and not has_image:
errors.append(f"{uid}: 必须至少引用一个文本或图片来源")
return errors
def check_logic_tree_coverage(
units: list[dict], node_index: dict[str, set[str]]
) -> list[str]:
"""Check that decision and action nodes in logic trees are covered."""
warnings = []
for image_id, all_nodes in node_index.items():
referenced = set()
for fu in units:
for src in fu.get("sources", []):
if src.get("image_id") == image_id:
for nid in src.get("logic_tree_nodes", []):
referenced.add(nid)
uncovered = all_nodes - referenced
if uncovered:
doc = config.load_input_document()
node_types = {}
for img in doc.get("image_analysis", []):
if img.get("rid") == image_id:
lt = img.get("logic_tree", {})
for n in lt.get("nodes", []):
node_types[n["id"]] = n.get("type", "?")
break
decision_action_uncovered = [
n for n in uncovered if node_types.get(n) in ("decision", "action")
]
if decision_action_uncovered:
warnings.append(
f"{image_id}: {len(decision_action_uncovered)}"
f"decision/action 节点未被引用: {decision_action_uncovered}"
)
return warnings
def check_ensemble_confidence(units: list[dict]) -> list[str]:
"""Check that every function_unit has confidence, ensemble_support, source_versions."""
errors = []
valid_conf = {"high", "medium", "low"}
for fu in units:
uid = fu.get("unit_id", "?")
conf = fu.get("confidence", "")
if not conf:
errors.append(f"{uid}: 缺少 confidence 字段")
elif conf not in valid_conf:
errors.append(f"{uid}: confidence='{conf}' 无效 (期望 high/medium/low)")
support = fu.get("ensemble_support", "")
if not support:
errors.append(f"{uid}: 缺少 ensemble_support 字段")
if "source_versions" not in fu:
errors.append(f"{uid}: 缺少 source_versions 字段")
return errors
def check_confidence_summary(si: dict) -> list[str]:
"""Check that confidence_summary counts match actual unit/concept confidence."""
errors = []
cs = si.get("confidence_summary", {})
if not cs:
errors.append("缺少 confidence_summary 字段")
return errors
units = si.get("function_units", [])
concepts = si.get("concepts", [])
# Count actual confidence levels
unit_high = sum(1 for u in units if u.get("confidence") == "high")
unit_medium = sum(1 for u in units if u.get("confidence") == "medium")
unit_low = sum(1 for u in units if u.get("confidence") == "low")
concept_high = sum(1 for c in concepts if c.get("confidence") == "high")
concept_medium = sum(1 for c in concepts if c.get("confidence") == "medium")
concept_low = sum(1 for c in concepts if c.get("confidence") == "low")
if cs.get("total_units", 0) != len(units):
errors.append(f"confidence_summary.total_units={cs.get('total_units')} != 实际 {len(units)}")
if cs.get("high", 0) != unit_high:
errors.append(f"confidence_summary.high={cs.get('high')} != 实际 {unit_high}")
if cs.get("medium", 0) != unit_medium:
errors.append(f"confidence_summary.medium={cs.get('medium')} != 实际 {unit_medium}")
if cs.get("low", 0) != unit_low:
errors.append(f"confidence_summary.low={cs.get('low')} != 实际 {unit_low}")
if cs.get("total_concepts", 0) != len(concepts):
errors.append(f"confidence_summary.total_concepts={cs.get('total_concepts')} != 实际 {len(concepts)}")
if cs.get("concept_high", 0) != concept_high:
errors.append(f"confidence_summary.concept_high={cs.get('concept_high')} != 实际 {concept_high}")
if cs.get("concept_medium", 0) != concept_medium:
errors.append(f"confidence_summary.concept_medium={cs.get('concept_medium')} != 实际 {concept_medium}")
if cs.get("concept_low", 0) != concept_low:
errors.append(f"confidence_summary.concept_low={cs.get('concept_low')} != 实际 {concept_low}")
return errors
def run_all_tests():
print("=" * 60)
print("Step 1 自检测试")
print("=" * 60)
si, doc = load_inputs()
units = si.get("function_units", [])
concepts = si.get("concepts", [])
image_index = build_image_index(doc)
node_index = build_logic_tree_node_index(doc)
all_errors = []
all_warnings = []
# Test 1: unit_id and name validity
errors = check_unit_ids(units)
if errors:
print(f"\n{FAIL} unit_id/name 检查: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} unit_id/name 检查: 全部通过 ({len(units)} 个功能单元)")
# Test 2: path fields
errors = check_unit_paths(units)
if errors:
print(f"\n{FAIL} path 字段检查: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} path 字段检查: 全部通过")
# Test 3: concept parent references
errors = check_concept_parents(concepts)
if errors:
print(f"\n{FAIL} concept parent 检查: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} concept parent 检查: 全部通过 ({len(concepts)} 个概念)")
# Test 4: source references exist
errors = check_sources_exist(units, image_index, node_index)
if errors:
print(f"\n{FAIL} 来源引用检查: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} 来源引用检查: 全部通过")
# Test 5: Logic tree coverage
warnings = check_logic_tree_coverage(units, node_index)
if warnings:
print(f"\n{WARN} 逻辑树节点覆盖率: {len(warnings)} 个警告")
for w in warnings:
print(f" - {w}")
all_warnings.extend(warnings)
else:
print(f"\n{PASS} 逻辑树节点覆盖率: 全部通过")
# Test 6: Ensemble confidence fields on function_units
errors = check_ensemble_confidence(units)
if errors:
print(f"\n{FAIL} 集成置信度字段: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} 集成置信度字段: 全部通过")
# Test 7: Confidence summary consistency
errors = check_confidence_summary(si)
if errors:
print(f"\n{FAIL} confidence_summary 一致性: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
cs = si.get("confidence_summary", {})
print(f"\n{PASS} confidence_summary 一致性: "
f"high={cs.get('high',0)}, medium={cs.get('medium',0)}, "
f"low={cs.get('low',0)}")
# Summary
print(f"\n{'='*60}")
total_failures = len(all_errors)
total_warnings = len(all_warnings)
if total_failures == 0 and total_warnings == 0:
print(f"{PASS} 所有测试通过!")
elif total_failures == 0:
print(f"{WARN} 全部通过但有 {total_warnings} 个警告")
else:
print(f"{FAIL} 测试失败: {total_failures} 个错误, {total_warnings} 个警告")
print("\n请检查 LLM 输出质量,可能需要调整 Prompt 并重新运行 step1_semantic_index.py")
print(f"\n统计:")
print(f" 功能单元数: {len(units)}")
print(f" 概念数: {len(concepts)}")
print(f" 逻辑树图片数: {len(node_index)}")
return total_failures == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)
@@ -0,0 +1,322 @@
"""
Tests for Stage 2 (IR Extraction).
Validates that ir_fragments.json meets quality and structural requirements:
- All fragments have non-empty rules
- All rules have path arrays
- All rules have precondition.geographic_scope and precondition.screen_type
- All trigger conditions have signal/operator/value
- user_interaction content is non-empty and not a placeholder
- No duplicate rule_ids (across all fragments)
"""
import json
import sys
from pathlib import Path
from collections import Counter
sys.path.insert(0, str(Path(__file__).parent.parent))
import config
PASS = "[PASS]"
FAIL = "[FAIL]"
WARN = "[WARN]"
# Forbidden placeholder phrases in user_interaction content
FORBIDDEN_PLACEHOLDERS = [
"文案由业务定义", "待定", "自定义", "TBD", "todo", "TODO"
]
def load_fragments():
"""Load ir_fragments.json."""
try:
return config.load_json(config.IR_FRAGMENTS_JSON)
except FileNotFoundError:
print(f"{FAIL} ir_fragments.json 未找到: {config.IR_FRAGMENTS_JSON}")
print(" 请先运行 step2_ir_extraction.py")
sys.exit(1)
def check_non_empty_rules(fragments: list[dict]) -> list[str]:
"""Every fragment must have at least one rule."""
errors = []
for f in fragments:
uid = f.get("unit_id", "?")
rules = f.get("rules", [])
if not rules:
if f.get("error"):
errors.append(f"{uid}: 提取失败 — {f['error']}")
else:
errors.append(f"{uid}: rules 为空")
return errors
def check_rule_paths(fragments: list[dict]) -> list[str]:
"""Every rule must have a non-empty path array."""
errors = []
for f in fragments:
uid = f.get("unit_id", "?")
for j, rule in enumerate(f.get("rules", [])):
rid = rule.get("rule_id", f"rule[{j}]")
path = rule.get("path", [])
if not path:
errors.append(f"{rid}: path 字段为空或缺失")
elif not isinstance(path, list):
errors.append(f"{rid}: path 必须是数组")
return errors
def check_precondition_fields(fragments: list[dict]) -> list[str]:
"""Every rule must have precondition with geographic_scope and screen_type."""
errors = []
for f in fragments:
uid = f.get("unit_id", "?")
for j, rule in enumerate(f.get("rules", [])):
rid = rule.get("rule_id", f"rule[{j}]")
precond = rule.get("precondition", {})
if not precond:
errors.append(f"{rid}: precondition 缺失")
continue
if not precond.get("geographic_scope"):
errors.append(f"{rid}: precondition.geographic_scope 缺失")
if "screen_type" not in precond:
errors.append(f"{rid}: precondition.screen_type 缺失")
return errors
def check_user_interaction_content(fragments: list[dict]) -> list[str]:
"""user_interaction actions must have non-empty, non-placeholder content."""
errors = []
for f in fragments:
uid = f.get("unit_id", "?")
for j, rule in enumerate(f.get("rules", [])):
rid = rule.get("rule_id", f"rule[{j}]")
for k, action in enumerate(rule.get("actions", [])):
if action.get("type") != "user_interaction":
continue
content = action.get("content", "")
if not content:
errors.append(
f"{rid}.actions[{k}]: user_interaction 的 content 为空"
)
elif any(ph in content for ph in FORBIDDEN_PLACEHOLDERS):
errors.append(
f"{rid}.actions[{k}]: content 包含占位符: '{content}'"
)
return errors
def check_sources_have_logic_tree_nodes(fragments: list[dict]) -> list[str]:
"""Every rule should reference at least one logic tree node in its sources."""
errors = []
for f in fragments:
uid = f.get("unit_id", "?")
for j, rule in enumerate(f.get("rules", [])):
rid = rule.get("rule_id", f"rule[{j}]")
sources = rule.get("sources", [])
has_logic_tree = any(
src.get("type") == "logic_tree" and src.get("node_ids")
for src in sources
)
if not has_logic_tree:
has_text = any(
src.get("type") in ("table", "para") for src in sources
)
if not has_text:
errors.append(f"{rid}: sources 中既无逻辑树引用也无文字引用")
return errors
def check_trigger_conditions(fragments: list[dict]) -> list[str]:
"""Every trigger condition must have signal, operator, value."""
errors = []
for f in fragments:
uid = f.get("unit_id", "?")
for j, rule in enumerate(f.get("rules", [])):
rid = rule.get("rule_id", f"rule[{j}]")
trigger = rule.get("trigger", {})
conditions = trigger.get("conditions", [])
if trigger.get("event") is not None:
continue
for k, cond in enumerate(conditions):
signal = cond.get("signal", "")
operator = cond.get("operator", "")
has_value = "value" in cond
if not signal:
errors.append(f"{rid}.condition[{k}]: 缺少 signal")
if not operator:
errors.append(f"{rid}.condition[{k}]: 缺少 operator")
if not has_value:
errors.append(f"{rid}.condition[{k}]: 缺少 value")
return errors
def check_duplicate_rule_ids(fragments: list[dict]) -> list[str]:
"""Check for duplicate rule_ids across all fragments."""
all_rule_ids = []
for f in fragments:
for rule in f.get("rules", []):
rid = rule.get("rule_id", "")
if rid:
all_rule_ids.append(rid)
duplicates = [rid for rid, count in Counter(all_rule_ids).items() if count > 1]
errors = []
if duplicates:
errors.append(f"重复 rule_id: {duplicates}")
return errors
def check_action_types(fragments: list[dict]) -> list[str]:
"""Verify that actions have valid types."""
valid_types = {"system", "user_interaction"}
errors = []
for f in fragments:
for j, rule in enumerate(f.get("rules", [])):
rid = rule.get("rule_id", f"rule[{j}]")
for k, action in enumerate(rule.get("actions", [])):
atype = action.get("type", "")
if atype not in valid_types:
errors.append(
f"{rid}.action[{k}]: type='{atype}' 无效, "
f"应为 {valid_types}"
)
if atype == "user_interaction" and "content" not in action:
errors.append(
f"{rid}.action[{k}]: user_interaction 类型缺少 content 字段"
)
return errors
def run_all_tests():
print("=" * 60)
print("Step 2 自检测试")
print("=" * 60)
fragments = load_fragments()
all_errors = []
total_units = len(fragments)
total_rules = sum(len(f.get("rules", [])) for f in fragments)
# Test 1: Non-empty rules
errors = check_non_empty_rules(fragments)
if errors:
print(f"\n{FAIL} 非空规则检查: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} 非空规则检查: 全部通过 ({total_units} 个片段)")
# Test 2: Rule path arrays
errors = check_rule_paths(fragments)
if errors:
print(f"\n{FAIL} 规则 path 字段: {len(errors)} 个错误")
for e in errors[:10]:
print(f" - {e}")
if len(errors) > 10:
print(f" ... 还有 {len(errors) - 10}")
all_errors.extend(errors)
else:
print(f"\n{PASS} 规则 path 字段: 全部通过")
# Test 3: Precondition fields
errors = check_precondition_fields(fragments)
if errors:
print(f"\n{FAIL} precondition 字段: {len(errors)} 个错误")
for e in errors[:10]:
print(f" - {e}")
if len(errors) > 10:
print(f" ... 还有 {len(errors) - 10}")
all_errors.extend(errors)
else:
print(f"\n{PASS} precondition 字段: 全部通过")
# Test 4: user_interaction content
errors = check_user_interaction_content(fragments)
if errors:
print(f"\n{FAIL} user_interaction content: {len(errors)} 个错误")
for e in errors[:10]:
print(f" - {e}")
if len(errors) > 10:
print(f" ... 还有 {len(errors) - 10}")
all_errors.extend(errors)
else:
print(f"\n{PASS} user_interaction content: 全部通过")
# Test 5: Sources have logic tree references
errors = check_sources_have_logic_tree_nodes(fragments)
if errors:
print(f"\n{FAIL} 来源节点引用: {len(errors)} 个规则缺少来源引用")
for e in errors[:10]:
print(f" - {e}")
if len(errors) > 10:
print(f" ... 还有 {len(errors) - 10}")
all_errors.extend(errors)
else:
print(f"\n{PASS} 来源节点引用: 全部通过")
# Test 6: Trigger conditions completeness
errors = check_trigger_conditions(fragments)
if errors:
print(f"\n{FAIL} 触发条件完整性: {len(errors)} 个条件不完整")
for e in errors[:10]:
print(f" - {e}")
if len(errors) > 10:
print(f" ... 还有 {len(errors) - 10}")
all_errors.extend(errors)
else:
print(f"\n{PASS} 触发条件完整性: 全部通过")
# Test 7: No duplicate rule_ids
errors = check_duplicate_rule_ids(fragments)
if errors:
print(f"\n{FAIL} rule_id 唯一性: 发现重复")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} rule_id 唯一性: 全部通过")
# Test 8: Valid action types
errors = check_action_types(fragments)
if errors:
print(f"\n{FAIL} 动作类型检查: {len(errors)} 个问题")
for e in errors[:10]:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} 动作类型检查: 全部通过")
# Summary
print(f"\n{'='*60}")
total_failures = len(all_errors)
if total_failures == 0:
print(f"{PASS} 所有测试通过!")
else:
print(f"{FAIL} 测试失败: {total_failures} 个错误")
print("\n建议:")
print(" 1. 检查 ir_fragments.json 中出错的规则")
print(" 2. 如果某些功能单元的规则为空,检查上下文包是否丢失了关键信息")
print(" 3. 调整 Prompt (prompts/step2_ir_extraction.txt) 后重新运行")
print(f"\n统计:")
print(f" 功能单元数: {total_units}")
print(f" 规则总数: {total_rules}")
error_units = sum(1 for f in fragments if f.get("error"))
if error_units:
print(f" 提取失败的单元: {error_units}")
return total_failures == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)
@@ -0,0 +1,152 @@
"""
Tests for Stage 2.5 (Branch Coverage Auto-Completion).
Validates:
- Path enumeration exists and is non-empty
- Auto-complete fragments have valid structure
- No duplicate unit_ids in autocomplete fragments
- Path coverage improved after autocomplete (if applicable)
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import config
PASS = "[PASS]"
FAIL = "[FAIL]"
WARN = "[WARN]"
def load_path_enumeration():
"""Load path_enumeration.json."""
try:
return config.load_json(config.PATH_ENUM_JSON)
except FileNotFoundError:
print(f"{FAIL} path_enumeration.json 未找到: {config.PATH_ENUM_JSON}")
print(" 请先运行 step2_5_branch_coverage.py")
sys.exit(1)
def load_autocomplete_fragments():
"""Load ir_autocomplete_fragments.json, or return [] if absent."""
path = config.IR_AUTOCOMPLETE_FRAGMENTS_JSON
if not Path(path).exists():
return None
return config.load_json(path)
def check_path_enumeration(data: dict) -> list[str]:
"""Check path enumeration has valid structure."""
errors = []
paths = data.get("logic_tree_paths", {})
if not paths:
errors.append("logic_tree_paths 为空")
total = data.get("total_paths", 0)
if total <= 0:
errors.append(f"total_paths = {total}, 期望 > 0")
for image_id, image_paths in paths.items():
if not image_paths:
errors.append(f"{image_id}: 路径列表为空")
continue
for i, p in enumerate(image_paths):
if not p.get("path_id"):
errors.append(f"{image_id}[{i}]: 缺少 path_id")
if not p.get("image_id"):
errors.append(f"{image_id}[{i}]: 缺少 image_id")
if not p.get("node_ids"):
errors.append(f"{image_id}[{i}]: 缺少 node_ids")
return errors
def check_autocomplete_fragments(fragments: list[dict] | None) -> list[str]:
"""Check auto-complete fragments have valid structure."""
if fragments is None:
return ["ir_autocomplete_fragments.json 未生成 (可能无需补全)"]
errors = []
seen_unit_ids = set()
for frag in fragments:
uid = frag.get("unit_id", "")
if not uid:
errors.append("fragment 缺少 unit_id")
continue
if uid in seen_unit_ids:
errors.append(f"unit_id '{uid}' 重复")
seen_unit_ids.add(uid)
if not frag.get("auto_generated"):
errors.append(f"{uid}: auto_generated 应为 true")
rules = frag.get("rules", [])
for j, rule in enumerate(rules):
rid = rule.get("rule_id", f"rule[{j}]")
if not rule.get("path"):
errors.append(f"{rid}: path 字段缺失")
precond = rule.get("precondition", {})
if not precond.get("geographic_scope"):
errors.append(f"{rid}: precondition.geographic_scope 缺失")
return errors
def run_all_tests():
print("=" * 60)
print("Step 2.5 自检测试")
print("=" * 60)
all_errors = []
# Test 1: Path enumeration exists
try:
path_data = load_path_enumeration()
except SystemExit:
return False
errors = check_path_enumeration(path_data)
if errors:
print(f"\n{FAIL} 路径枚举检查: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
total = path_data.get("total_paths", 0)
n_images = len(path_data.get("logic_tree_paths", {}))
print(f"\n{PASS} 路径枚举检查: {total} 条路径, {n_images} 个逻辑树")
# Test 2: Auto-complete fragments
fragments = load_autocomplete_fragments()
errors = check_autocomplete_fragments(fragments)
if fragments is None:
print(f"\n{WARN} 自动补全片段: 未生成 (可能所有路径已覆盖)")
elif errors:
print(f"\n{FAIL} 自动补全片段检查: {len(errors)} 个错误")
for e in errors[:10]:
print(f" - {e}")
all_errors.extend(errors)
else:
auto_rules = sum(len(f.get("rules", [])) for f in fragments)
print(f"\n{PASS} 自动补全片段检查: "
f"{len(fragments)} 个片段, {auto_rules} 条规则")
# Summary
print(f"\n{'='*60}")
total_failures = len(all_errors)
if total_failures == 0:
print(f"{PASS} 所有测试通过!")
else:
print(f"{FAIL} 测试失败: {total_failures} 个错误")
return total_failures == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)
@@ -0,0 +1,232 @@
"""
Tests for Stage 3 (Merge & Audit).
Validates:
- ir_final.json exists and is well-formed
- No duplicate rule_ids
- All rule_ids follow new hierarchical naming convention
- All rules have path arrays
- ir_audit_report.md exists and contains all required sections
"""
import re
import sys
from pathlib import Path
from collections import Counter
sys.path.insert(0, str(Path(__file__).parent.parent))
import config
PASS = "[PASS]"
FAIL = "[FAIL]"
WARN = "[WARN]"
def load_ir_final():
"""Load ir_final.json."""
try:
return config.load_json(config.IR_FINAL_JSON)
except FileNotFoundError:
print(f"{FAIL} ir_final.json 未找到: {config.IR_FINAL_JSON}")
print(" 请先运行 step3_merge_and_audit.py")
sys.exit(1)
def load_audit_report():
"""Load ir_audit_report.md if it exists."""
try:
with open(config.IR_AUDIT_REPORT_MD, "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
print(f"{FAIL} ir_audit_report.md 未找到: {config.IR_AUDIT_REPORT_MD}")
print(" 请先运行 step3_merge_and_audit.py")
sys.exit(1)
def check_rule_ids(ir: dict) -> list[str]:
"""Check for duplicate rule_ids and hierarchical naming convention.
Format: DRL-001-DOMESTIC-SYS-FG-INTERRUPT-01
"""
errors = []
rules = ir.get("rules", [])
rule_ids = [r.get("rule_id", "") for r in rules]
# No duplicates
duplicates = [rid for rid, count in Counter(rule_ids).items() if count > 1]
if duplicates:
errors.append(f"重复 rule_id: {duplicates}")
# New hierarchical naming convention
pattern = re.compile(
r"^[A-Z]+-\d{3}-(DOMESTIC|OVERSEAS)-"
r"(SYS|SDK|OTHER)-"
r"(FG-INTERRUPT|BG-BLOCK|BG-PAUSE|NO-RESTRICT|SWITCH-OFF)-\d{2}$"
)
for rid in rule_ids:
if rid and not pattern.match(rid):
errors.append(
f"rule_id 命名不规范: '{rid}' "
f"(期望: FEATURE-SCOPE-METHOD-BEHAVIOR-NN)"
)
return errors
def check_top_level_structure(ir: dict) -> list[str]:
"""Check that ir_final has the required top-level fields."""
errors = []
for field in ["feature", "feature_id", "rules"]:
if field not in ir:
errors.append(f"ir_final 缺少顶层字段: {field}")
if not isinstance(ir.get("rules"), list):
errors.append("ir_final.rules 必须是数组")
elif len(ir["rules"]) == 0:
errors.append("ir_final.rules 为空")
return errors
def check_rule_paths(rules: list[dict]) -> list[str]:
"""Every rule must have a non-empty path array."""
errors = []
for rule in rules:
rid = rule.get("rule_id", "?")
path = rule.get("path", [])
if not path:
errors.append(f"{rid}: path 字段为空或缺失")
return errors
def check_rule_completeness(rules: list[dict]) -> list[str]:
"""Check each rule has all required fields."""
errors = []
required_fields = [
"rule_id", "description", "priority", "sources",
"precondition", "trigger", "actions"
]
for i, rule in enumerate(rules):
rid = rule.get("rule_id", f"rule[{i}]")
for field in required_fields:
if field not in rule:
errors.append(f"{rid}: 缺少字段 '{field}'")
if not rule.get("sources"):
errors.append(f"{rid}: sources 为空")
if not rule.get("actions"):
errors.append(f"{rid}: actions 为空")
# Check precondition fields
precond = rule.get("precondition", {})
if not precond.get("geographic_scope"):
errors.append(f"{rid}: precondition.geographic_scope 缺失")
if "screen_type" not in precond:
errors.append(f"{rid}: precondition.screen_type 缺失")
return errors
def check_audit_report(report: str) -> list[str]:
"""Check audit report has all required sections."""
errors = []
required_sections = [
"逻辑树路径覆盖率",
"表格枚举覆盖",
"开关状态",
"一致性扫描报告",
"自动补全摘要",
"规则清单",
]
for section in required_sections:
if section not in report:
errors.append(f"审计报告缺少章节: {section}")
# Should have the human review notice
if "人工审查" not in report:
errors.append("审计报告缺少人工审查提示")
return errors
def run_all_tests():
print("=" * 60)
print("Step 3 自检测试")
print("=" * 60)
ir = load_ir_final()
report = load_audit_report()
rules = ir.get("rules", [])
all_errors = []
# Test 1: Top-level structure
errors = check_top_level_structure(ir)
if errors:
print(f"\n{FAIL} 顶层结构检查: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} 顶层结构检查: 通过 "
f"(feature={ir.get('feature')}, feature_id={ir.get('feature_id')})")
# Test 2: rule_id uniqueness and naming
errors = check_rule_ids(ir)
if errors:
print(f"\n{FAIL} rule_id 检查: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} rule_id 检查: 全部通过 ({len(rules)} 个唯一 ID, 层次化格式)")
# Test 3: Rule path fields
errors = check_rule_paths(rules)
if errors:
print(f"\n{FAIL} 规则 path 字段: {len(errors)} 个错误")
for e in errors[:10]:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} 规则 path 字段: 全部通过")
# Test 4: Rule field completeness
errors = check_rule_completeness(rules)
if errors:
print(f"\n{FAIL} 规则字段完整性: {len(errors)} 个错误")
for e in errors[:10]:
print(f" - {e}")
if len(errors) > 10:
print(f" ... 还有 {len(errors) - 10}")
all_errors.extend(errors)
else:
print(f"\n{PASS} 规则字段完整性: 全部通过")
# Test 5: Audit report content
errors = check_audit_report(report)
if errors:
print(f"\n{FAIL} 审计报告检查: {len(errors)} 个错误")
for e in errors:
print(f" - {e}")
all_errors.extend(errors)
else:
print(f"\n{PASS} 审计报告检查: 全部通过 (6 个章节)")
# Summary
print(f"\n{'='*60}")
total_failures = len(all_errors)
if total_failures == 0:
print(f"{PASS} 所有测试通过!")
print(f"\n最终交付物:")
print(f" - {config.IR_FINAL_JSON} ({len(rules)} 条规则)")
print(f" - {config.IR_AUDIT_REPORT_MD}")
else:
print(f"{FAIL} 测试失败: {total_failures} 个错误")
print("\n建议: 检查 ir_fragments.json 和合并逻辑,修复问题后重新运行 step3_merge_and_audit.py")
return total_failures == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)
+1
View File
@@ -0,0 +1 @@
# Tests package for document_analyzer
+1
View File
@@ -0,0 +1 @@
# QE Acceptance Tests for document_analyzer
+186
View File
@@ -0,0 +1,186 @@
"""Pytest configuration and shared fixtures for QE acceptance tests.
Usage::
pytest tests/acceptance/ -v --run-acceptance [--acceptance-runs=3]
Environment variables:
DASHSCOPE_API_KEY LLM API key (required for Layers B/C)
TEST_IR_PATH path to IR JSON to validate (default: ir_final.json sample)
TEST_PARSED_PATH path to _parsed.json or _updated.json for coverage analysis
"""
from __future__ import annotations
import json
import os
import sys
import tempfile
from pathlib import Path
from typing import Any
import pytest
# ── Path setup ──────────────────────────────────────────────────────────────
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(_PROJECT_ROOT))
def _skill_path(skill_name: str) -> str:
return str(_PROJECT_ROOT / "skills" / skill_name / "scripts")
# ── pytest configuration ────────────────────────────────────────────────────
def pytest_addoption(parser):
parser.addoption(
"--run-acceptance",
action="store_true",
default=False,
help="Run QE acceptance tests (requires DASHSCOPE_API_KEY)",
)
parser.addoption(
"--acceptance-runs",
type=int,
default=1,
help="Number of IR generation runs for Layer B stability testing (default: 1 = skip)",
)
parser.addoption(
"--ir-path",
type=str,
default=None,
help="Path to IR JSON file to validate",
)
parser.addoption(
"--parsed-path",
type=str,
default=None,
help="Path to _parsed.json or _updated.json for coverage analysis",
)
def pytest_configure(config):
config.addinivalue_line(
"markers",
"acceptance: QE acceptance test (requires --run-acceptance flag and DASHSCOPE_API_KEY)",
)
def pytest_collection_modifyitems(config, items):
acceptance_dir = str(_PROJECT_ROOT / "tests" / "acceptance")
acceptance_items = [i for i in items if str(i.fspath).startswith(acceptance_dir)]
non_acceptance_items = [i for i in items if not str(i.fspath).startswith(acceptance_dir)]
if not config.getoption("--run-acceptance"):
skip_msg = pytest.mark.skip(reason="Need --run-acceptance flag to run")
for item in acceptance_items:
item.add_marker(skip_msg)
# Don't skip non-acceptance tests
return
if not os.environ.get("DASHSCOPE_API_KEY"):
skip_msg = pytest.mark.skip(reason="DASHSCOPE_API_KEY not set")
for item in acceptance_items:
item.add_marker(skip_msg)
# ── Shared fixtures ─────────────────────────────────────────────────────────
@pytest.fixture(scope="session")
def project_root() -> Path:
return _PROJECT_ROOT
@pytest.fixture(scope="session")
def ir_path(request) -> str:
"""Path to the IR JSON file under test."""
path = (
request.config.getoption("--ir-path")
or os.environ.get("TEST_IR_PATH")
or str(
Path.home()
/ ".openclaw/workspace/skills/doc_parser_skill/output/ir_final.json"
)
)
if not os.path.exists(path):
pytest.skip(f"IR file not found: {path}")
return path
@pytest.fixture(scope="session")
def ir_data(ir_path: str) -> dict:
"""Load the IR JSON data."""
with open(ir_path, "r", encoding="utf-8") as f:
return json.load(f)
@pytest.fixture(scope="session")
def parsed_path(request) -> str | None:
"""Path to the corresponding _parsed.json or _updated.json."""
path = (
request.config.getoption("--parsed-path")
or os.environ.get("TEST_PARSED_PATH")
or str(
_PROJECT_ROOT
/ "skills/ir_generation_skill/车机娱乐系统禁止功能文档_精简_updated.json"
)
)
if os.path.exists(path):
return path
return None
@pytest.fixture(scope="session")
def parsed_data(parsed_path: str | None) -> dict | None:
"""Load the parsed document JSON for coverage analysis."""
if parsed_path is None:
return None
with open(parsed_path, "r", encoding="utf-8") as f:
return json.load(f)
@pytest.fixture(scope="session")
def llm_client():
"""Create an LLMClient instance for acceptance tests.
Uses the DashScope-compatible LLMClient from the project.
"""
sys.path.insert(0, _skill_path("doc_parser_skill"))
from LLM import LLMClient
return LLMClient()
@pytest.fixture(scope="session")
def acceptance_runs(request) -> int:
return request.config.getoption("--acceptance-runs", default=1)
# ── Pipeline runner ─────────────────────────────────────────────────────────
@pytest.fixture(scope="session")
def run_ir_pipeline():
"""Return a callable that runs the IR generation pipeline on a parsed JSON.
Usage::
ir_data, ir_path = run_ir_pipeline(parsed_json_path, output_dir)
"""
sys.path.insert(0, _skill_path("ir_generation_skill"))
from ir_generator import generate_ir
def _run(parsed_path: str, output_dir: str | None = None) -> tuple[dict, str]:
"""Run IR generation and return (ir_data, ir_path)."""
out = output_dir or tempfile.mkdtemp(prefix="qe_acceptance_")
result = generate_ir(parsed_path, out, dry_run=False)
ir_list = result.get("ir", [])
ir_path = result.get("path", "")
# ir_generator produces a list; wrap to match rich format expectations
# for schema validation we accept both formats
return ir_list, ir_path
return _run
+325
View File
@@ -0,0 +1,325 @@
"""Rich IR schema definition and validators for the document_analyzer QE framework.
Target format is the production IR (``ir_final.json``):
{feature, feature_id, config_defaults?, rules: [{rule_id, path, description,
priority, sources, precondition, trigger, actions}]}
"""
from __future__ import annotations
import re
from typing import Any
# ── Constants ────────────────────────────────────────────────────────────────
VALID_SOURCE_TYPES = {"table", "logic_tree", "text"}
VALID_ACTION_TYPES = {"system", "user_interaction"}
VALID_PRIORITIES = {"P0", "P1", "P2"}
VALID_TRIGGER_OPERATORS = {"AND", "OR"}
# rule_id pattern: FEAT-NNN-SCOPE-TYPE-...-PATH-NN (variable middle segments)
RULE_ID_RE = re.compile(
r"^[A-Z]+-\d+(-[A-Z]+)+-\d+$"
)
# ── Validation helpers ──────────────────────────────────────────────────────
def _check(condition: bool, message: str) -> list[str]:
"""Return a list with an error message if *condition* is False, else empty list."""
return [] if condition else [message]
def validate_rule(rule: dict, index: int = 0) -> list[str]:
"""Validate a single rule dict. Returns a (possibly empty) list of error strings."""
errors: list[str] = []
label = f"rules[{index}]"
if not isinstance(rule, dict):
return [f"{label}: not a dict"]
# ── required top-level fields ──
for field in ("rule_id", "description"):
errors.extend(_check(
isinstance(rule.get(field), str) and bool(rule[field].strip()),
f'{label}.{field}: required non-empty string',
))
# sources is a list, not a string — validated separately below
# ── rule_id naming ──
rid = rule.get("rule_id", "")
if rid and isinstance(rid, str):
errors.extend(_check(
bool(RULE_ID_RE.match(rid)),
f'{label}.rule_id: "{rid}" does not match pattern FEAT-NNN-SCOPE-TYPE-PATH-NN',
))
# ── priority ──
priority = rule.get("priority")
if priority is not None:
errors.extend(_check(
priority in VALID_PRIORITIES,
f'{label}.priority: "{priority}" not in {VALID_PRIORITIES}',
))
# ── path ──
path = rule.get("path")
if path is not None:
if not isinstance(path, list):
errors.append(f"{label}.path: must be a list")
elif len(path) == 0:
errors.append(f"{label}.path: must not be empty")
elif not all(isinstance(p, str) and p.strip() for p in path):
errors.append(f"{label}.path: all segments must be non-empty strings")
# ── sources[] ──
sources = rule.get("sources", [])
if not isinstance(sources, list):
errors.append(f"{label}.sources: must be a list")
elif len(sources) == 0:
errors.append(f"{label}.sources: must have at least one source")
else:
for si, src in enumerate(sources):
errors.extend(_validate_source(src, f"{label}.sources[{si}]"))
# ── precondition ──
precondition = rule.get("precondition")
if precondition is not None:
if not isinstance(precondition, dict):
errors.append(f"{label}.precondition: must be a dict")
elif len(precondition) == 0:
errors.append(f"{label}.precondition: must not be empty")
# ── trigger ──
trigger = rule.get("trigger")
if trigger is not None:
if not isinstance(trigger, dict):
errors.append(f"{label}.trigger: must be a dict")
else:
errors.extend(_validate_trigger(trigger, f"{label}.trigger"))
# ── actions ──
actions = rule.get("actions")
if actions is not None:
if not isinstance(actions, list):
errors.append(f"{label}.actions: must be a list")
else:
for ai, act in enumerate(actions):
errors.extend(_validate_action(act, f"{label}.actions[{ai}]"))
# ── no null values at any depth ──
errors.extend(_find_nulls(rule, label))
return errors
def _validate_source(src: dict, label: str) -> list[str]:
errors: list[str] = []
if not isinstance(src, dict):
return [f"{label}: not a dict"]
stype = src.get("type", "")
errors.extend(_check(
stype in VALID_SOURCE_TYPES,
f'{label}.type: "{stype}" not in {VALID_SOURCE_TYPES}',
))
priority = src.get("priority", "")
if priority:
errors.extend(_check(
priority in ("primary_source", "supplementary"),
f'{label}.priority: "{priority}" must be primary_source or supplementary',
))
# type-specific fields
if stype == "table":
errors.extend(_check(
isinstance(src.get("section"), str) and bool(src["section"].strip()),
f"{label}.section: required non-empty string for table source",
))
errors.extend(_check(
isinstance(src.get("row"), int),
f"{label}.row: required int for table source",
))
elif stype == "logic_tree":
errors.extend(_check(
isinstance(src.get("image_id"), str) and bool(src["image_id"].strip()),
f"{label}.image_id: required non-empty string for logic_tree source",
))
node_ids = src.get("node_ids", [])
errors.extend(_check(
isinstance(node_ids, list) and len(node_ids) > 0,
f"{label}.node_ids: required non-empty list for logic_tree source",
))
elif stype == "text":
errors.extend(_check(
isinstance(src.get("section"), str) and bool(src["section"].strip()),
f"{label}.section: required non-empty string for text source",
))
return errors
def _validate_trigger(trigger: dict, label: str) -> list[str]:
errors: list[str] = []
operator = trigger.get("operator", "")
errors.extend(_check(
operator in VALID_TRIGGER_OPERATORS,
f'{label}.operator: "{operator}" not in {VALID_TRIGGER_OPERATORS}',
))
conditions = trigger.get("conditions")
if conditions is not None:
if not isinstance(conditions, list):
errors.append(f"{label}.conditions: must be a list")
else:
for ci, cond in enumerate(conditions):
if not isinstance(cond, dict):
errors.append(f"{label}.conditions[{ci}]: not a dict")
else:
errors.extend(_check(
isinstance(cond.get("signal"), str) and bool(cond["signal"].strip()),
f"{label}.conditions[{ci}].signal: required non-empty string",
))
errors.extend(_check(
"operator" in cond,
f"{label}.conditions[{ci}].operator: required",
))
# empty conditions is valid (e.g. "switch always off, no conditions")
return errors
def _validate_action(action: dict, label: str) -> list[str]:
errors: list[str] = []
if not isinstance(action, dict):
return [f"{label}: not a dict"]
atype = action.get("type", "")
errors.extend(_check(
atype in VALID_ACTION_TYPES,
f'{label}.type: "{atype}" not in {VALID_ACTION_TYPES}',
))
errors.extend(_check(
isinstance(action.get("description"), str) and bool(action["description"].strip()),
f"{label}.description: required non-empty string",
))
return errors
def _find_nulls(obj: Any, label: str) -> list[str]:
"""Find any None values at any depth in *obj*."""
errors: list[str] = []
if obj is None:
return [f"{label}: null value"]
elif isinstance(obj, dict):
for k, v in obj.items():
errors.extend(_find_nulls(v, f"{label}.{k}"))
elif isinstance(obj, list):
for i, v in enumerate(obj):
errors.extend(_find_nulls(v, f"{label}[{i}]"))
return errors
# ── Top-level validation ────────────────────────────────────────────────────
def validate_ir(ir_data: dict) -> dict:
"""Validate the entire IR document.
Returns:
{
"valid": bool,
"errors": [str, ...],
"stats": {total_rules, valid_rules, has_config_defaults, ...}
}
"""
errors: list[str] = []
stats = {"total_rules": 0, "valid_rules": 0, "has_config_defaults": False, "features": 0}
if not isinstance(ir_data, dict):
return {"valid": False, "errors": ["IR root is not a dict"], "stats": stats}
# top-level required fields
for field in ("feature", "feature_id", "rules"):
if field not in ir_data:
errors.append(f"root.{field}: missing required field")
elif field in ("feature", "feature_id") and not (
isinstance(ir_data[field], str) and ir_data[field].strip()
):
errors.append(f"root.{field}: must be non-empty string")
# config_defaults (optional)
if "config_defaults" in ir_data:
stats["has_config_defaults"] = True
cd = ir_data["config_defaults"]
if not isinstance(cd, dict):
errors.append("root.config_defaults: must be a dict")
# rules array
rules = ir_data.get("rules", [])
if not isinstance(rules, list):
errors.append("root.rules: must be a list")
else:
stats["total_rules"] = len(rules)
if len(rules) == 0:
errors.append("root.rules: must have at least one rule")
else:
for i, rule in enumerate(rules):
rule_errors = validate_rule(rule, i)
if rule_errors:
errors.extend(rule_errors)
else:
stats["valid_rules"] += 1
# feature count
if isinstance(ir_data.get("feature_id"), str):
stats["features"] = 1
return {
"valid": len(errors) == 0,
"errors": errors,
"stats": stats,
}
# ── Summary helpers ─────────────────────────────────────────────────────────
def schema_checklist(ir_data: dict) -> list[dict]:
"""Run individual checks and return a checklist for reporting.
Each item: {"check": str, "passed": bool, "detail": str}
"""
report = validate_ir(ir_data)
checks: list[dict] = []
def _add(name: str, passed: bool, detail: str = ""):
checks.append({"check": name, "passed": passed, "detail": detail})
# Top-level
_add("root is dict", isinstance(ir_data, dict))
_add("root.feature present", isinstance(ir_data.get("feature"), str) and bool(ir_data["feature"].strip()))
_add("root.feature_id present", isinstance(ir_data.get("feature_id"), str) and bool(ir_data["feature_id"].strip()))
_add("root.rules is non-empty list", isinstance(ir_data.get("rules"), list) and len(ir_data["rules"]) > 0)
# Per-rule checks
rules = ir_data.get("rules", []) if isinstance(ir_data, dict) else []
rule_ids = []
for i, rule in enumerate(rules):
if not isinstance(rule, dict):
continue
rid = rule.get("rule_id", f"rules[{i}]")
rule_ids.append(rid)
errs = validate_rule(rule, i)
_add(f"{rid}: valid", len(errs) == 0, "; ".join(errs) if errs else "")
# Aggregate checks
_add("no duplicate rule_ids", len(rule_ids) == len(set(rule_ids)),
f"duplicates: {[r for r in rule_ids if rule_ids.count(r) > 1]}" if len(rule_ids) != len(set(rule_ids)) else "")
_add("all rules valid", report["valid"],
f"{report['stats']['valid_rules']}/{report['stats']['total_rules']} valid")
return checks
+178
View File
@@ -0,0 +1,178 @@
"""Structured JSON report generation for QE acceptance test results.
Produces a unified report with three-layer verdict:
Layer A Schema compliance
Layer B Structural coverage + stability
Layer C LLM QE expert audit
Final verdict: PASS (releasable) or FAIL (blocked).
"""
from __future__ import annotations
import json
import time
from pathlib import Path
from typing import Any
def generate_report(
schema_result: dict,
coverage_result: dict,
audit_result: dict | None,
*,
commit: str = "",
branch: str = "main",
output_path: str | None = None,
) -> dict:
"""Assemble the three-layer report and return it.
Args:
schema_result: ``{"verdict": "PASS"|"FAIL", "total_checks": N, "passed": N, "failed": N}``
coverage_result: ``{"verdict": "PASS"|"FAIL", "coverage_rate": float,
"stability": {"runs": N, "values": [...], "std": float}}``
audit_result: ``{"verdict": "ACCEPT"|"REJECT", "inadequate_ratio": float,
"rationale": str, "section_assessments": [...]}`` or None
commit: git commit SHA
branch: branch name
output_path: if set, write the report JSON to this path
Returns the report dict.
"""
layers: dict[str, Any] = {
"A_schema": schema_result,
"B_coverage": coverage_result,
}
if audit_result is not None:
layers["C_qe_audit"] = audit_result
# ── final verdict ──
a_pass = schema_result.get("verdict") == "PASS"
b_pass = coverage_result.get("verdict") == "PASS"
c_pass = (
audit_result is None
or audit_result.get("verdict") == "ACCEPT"
)
all_pass = a_pass and b_pass and c_pass
report = {
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"commit": commit,
"branch": branch,
"layers": layers,
"final_verdict": "PASS" if all_pass else "FAIL",
"releasable": all_pass,
"failure_details": _failure_details(layers),
}
if output_path:
out = Path(output_path)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
return report
def _failure_details(layers: dict) -> list[str]:
"""Summarise which layers failed and why."""
details: list[str] = []
schema = layers.get("A_schema", {})
if schema.get("verdict") != "PASS":
details.append(
f"Layer A (Schema): {schema.get('failed', '?')}/{schema.get('total_checks', '?')} checks failed"
)
coverage = layers.get("B_coverage", {})
if coverage.get("verdict") != "PASS":
cv = coverage.get("coverage_rate", "?")
details.append(f"Layer B (Coverage): rate={cv} (threshold: 0.70)")
audit = layers.get("C_qe_audit", {})
if audit.get("verdict") == "REJECT":
details.append(
f"Layer C (QE Audit): REJECT — inadequate_ratio={audit.get('inadequate_ratio', '?')}"
)
return details
# ── Layer-specific result builders ──────────────────────────────────────────
def schema_verdict(errors: list[str], stats: dict) -> dict:
"""Build Layer A result from schema validation errors & stats."""
total = stats.get("total_rules", 0)
valid = stats.get("valid_rules", 0)
failed_checks = len(errors) + (total - valid)
return {
"verdict": "PASS" if failed_checks == 0 else "FAIL",
"total_checks": max(total, 1), # at minimum, we checked the root
"passed": valid if failed_checks == 0 else valid,
"failed": failed_checks,
"rule_pass_rate": round(valid / max(total, 1), 2) if total > 0 else 0,
"sample_errors": errors[:10], # first 10 for the report
}
def coverage_verdict(
coverage_rate: float,
stability_std: float,
stability_values: list[float],
*,
coverage_threshold: float = 0.70,
stability_threshold: float = 0.05,
section_coverage: dict | None = None,
table_coverage: dict | None = None,
diagram_coverage: dict | None = None,
) -> dict:
"""Build Layer B result from coverage metrics."""
b1_pass = coverage_rate >= coverage_threshold
b2_pass = stability_std <= stability_threshold
both_pass = b1_pass and b2_pass
result: dict[str, Any] = {
"verdict": "PASS" if both_pass else "FAIL",
"coverage_rate": round(coverage_rate, 3),
"coverage_threshold": coverage_threshold,
"coverage_pass": b1_pass,
"stability": {
"runs": len(stability_values),
"values": [round(v, 3) for v in stability_values],
"std": round(stability_std, 4),
"threshold": stability_threshold,
"pass": b2_pass,
},
}
if section_coverage:
result["section_coverage"] = section_coverage
if table_coverage:
result["table_coverage"] = table_coverage
if diagram_coverage:
result["diagram_coverage"] = diagram_coverage
return result
def audit_verdict(audit_data: dict, *, inadequate_threshold: float = 0.30) -> dict:
"""Build Layer C result from LLM QE audit.
*audit_data* should contain:
inadequate_ratio: float
rationale: str
section_assessments: list[dict]
"""
ratio = audit_data.get("inadequate_ratio", 1.0)
passed = ratio <= inadequate_threshold
return {
"verdict": "ACCEPT" if passed else "REJECT",
"inadequate_ratio": round(ratio, 3),
"threshold": inadequate_threshold,
"rationale": audit_data.get("rationale", ""),
"total_sections": audit_data.get("total_functional_sections", 0),
"adequate": audit_data.get("adequate", 0),
"inadequate": audit_data.get("inadequate", 0),
"not_applicable": audit_data.get("not_applicable", 0),
}
+558
View File
@@ -0,0 +1,558 @@
"""QE Acceptance Test — Three-layer main branch health check.
Layer A (Schema): structural correctness of IR
Layer B (Coverage): structural source-traceability coverage + stability
Layer C (QE Audit): LLM as QE expert functional coverage assessment
Final verdict: all three layers must pass for main to be releasable.
"""
from __future__ import annotations
import json
import math
import re
import statistics
import tempfile
import time
from pathlib import Path
from typing import Any
import pytest
from .ir_schema import validate_ir, schema_checklist
from .report import generate_report, schema_verdict, coverage_verdict, audit_verdict
# ═══════════════════════════════════════════════════════════════════════════════
# Layer A: SCHEMA — deterministic structural validation
# ═══════════════════════════════════════════════════════════════════════════════
def test_layer_a_schema(ir_data: dict, request):
"""Validate IR structure: required fields, types, naming conventions, no nulls."""
report = validate_ir(ir_data)
checks = schema_checklist(ir_data)
# Build Layer A result
a_errors = report["errors"]
a_stats = report["stats"]
a_result = schema_verdict(a_errors, a_stats)
a_result["checks"] = checks
# Store for downstream layers & report
_stash(request, "layer_a", a_result)
# Assert
assert report["valid"], (
f"Schema validation FAILED ({len(a_errors)} errors)\n"
+ "\n".join(f" - {e}" for e in a_errors[:20])
)
# ═══════════════════════════════════════════════════════════════════════════════
# Layer B: STRUCTURAL COVERAGE + STABILITY
# ═══════════════════════════════════════════════════════════════════════════════
# Section titles that are NOT functional requirements
NON_FUNCTIONAL_PATTERNS = [
re.compile(p) for p in [
r"编制.*变更.*日志",
r"文档背景",
r"文档范围",
r"术语解释",
r"参考",
r"附录",
r"版本",
r"变更记录",
r"目录",
r"前言",
r"概述",
r"简介",
r"概述.*背景",
]
]
def _is_functional_section(section_name: str) -> bool:
"""Heuristic: exclude background, glossary, changelog, scope sections.
Sections that are purely structural preface, glossary, changelog are excluded.
Sections with numbering like '3.1.1' are always considered functional.
"""
# Numbered sections are functional
if _section_number(section_name) != section_name:
return True
for pat in NON_FUNCTIONAL_PATTERNS:
if pat.search(section_name):
return False
return True
def _extract_content_units(parsed_data: dict) -> dict:
"""Extract countable content units from parsed JSON.
Returns:
{"sections": [{"name": ..., "number": ...}, ...],
"table_rows": int, "diagram_images": [rid, ...]}
"""
sections = parsed_data.get("sections", [])
functional_sections: list[dict] = []
total_table_rows = 0
for sec in sections:
name = sec.get("source", "")
if _is_functional_section(name):
functional_sections.append({
"name": name,
"number": _section_number(name),
})
for block in sec.get("blocks", []):
if block.get("type") == "table":
rows = block.get("rows", [])
total_table_rows += len(rows)
# Diagram-type images from image_analysis
diagram_rids: list[str] = []
for img in parsed_data.get("image_analysis", []):
img_type = img.get("type", "")
if img_type in ("flowchart", "logic_tree", "architecture",
"state", "sequence", "activity"):
diagram_rids.append(img.get("rid", ""))
return {
"functional_sections": functional_sections,
"table_rows": total_table_rows,
"diagram_images": diagram_rids,
}
def _section_number(section_name: str) -> str:
"""Extract leading section number, e.g. '3.1.1 系统限制''3.1.1'."""
import re
m = re.match(r"^([\d.]+)", section_name)
return m.group(1) if m else section_name
def _section_matches(sec_ref: str, func_sections: list[dict]) -> str | None:
"""Find a functional section matching *sec_ref*. Returns the section name or None.
Matching: exact match starts-with match number match substring match.
"""
# exact
for s in func_sections:
if s["name"] == sec_ref:
return s["name"]
# starts with section number
for s in func_sections:
if s["name"].startswith(sec_ref) or sec_ref.startswith(s["name"]):
return s["name"]
# number match
sec_num = _section_number(sec_ref)
if sec_num:
for s in func_sections:
if s["number"] == sec_num:
return s["name"]
# substring
for s in func_sections:
if sec_ref in s["name"] or s["name"] in sec_ref:
return s["name"]
return None
def _measure_coverage(ir_data: dict, parsed_data: dict) -> dict:
"""Compute structural coverage of IR over parsed document.
Returns:
{
"section_coverage": {total, covered, rate, uncovered},
"table_coverage": {total_rows, covered_rows, rate},
"diagram_coverage": {total, covered, rate},
"overall_rate": float,
}
"""
units = _extract_content_units(parsed_data)
rules = ir_data.get("rules", [])
# ── section coverage ──
func_sections = units["functional_sections"]
covered_sections: set[str] = set()
for rule in rules:
for src in rule.get("sources", []):
sec_ref = src.get("section", "")
if sec_ref:
matched = _section_matches(sec_ref, func_sections)
if matched:
covered_sections.add(matched)
section_coverage = {
"total": len(func_sections),
"covered": len(covered_sections),
"rate": round(len(covered_sections) / max(len(func_sections), 1), 3),
"uncovered": [s["name"] for s in func_sections
if s["name"] not in covered_sections],
}
# ── table row coverage ──
covered_rows: set[tuple] = set()
for rule in rules:
for src in rule.get("sources", []):
if src.get("type") == "table":
sec = src.get("section", "")
row = src.get("row")
if sec and row is not None:
covered_rows.add((sec, row))
total_rows = units["table_rows"]
table_coverage = {
"total_rows": total_rows,
"covered_rows": len(covered_rows),
"rate": round(len(covered_rows) / max(total_rows, 1), 3),
}
# ── diagram coverage ──
diagram_rids = units["diagram_images"]
covered_rids: set[str] = set()
for rule in rules:
for src in rule.get("sources", []):
if src.get("type") == "logic_tree":
img_id = src.get("image_id", "")
if img_id and img_id in diagram_rids:
covered_rids.add(img_id)
diagram_coverage = {
"total": len(diagram_rids),
"covered": len(covered_rids),
"rate": round(len(covered_rids) / max(len(diagram_rids), 1), 3),
"uncovered": [r for r in diagram_rids if r not in covered_rids],
}
# ── overall ──
rates = [
section_coverage["rate"],
table_coverage["rate"],
diagram_coverage["rate"],
]
overall = round(sum(rates) / len(rates), 3) if rates else 0.0
return {
"section_coverage": section_coverage,
"table_coverage": table_coverage,
"diagram_coverage": diagram_coverage,
"overall_rate": overall,
}
def test_layer_b_coverage(
ir_data: dict,
parsed_data: dict | None,
ir_path: str,
acceptance_runs: int,
run_ir_pipeline,
request,
):
"""Measure structural coverage and (optionally) coverage stability."""
if parsed_data is None:
pytest.skip("No parsed JSON available for coverage analysis")
# ── B1: single-run coverage ──
cov = _measure_coverage(ir_data, parsed_data)
# ── B2: stability (multi-run) ──
stability_values: list[float] = [cov["overall_rate"]]
stability_std = 0.0
if acceptance_runs > 1:
parsed_path = request.config.getoption("--parsed-path")
if parsed_path and os.path.exists(parsed_path):
for _ in range(acceptance_runs - 1):
try:
ir_list, _ = run_ir_pipeline(parsed_path)
# Convert list-format IR to dict for coverage measurement
run_ir = _wrap_list_ir(ir_list)
run_cov = _measure_coverage(run_ir, parsed_data)
stability_values.append(run_cov["overall_rate"])
time.sleep(0.5) # rate limiting between runs
except Exception as e:
pytest.fail(f"Stability run failed: {e}")
if len(stability_values) > 1:
stability_std = statistics.stdev(stability_values)
# Build Layer B result
b_result = coverage_verdict(
coverage_rate=cov["overall_rate"],
stability_std=stability_std,
stability_values=stability_values,
section_coverage=cov["section_coverage"],
table_coverage=cov["table_coverage"],
diagram_coverage=cov["diagram_coverage"],
)
_stash(request, "layer_b", b_result)
# Assert — both B1 and B2 must pass
assert b_result["coverage_pass"], (
f"Coverage {cov['overall_rate']:.1%} < threshold 70%\n"
f" Sections: {cov['section_coverage']['covered']}/{cov['section_coverage']['total']} "
f"({cov['section_coverage']['rate']:.1%})\n"
f" Uncovered: {cov['section_coverage']['uncovered']}\n"
f" Table rows: {cov['table_coverage']['covered_rows']}/{cov['table_coverage']['total_rows']} "
f"({cov['table_coverage']['rate']:.1%})\n"
f" Diagrams: {cov['diagram_coverage']['covered']}/{cov['diagram_coverage']['total']} "
f"({cov['diagram_coverage']['rate']:.1%})\n"
f" Uncovered diagrams: {cov['diagram_coverage']['uncovered']}"
)
if len(stability_values) > 1:
assert b_result["stability"]["pass"], (
f"Coverage stability std={stability_std:.4f} > threshold 0.05\n"
f" Values across {len(stability_values)} runs: {stability_values}"
)
def _wrap_list_ir(ir_list: list) -> dict:
"""Wrap a list-format IR (from ir_generator.py) into a dict for schema compat."""
# Convert simple format to rich format for coverage measurement
rules = []
for i, entry in enumerate(ir_list):
if not isinstance(entry, dict):
continue
rule = {
"rule_id": f"GEN-001-RULE-{i:03d}",
"description": entry.get("function", ""),
"path": [],
"priority": "P2",
"sources": [],
"precondition": {},
"trigger": entry.get("trigger", {"operator": "AND", "conditions": []}),
"actions": [],
}
# Convert source
src = entry.get("source", {})
if src.get("section"):
rule["sources"].append({
"type": "text",
"section": src["section"],
"paragraph": 1,
"text_snippet": src.get("location", ""),
"priority": "primary_source",
})
rules.append(rule)
return {
"feature": "generated",
"feature_id": "GEN-001",
"rules": rules,
}
# ═══════════════════════════════════════════════════════════════════════════════
# Layer C: LLM QE EXPERT AUDIT
# ═══════════════════════════════════════════════════════════════════════════════
QE_AUDITOR_PROMPT = """你是一个资深 QE 专家,负责审查需求文档的 IR(中间表示层)是否充分覆盖了源文档的所有可测试功能点。
你不是 IR 的生成者你是独立的质量审计员你的职责是判断 IR 的功能覆盖率是否充分
## 审计输入
### Layer B 结构化覆盖率数据(参考)
{coverage_summary}
### 源文档内容(Parsed JSON
{parsed_content}
### 生成的 IR(待审计)
{ir_content}
## 审计要求
对源文档中的每个章节逐一评估其功能需求是否被 IR 充分覆盖
**判断标准**
- **adequate**充分覆盖该章节的所有功能需求在 IR 中都有对应的 rule包括触发条件执行动作
- **inadequate**覆盖不足该章节存在功能需求未在 IR 中体现或描述不完整缺少触发条件或动作
- **not_applicable**不适用该章节为背景介绍术语定义变更日志等不包含功能需求
**注意**
- 如果某个章节涉及多个决策路径如流程图检查 IR 是否覆盖了每条路径
- 表格中的每个功能行都应被至少一个 IR rule 覆盖
- 图片分析中的流程图/决策树节点应被 IR 引用
## 输出格式
请严格输出以下 JSON 格式不要包含代码块标记
{{
"total_functional_sections": <number>,
"adequate": <number>,
"inadequate": <number>,
"not_applicable": <number>,
"inadequate_ratio": <float>,
"verdict": "ACCEPT 或 REJECT",
"rationale": "<一句话说明接受或拒绝的理由>",
"section_assessments": [
{{
"section": "<章节名>",
"assessment": "adequate | inadequate | not_applicable",
"reason": "<评估理由>",
"missing": ["<缺失项1>", "<缺失项2>"] // inadequate 时需要
}}
]
}}
verdict 判定规则
- inadequate_ratio 0.30 "ACCEPT"风险可控
- inadequate_ratio > 0.30 "REJECT"功能点认知差异大需要补充 IR
"""
def test_layer_c_qe_audit(
ir_data: dict, parsed_data: dict | None, llm_client, request
):
"""LLM QE expert audit of functional coverage."""
if parsed_data is None:
pytest.skip("No parsed JSON available — cannot run QE audit")
# ── get Layer B summary for context ──
layer_b = _unstash(request, "layer_b") or {}
cov_summary = json.dumps(
{
"coverage_rate": layer_b.get("coverage_rate", "N/A"),
"section_coverage": layer_b.get("section_coverage", {}),
"diagram_coverage": layer_b.get("diagram_coverage", {}),
},
ensure_ascii=False,
indent=2,
)
# ── prepare content (trim to avoid token overflow) ──
parsed_str = json.dumps(parsed_data, ensure_ascii=False)
ir_str = json.dumps(ir_data, ensure_ascii=False)
max_parsed = 12000
max_ir = 8000
if len(parsed_str) > max_parsed:
parsed_str = parsed_str[:max_parsed] + "\n...[truncated]"
if len(ir_str) > max_ir:
ir_str = ir_str[:max_ir] + "\n...[truncated]"
prompt = QE_AUDITOR_PROMPT.format(
coverage_summary=cov_summary,
parsed_content=parsed_str,
ir_content=ir_str,
)
# ── call LLM ──
try:
raw = llm_client.chat(
model=llm_client.TEXT_MODEL,
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"},
)
except Exception as e:
pytest.fail(f"QE audit LLM call failed: {e}")
# ── parse response ──
audit_data = _parse_json_response(raw)
if audit_data is None:
pytest.fail(f"QE audit returned unparseable response:\n{raw[:500]}")
# Build Layer C result
c_result = audit_verdict(audit_data)
c_result["raw_assessments"] = audit_data.get("section_assessments", [])
_stash(request, "layer_c", c_result)
# Assert
assert c_result["verdict"] == "ACCEPT", (
f"QE Audit REJECTED — inadequate_ratio={c_result['inadequate_ratio']:.1%} > 30%\n"
f" Rationale: {c_result['rationale']}\n"
f" Adequate: {c_result['adequate']}, Inadequate: {c_result['inadequate']}"
)
# ═══════════════════════════════════════════════════════════════════════════════
# Final report (runs last)
# ═══════════════════════════════════════════════════════════════════════════════
def test_final_report(ir_data: dict, ir_path: str, request):
"""Generate the final three-layer JSON report.
This test always passes (report generation). The verdicts from layers A/B/C
determine the final releasable status, but the report itself is informational.
"""
layer_a = _unstash(request, "layer_a") or {"verdict": "SKIPPED"}
layer_b = _unstash(request, "layer_b") or {"verdict": "SKIPPED"}
layer_c = _unstash(request, "layer_c") or {"verdict": "SKIPPED"}
report_path = request.config.getoption("--json-report-file", None) or str(
Path.cwd() / "acceptance-report.json"
)
report = generate_report(
layer_a,
layer_b,
layer_c,
commit=os.environ.get("GITEA_SHA", ""),
branch=os.environ.get("GITEA_BRANCH", "main"),
output_path=report_path,
)
# Print summary
print(f"\n{'='*60}")
print(f"QE ACCEPTANCE REPORT")
print(f"{'='*60}")
print(f" Layer A (Schema): {layer_a.get('verdict', '?')}")
print(f" Layer B (Coverage): {layer_b.get('verdict', '?')} "
f"(rate={layer_b.get('coverage_rate', '?')})")
print(f" Layer C (QE Audit): {layer_c.get('verdict', '?')}")
print(f" {''*40}")
print(f" FINAL: {report['final_verdict']} | "
f"Releasable: {report['releasable']}")
print(f" Report: {report_path}")
print(f"{'='*60}\n")
# Fail if any layer failed (aggregate assertion)
failures = report.get("failure_details", [])
if failures:
pytest.fail(
"Acceptance tests FAILED:\n" + "\n".join(f" - {f}" for f in failures)
)
# ═══════════════════════════════════════════════════════════════════════════════
# Helpers
# ═══════════════════════════════════════════════════════════════════════════════
import os # noqa: E402
# Module-level stash for sharing results across tests in the same module.
# Each test function stores its result here; later tests read earlier results.
_module_stash: dict[str, dict] = {}
def _stash(request, key: str, value: dict):
"""Store a result dict for cross-test access within this module."""
_module_stash[key] = value
def _unstash(request, key: str) -> dict | None:
"""Retrieve a stashed result."""
return _module_stash.get(key)
def _parse_json_response(raw: str) -> dict | None:
"""Parse JSON from an LLM response, handling markdown code fences."""
if not raw:
return None
text = raw.strip()
if text.startswith("```"):
nl = text.find("\n")
text = text[nl + 1:] if nl != -1 else text[3:]
if text.endswith("```"):
text = text[:-3]
try:
return json.loads(text)
except json.JSONDecodeError:
return None
+10 -5
View File
@@ -55,12 +55,17 @@ def test_import_detect_conflicts():
# -- IR generation tests ------------------------------------------------------ # -- IR generation tests ------------------------------------------------------
def test_import_ir_generator(): def test_import_ir_main():
"""ir_generator module should be importable.""" """ir_generation main module should be importable (new project structure)."""
os.environ.setdefault("DASHSCOPE_API_KEY", "test-fake-key") os.environ.setdefault("DASHSCOPE_API_KEY", "test-fake-key")
_import_from_skill("ir_generation_skill", "ir_generator") skill_dir = os.path.join(
import ir_generator os.path.dirname(os.path.dirname(__file__)),
assert hasattr(ir_generator, "generate_ir") "skills", "ir_generation_skill"
)
if skill_dir not in sys.path:
sys.path.insert(0, skill_dir)
import main
assert hasattr(main, "main")
# -- Resolution application tests --------------------------------------------- # -- Resolution application tests ---------------------------------------------