doc_parser_skill: - New: verify_flowchart.py (flowchart validation) - Updated: LLM.py (multi-provider: DeepSeek + DashScope) - Updated: image_parser.py (logic tree support, external prompts) - Updated: SKILL.md, prompts/image_prompt.md conflict_detection_skill: - Updated: LLM.py (multi-provider sync) - Updated: detect_conflicts.py (logic tree text conversion) ir_generation_skill: - Replaced old scripts/LLM.py + ir_generator.py with standalone project - New: main.py, config.py, step1-3_*.py, ensemble_merge.py - New: prompts/, tests/ subdirectories tests: - New: acceptance/ test suite with schema validation - Fixed: conftest no longer globally skips non-acceptance tests - Updated: test_sample.py for new ir_generation structure Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,54 @@
|
||||
name: QE Acceptance Tests
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
acceptance_runs:
|
||||
description: 'Layer B stability runs (1 = skip stability testing)'
|
||||
required: false
|
||||
default: '1'
|
||||
ir_path:
|
||||
description: 'Path to IR JSON file (relative to workspace)'
|
||||
required: false
|
||||
default: 'output/ir_final.json'
|
||||
parsed_path:
|
||||
description: 'Path to _parsed.json or _updated.json (relative to workspace)'
|
||||
required: false
|
||||
default: 'output/车机娱乐系统禁止功能文档_精简_updated.json'
|
||||
|
||||
jobs:
|
||||
acceptance:
|
||||
runs-on: shell
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- name: Checkout main branch
|
||||
run: |
|
||||
git clone --depth 1 http://localhost:3000/pzhang_zywl/document_analyzer.git .
|
||||
git checkout main
|
||||
|
||||
- name: Install dependencies
|
||||
run: pip install -r requirements.txt
|
||||
|
||||
- name: Run QE Acceptance Tests
|
||||
run: >-
|
||||
python -m pytest tests/acceptance/ -v
|
||||
--run-acceptance
|
||||
--acceptance-runs=${{ github.event.inputs.acceptance_runs }}
|
||||
--ir-path=${{ github.event.inputs.ir_path }}
|
||||
--parsed-path=${{ github.event.inputs.parsed_path }}
|
||||
--tb=long
|
||||
env:
|
||||
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
||||
|
||||
- name: Create issue on failure
|
||||
if: failure()
|
||||
env:
|
||||
GITEA_API_TOKEN: ${{ secrets.GITEA_TOKEN }}
|
||||
run: >-
|
||||
python scripts/create_failure_issue.py
|
||||
--sha "${{ github.sha }}"
|
||||
--branch "main"
|
||||
--run "${{ github.run_number }}"
|
||||
--message "QE Acceptance Tests Failed"
|
||||
--workflow "QE Acceptance"
|
||||
--labels "acceptance-failure,agent-task"
|
||||
@@ -7,3 +7,7 @@ output/
|
||||
dist/
|
||||
.runner
|
||||
*_output/
|
||||
*.png
|
||||
*.jpg
|
||||
acceptance-report.json
|
||||
ir_final.json
|
||||
|
||||
@@ -17,14 +17,18 @@ def main():
|
||||
parser.add_argument("--run", required=True)
|
||||
parser.add_argument("--message", required=True)
|
||||
parser.add_argument("--api-token", default=os.environ.get("GITEA_API_TOKEN", ""))
|
||||
parser.add_argument("--workflow", default="CI", help="Workflow name that triggered this (default: CI)")
|
||||
parser.add_argument("--labels", default="ci-failure,agent-task",
|
||||
help="Comma-separated labels for the issue (default: ci-failure,agent-task)")
|
||||
args = parser.parse_args()
|
||||
|
||||
sha_short = args.sha[:7]
|
||||
run_url = f"{GITEA_URL}/{REPO}/actions/runs/{args.run}"
|
||||
labels = [l.strip() for l in args.labels.split(",") if l.strip()]
|
||||
|
||||
title = f"CI Failure: {args.message[:80]}"
|
||||
title = f"[{args.workflow}] Failure: {args.message[:80]}"
|
||||
body = (
|
||||
f"## CI 测试失败\n\n"
|
||||
f"## {args.workflow} 测试失败\n\n"
|
||||
f"- **Commit:** {sha_short}\n"
|
||||
f"- **Branch:** {args.branch}\n"
|
||||
f"- **工作流运行:** {run_url}\n\n"
|
||||
@@ -38,7 +42,7 @@ def main():
|
||||
payload = json.dumps({
|
||||
"title": title,
|
||||
"body": body,
|
||||
"labels": [],
|
||||
"labels": labels,
|
||||
}).encode("utf-8")
|
||||
|
||||
url = f"{GITEA_URL}/api/v1/repos/{REPO}/issues"
|
||||
|
||||
@@ -1,38 +1,97 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Resolve secrets file: priority 1) env OPENCLAW_SECRETS,
|
||||
# 2) workspace-document-analyzer/config/ (relative to skills dir),
|
||||
# 3) .openclaw/config/
|
||||
_SECRETS_FILE = None
|
||||
for _candidate in (
|
||||
os.environ.get("OPENCLAW_SECRETS", ""),
|
||||
Path(__file__).resolve().parents[3] / "config" / "secrets.yaml",
|
||||
Path(__file__).resolve().parents[5] / ".openclaw" / "config" / "secrets.yaml",
|
||||
):
|
||||
if _candidate and Path(_candidate).exists():
|
||||
_SECRETS_FILE = Path(_candidate)
|
||||
break
|
||||
if _SECRETS_FILE is None:
|
||||
_SECRETS_FILE = Path("") # empty fallback
|
||||
|
||||
|
||||
def _load_secrets() -> dict:
|
||||
"""Load API keys from secrets.yaml, with env-var overrides."""
|
||||
secrets = {}
|
||||
if _SECRETS_FILE.exists():
|
||||
try:
|
||||
import yaml
|
||||
with open(_SECRETS_FILE, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
for provider in ("deepseek", "dashscope"):
|
||||
if provider in data and isinstance(data[provider], dict):
|
||||
secrets[provider] = data[provider]
|
||||
except ImportError:
|
||||
logger.warning("pyyaml not installed, cannot read %s", _SECRETS_FILE)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load %s: %s", _SECRETS_FILE, e)
|
||||
|
||||
# Env overrides
|
||||
dk_env = os.environ.get("DEEPSEEK_API_KEY", "")
|
||||
ds_env = os.environ.get("DASHSCOPE_API_KEY", "")
|
||||
if dk_env:
|
||||
secrets.setdefault("deepseek", {})["apiKey"] = dk_env
|
||||
if ds_env:
|
||||
secrets.setdefault("dashscope", {})["apiKey"] = ds_env
|
||||
return secrets
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""Low-level OpenAI-compatible LLM client with retry and token tracking.
|
||||
"""Multi-provider LLM client with retry and token tracking.
|
||||
|
||||
Routes text models to DeepSeek, vision models to DashScope (Bailian).
|
||||
Reads API keys from openclaw config/secrets.yaml, with env-var overrides.
|
||||
|
||||
Usage::
|
||||
|
||||
llm = LLMClient()
|
||||
content = llm.chat("qwen3.5-flash", [{"role": "user", "content": "Hello"}])
|
||||
content = llm.chat("deepseek-v4-pro", [{"role": "user", "content": "Hello"}])
|
||||
print(llm.usage)
|
||||
"""
|
||||
|
||||
IMAGE_MODEL = "qwen3-vl-plus"
|
||||
TEXT_MODEL = "qwen3.5-flash-2026-02-23"
|
||||
TEXT_MODEL = "deepseek-v4-flash"
|
||||
|
||||
DASHSCOPE_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
DEEPSEEK_BASE = "https://api.deepseek.com/v1"
|
||||
|
||||
TIMEOUT = 120
|
||||
MAX_RETRIES = 3
|
||||
|
||||
_VISION_KEYWORDS = ("vl", "vision", "qwen-vl", "qwen3-vl")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
timeout: int | None = None,
|
||||
):
|
||||
key = os.environ.get("DASHSCOPE_API_KEY", "")
|
||||
if not key:
|
||||
raise ValueError("DASHSCOPE_API_KEY environment variable is not set.")
|
||||
self._client = OpenAI(api_key=key, base_url=base_url)
|
||||
secrets = _load_secrets()
|
||||
|
||||
ds_cfg = secrets.get("dashscope", {})
|
||||
dk_cfg = secrets.get("deepseek", {})
|
||||
|
||||
dashscope_key = ds_cfg.get("apiKey", "")
|
||||
dashscope_url = ds_cfg.get("baseUrl", self.DASHSCOPE_BASE)
|
||||
deepseek_key = dk_cfg.get("apiKey", "")
|
||||
deepseek_url = dk_cfg.get("baseUrl", self.DEEPSEEK_BASE)
|
||||
|
||||
self._ds_client = OpenAI(api_key=dashscope_key, base_url=dashscope_url) if dashscope_key else None
|
||||
self._dk_client = OpenAI(api_key=deepseek_key, base_url=deepseek_url) if deepseek_key else None
|
||||
|
||||
self._timeout = timeout or self.TIMEOUT
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
@@ -49,7 +108,7 @@ class LLMClient:
|
||||
@staticmethod
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Quick token estimate. CJK ≈1.7/token, others ≈3.0/token."""
|
||||
cjk = sum(1 for c in text if '一' <= c <= '鿿' or ' ' <= c <= '〿')
|
||||
cjk = sum(1 for c in text if '\u4e00' <= c <= '\u9fff' or '\u3000' <= c <= '\u303f')
|
||||
other = len(text) - cjk
|
||||
return max(1, int(cjk / 1.7 + other / 3.0))
|
||||
|
||||
@@ -58,6 +117,20 @@ class LLMClient:
|
||||
"""Fixed estimate for one vision-model image (~500 tokens)."""
|
||||
return 500
|
||||
|
||||
@staticmethod
|
||||
def _is_vision_model(model: str) -> bool:
|
||||
return any(kw in model.lower() for kw in LLMClient._VISION_KEYWORDS)
|
||||
|
||||
def _get_client(self, model: str) -> OpenAI:
|
||||
if self._is_vision_model(model):
|
||||
if self._ds_client is None:
|
||||
raise ValueError("DASHSCOPE_API_KEY not set but required for vision model")
|
||||
return self._ds_client
|
||||
else:
|
||||
if self._dk_client is None:
|
||||
raise ValueError("DEEPSEEK_API_KEY not set but required for text model")
|
||||
return self._dk_client
|
||||
|
||||
def chat(
|
||||
self, model: str, messages: list[dict], *, timeout: int | None = None,
|
||||
response_format: dict | None = None,
|
||||
@@ -65,8 +138,10 @@ class LLMClient:
|
||||
"""Send a chat completion request and return the response content.
|
||||
|
||||
Automatically retries on failure and accumulates token usage.
|
||||
Routes to DeepSeek for text, DashScope for vision.
|
||||
"""
|
||||
label = f"chat({model})"
|
||||
client = self._get_client(model)
|
||||
|
||||
def _call():
|
||||
t0 = time.time()
|
||||
@@ -74,7 +149,7 @@ class LLMClient:
|
||||
if response_format is not None:
|
||||
kwargs["response_format"] = response_format
|
||||
kwargs["temperature"] = 0
|
||||
resp = self._client.chat.completions.create(**kwargs)
|
||||
resp = client.chat.completions.create(**kwargs)
|
||||
content = resp.choices[0].message.content
|
||||
usg = resp.usage
|
||||
if usg:
|
||||
|
||||
@@ -96,6 +96,77 @@ PROMPT_DETECT_CONFLICT = """你是一个文档一致性检查专家。以下内
|
||||
"""
|
||||
|
||||
|
||||
def _is_nested_tree(lt: dict) -> bool:
|
||||
"""Return True if logic_tree uses the nested children format."""
|
||||
return isinstance(lt.get("children"), list)
|
||||
|
||||
|
||||
def _logic_tree_to_text(lt: dict) -> str:
|
||||
"""Convert logic_tree JSON to readable text for conflict detection.
|
||||
|
||||
Supports both the new nested-tree format and the legacy flat-nodes format.
|
||||
"""
|
||||
if _is_nested_tree(lt):
|
||||
return _nested_tree_to_text(lt)
|
||||
return _flat_tree_to_text(lt)
|
||||
|
||||
|
||||
def _nested_tree_to_text(tree: dict) -> str:
|
||||
"""Convert a nested flowchart tree to readable text."""
|
||||
lines: list[str] = []
|
||||
|
||||
def _walk(node: dict, indent: int = 0):
|
||||
prefix = " " * indent
|
||||
nid = node.get("id", "")
|
||||
name = node.get("name", "")
|
||||
ntype = node.get("type", "")
|
||||
|
||||
type_label = {
|
||||
"start": "起始", "end": "结束", "process": "处理",
|
||||
"decision": "判断", "action": "动作",
|
||||
}.get(ntype, ntype)
|
||||
|
||||
lines.append(f"{prefix}[{type_label}] {nid}: {name}")
|
||||
|
||||
if ntype == "decision":
|
||||
for child in node.get("children", []):
|
||||
cond = child.get("condition", "")
|
||||
lines.append(f"{prefix} 分支 \"{cond}\":")
|
||||
_walk(child["node"], indent + 2)
|
||||
elif "children" in node:
|
||||
for child in node.get("children", []):
|
||||
_walk(child, indent + 1)
|
||||
|
||||
_walk(tree)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _flat_tree_to_text(lt: dict) -> str:
|
||||
"""Convert legacy flat-nodes logic_tree to readable text."""
|
||||
lines: list[str] = []
|
||||
root = lt.get("root", "")
|
||||
if root:
|
||||
lines.append(f"根节点: {root}")
|
||||
for node in lt.get("nodes", []):
|
||||
nid = node.get("id", "")
|
||||
ntype = node.get("type", "")
|
||||
if ntype == "decision":
|
||||
cond = node.get("condition", "")
|
||||
branches = node.get("branches", [])
|
||||
lines.append(f"判断节点 {nid}: 条件=\"{cond}\"")
|
||||
for b in branches:
|
||||
lines.append(f" - 分支 \"{b.get('value', '')}\" → {b.get('target', '')}")
|
||||
elif ntype == "action":
|
||||
lines.append(f"动作节点 {nid}: {node.get('description', '')}")
|
||||
elif ntype == "state":
|
||||
lines.append(f"状态节点 {nid}: {node.get('description', '')}")
|
||||
elif ntype == "start":
|
||||
lines.append(f"起始节点 {nid}: {node.get('description', '')}")
|
||||
elif ntype == "end":
|
||||
lines.append(f"结束节点 {nid}: {node.get('description', '')}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _build_text_for_section(sections: list[dict], section_name: str) -> str:
|
||||
"""Build a single text block for the given section name."""
|
||||
texts: list[str] = []
|
||||
@@ -184,8 +255,9 @@ def detect_conflicts(
|
||||
img_type = img.get("type", "other")
|
||||
rid = img.get("rid", "")
|
||||
description = img.get("description", "").strip()
|
||||
logic_tree = img.get("logic_tree_nested") or img.get("logic_tree")
|
||||
|
||||
if img_type not in DIAGRAM_TYPES or not description:
|
||||
if img_type not in DIAGRAM_TYPES or (not description and not logic_tree):
|
||||
logger.info("Skip conflict check: rid=%s type=%s", rid, img_type)
|
||||
continue
|
||||
|
||||
@@ -211,8 +283,17 @@ def detect_conflicts(
|
||||
logger.info(" [DRY RUN] would call LLM to detect conflicts")
|
||||
continue
|
||||
|
||||
# Enrich description with logic_tree if available
|
||||
combined_desc = description
|
||||
if logic_tree:
|
||||
lt_text = _logic_tree_to_text(logic_tree)
|
||||
if combined_desc:
|
||||
combined_desc = f"[结构化逻辑树]\n{lt_text}\n\n[文字描述]\n{combined_desc}"
|
||||
else:
|
||||
combined_desc = f"[结构化逻辑树]\n{lt_text}"
|
||||
|
||||
prompt = PROMPT_DETECT_CONFLICT.format(
|
||||
image_description=description,
|
||||
image_description=combined_desc,
|
||||
text_description=text_content,
|
||||
section_name=section_name,
|
||||
)
|
||||
|
||||
@@ -29,7 +29,10 @@ description: 解析文档(.docx, .pdf)以提取图像和文本结构,并
|
||||
该技能生成一个结构化JSON文件,文件名为输入文档的基本名称后跟'_parsed.json',包含:
|
||||
- `sections`:按标题分组的文档文本结构
|
||||
- `image_sources`:从图像标识符到其在文档中位置的映射
|
||||
- `image_analysis`:由视觉大语言模型确定的每个图像的类型和内容描述
|
||||
- `image_analysis`:由视觉大语言模型确定的每个图像的类型、内容描述和(如适用)结构化逻辑树
|
||||
- `type`: 图片类型(flowchart/architecture/state/sequence/activity/other)
|
||||
- `description`: 图片的文字描述
|
||||
- `logic_tree`(可选,仅图表类型):结构化逻辑树JSON,包含 `root`(根节点描述)和 `nodes` 数组。节点类型:`decision`(判断)、`action`(动作)、`state`(状态)、`start`(开始)、`end`(结束)。decision 节点包含 `condition` 和 `branches` 字段,其他节点包含 `description` 字段。
|
||||
|
||||
## 集成点
|
||||
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
请分析这张图片,判断类型并输出文字描述和(如适用)结构化逻辑树。
|
||||
|
||||
## 判断图片类型
|
||||
|
||||
如果是 **流程图 / 架构图 / 状态图 / 时序图 / 活动图**,你需要输出三项内容:
|
||||
1. 类型标签
|
||||
2. **嵌套逻辑树 JSON**(见下方格式)
|
||||
3. 文字描述
|
||||
|
||||
如果是 **其他类型**(UI原型图 / 界面截图 / 设计稿 / 手机屏幕截图 / 网页截图等),只输出类型标签和简要文字描述。
|
||||
|
||||
## 嵌套逻辑树 JSON 格式(仅流程图/架构图/状态图/时序图/活动图需要)
|
||||
|
||||
**核心原则:用嵌套的 `children` 数组表达流程的层级关系,而不是用 id 引用。** 这种格式更贴近流程图的自然结构,每个节点的后续步骤直接嵌套在其下方。
|
||||
|
||||
### 节点类型
|
||||
|
||||
| 类型 | 含义 | 对应形状 |
|
||||
|------|------|----------|
|
||||
| `start` | 起始节点 | 椭圆/圆角矩形 |
|
||||
| `end` | 结束节点 | 椭圆/圆角矩形 |
|
||||
| `process` | 处理/状态节点 | 矩形/圆角矩形 |
|
||||
| `decision` | 判断节点 | 菱形 |
|
||||
| `action` | 动作节点 | 矩形 |
|
||||
|
||||
### 非判断节点的 children 格式
|
||||
|
||||
对于 `start`、`end`、`process`、`action` 节点,`children` 是一个数组,包含后续步骤节点:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "n1",
|
||||
"name": "节点名称",
|
||||
"type": "process",
|
||||
"children": [
|
||||
{
|
||||
"id": "n2",
|
||||
"name": "下一个步骤",
|
||||
"type": "action",
|
||||
"children": [...]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 判断节点的 children 格式
|
||||
|
||||
对于 `decision` 节点,`children` 是一个数组,每个元素包含 `condition`(分支条件)和 `node`(该分支对应的子节点):
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "n5",
|
||||
"name": "是否满足条件?",
|
||||
"type": "decision",
|
||||
"children": [
|
||||
{
|
||||
"condition": "是",
|
||||
"node": {
|
||||
"id": "n6",
|
||||
"name": "满足条件时的动作",
|
||||
"type": "action",
|
||||
"children": [...]
|
||||
}
|
||||
},
|
||||
{
|
||||
"condition": "否",
|
||||
"node": {
|
||||
"id": "n7",
|
||||
"name": "不满足条件时的动作",
|
||||
"type": "action",
|
||||
"children": [...]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 结束节点
|
||||
|
||||
`end` 节点没有 `children` 字段:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "n10",
|
||||
"name": "流程结束",
|
||||
"type": "end"
|
||||
}
|
||||
```
|
||||
|
||||
### 完整示例
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "n1",
|
||||
"name": "开关状态",
|
||||
"type": "start",
|
||||
"children": [
|
||||
{
|
||||
"id": "n2",
|
||||
"name": "开启",
|
||||
"type": "process",
|
||||
"children": [
|
||||
{
|
||||
"id": "n3",
|
||||
"name": "是否在目标场景?",
|
||||
"type": "decision",
|
||||
"children": [
|
||||
{
|
||||
"condition": "否",
|
||||
"node": {
|
||||
"id": "n4",
|
||||
"name": "不受限",
|
||||
"type": "end"
|
||||
}
|
||||
},
|
||||
{
|
||||
"condition": "是",
|
||||
"node": {
|
||||
"id": "n5",
|
||||
"name": "车速是否≥15km/h且持续5秒?",
|
||||
"type": "decision",
|
||||
"children": [
|
||||
{
|
||||
"condition": "否",
|
||||
"node": {
|
||||
"id": "n6",
|
||||
"name": "不受限",
|
||||
"type": "end"
|
||||
}
|
||||
},
|
||||
{
|
||||
"condition": "是",
|
||||
"node": {
|
||||
"id": "n7",
|
||||
"name": "暂停功能",
|
||||
"type": "action",
|
||||
"children": [
|
||||
{
|
||||
"id": "n8",
|
||||
"name": "发起Toast提示",
|
||||
"type": "end"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "n9",
|
||||
"name": "关闭",
|
||||
"type": "process",
|
||||
"children": [
|
||||
{
|
||||
"id": "n10",
|
||||
"name": "不受限",
|
||||
"type": "end"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 规则
|
||||
|
||||
1. 每条从根节点到 `end` 节点的路径必须是完整的逻辑链
|
||||
2. `decision` 节点的 `children` 必须穷举所有分支(通常为"是/否"),每条分支包含 `condition` 和 `node`
|
||||
3. 只有 `end` 节点没有 `children` 字段,其他所有节点都应该有 `children`
|
||||
4. 节点 id 使用 "n1", "n2", "n3"... 格式,按流程图从上到下、从左到右的顺序编号
|
||||
5. 仔细阅读图片中的每个判断条件和分支走向,确保分支目标节点正确
|
||||
6. 如果流程图中某个分支的后续步骤在图片中没有展示,将其标记为 `end` 节点,`name` 设为"(图中未展示)"
|
||||
7. **如果图片包含多个独立的流程图**(例如上半部分和下半部分分别描述不同场景),使用一个统一的 `process` 根节点将它们组织在一起。例如图片中有"策略A"和"策略B"两个流程,结构为:
|
||||
```json
|
||||
{
|
||||
"id": "n1",
|
||||
"name": "策略总览",
|
||||
"type": "process",
|
||||
"children": [
|
||||
{"id": "n2", "name": "策略A流程", "type": "process", "children": [...]},
|
||||
{"id": "n3", "name": "策略B流程", "type": "process", "children": [...]}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## 输出格式
|
||||
|
||||
**1. 类型标签(单独一行):**
|
||||
type: <flowchart|architecture|state|sequence|activity|other>
|
||||
|
||||
**2. 逻辑树 JSON(仅上述5种类型,以 logic_tree: 开头,后跟 JSON 对象):**
|
||||
logic_tree:
|
||||
{...}
|
||||
|
||||
**3. 文字描述(以 description: 开头):**
|
||||
description:
|
||||
该图片的详细文字描述。对于流程图/架构图等类型,这里提供自然语言总结;对于其他类型,这是唯一的描述内容。
|
||||
|
||||
不要输出 ``` 代码块包裹符号,不要输出 ---YAML--- 分隔符,不要添加任何额外的解释或问候语。
|
||||
@@ -1,38 +1,97 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Resolve secrets file: priority 1) env OPENCLAW_SECRETS,
|
||||
# 2) workspace-document-analyzer/config/ (relative to skills dir),
|
||||
# 3) .openclaw/config/
|
||||
_SECRETS_FILE = None
|
||||
for _candidate in (
|
||||
os.environ.get("OPENCLAW_SECRETS", ""),
|
||||
Path(__file__).resolve().parents[3] / "config" / "secrets.yaml",
|
||||
Path(__file__).resolve().parents[5] / ".openclaw" / "config" / "secrets.yaml",
|
||||
):
|
||||
if _candidate and Path(_candidate).exists():
|
||||
_SECRETS_FILE = Path(_candidate)
|
||||
break
|
||||
if _SECRETS_FILE is None:
|
||||
_SECRETS_FILE = Path("") # empty fallback
|
||||
|
||||
|
||||
def _load_secrets() -> dict:
|
||||
"""Load API keys from secrets.yaml, with env-var overrides."""
|
||||
secrets = {}
|
||||
if _SECRETS_FILE.exists():
|
||||
try:
|
||||
import yaml
|
||||
with open(_SECRETS_FILE, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
for provider in ("deepseek", "dashscope"):
|
||||
if provider in data and isinstance(data[provider], dict):
|
||||
secrets[provider] = data[provider]
|
||||
except ImportError:
|
||||
logger.warning("pyyaml not installed, cannot read %s", _SECRETS_FILE)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load %s: %s", _SECRETS_FILE, e)
|
||||
|
||||
# Env overrides
|
||||
dk_env = os.environ.get("DEEPSEEK_API_KEY", "")
|
||||
ds_env = os.environ.get("DASHSCOPE_API_KEY", "")
|
||||
if dk_env:
|
||||
secrets.setdefault("deepseek", {})["apiKey"] = dk_env
|
||||
if ds_env:
|
||||
secrets.setdefault("dashscope", {})["apiKey"] = ds_env
|
||||
return secrets
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""Low-level OpenAI-compatible LLM client with retry and token tracking.
|
||||
"""Multi-provider LLM client with retry and token tracking.
|
||||
|
||||
Routes text models to DeepSeek, vision models to DashScope (Bailian).
|
||||
Reads API keys from openclaw config/secrets.yaml, with env-var overrides.
|
||||
|
||||
Usage::
|
||||
|
||||
llm = LLMClient()
|
||||
content = llm.chat("qwen3.5-flash", [{"role": "user", "content": "Hello"}])
|
||||
content = llm.chat("deepseek-v4-pro", [{"role": "user", "content": "Hello"}])
|
||||
print(llm.usage)
|
||||
"""
|
||||
|
||||
IMAGE_MODEL = "qwen3-vl-plus"
|
||||
TEXT_MODEL = "qwen3.5-flash-2026-02-23"
|
||||
TEXT_MODEL = "deepseek-v4-flash"
|
||||
|
||||
DASHSCOPE_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
DEEPSEEK_BASE = "https://api.deepseek.com/v1"
|
||||
|
||||
TIMEOUT = 120
|
||||
MAX_RETRIES = 3
|
||||
|
||||
_VISION_KEYWORDS = ("vl", "vision", "qwen-vl", "qwen3-vl")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
timeout: int | None = None,
|
||||
):
|
||||
key = os.environ.get("DASHSCOPE_API_KEY", "")
|
||||
if not key:
|
||||
raise ValueError("DASHSCOPE_API_KEY environment variable is not set.")
|
||||
self._client = OpenAI(api_key=key, base_url=base_url)
|
||||
secrets = _load_secrets()
|
||||
|
||||
ds_cfg = secrets.get("dashscope", {})
|
||||
dk_cfg = secrets.get("deepseek", {})
|
||||
|
||||
dashscope_key = ds_cfg.get("apiKey", "")
|
||||
dashscope_url = ds_cfg.get("baseUrl", self.DASHSCOPE_BASE)
|
||||
deepseek_key = dk_cfg.get("apiKey", "")
|
||||
deepseek_url = dk_cfg.get("baseUrl", self.DEEPSEEK_BASE)
|
||||
|
||||
self._ds_client = OpenAI(api_key=dashscope_key, base_url=dashscope_url) if dashscope_key else None
|
||||
self._dk_client = OpenAI(api_key=deepseek_key, base_url=deepseek_url) if deepseek_key else None
|
||||
|
||||
self._timeout = timeout or self.TIMEOUT
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
@@ -49,7 +108,7 @@ class LLMClient:
|
||||
@staticmethod
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Quick token estimate. CJK ≈1.7/token, others ≈3.0/token."""
|
||||
cjk = sum(1 for c in text if '一' <= c <= '鿿' or ' ' <= c <= '〿')
|
||||
cjk = sum(1 for c in text if '\u4e00' <= c <= '\u9fff' or '\u3000' <= c <= '\u303f')
|
||||
other = len(text) - cjk
|
||||
return max(1, int(cjk / 1.7 + other / 3.0))
|
||||
|
||||
@@ -58,6 +117,20 @@ class LLMClient:
|
||||
"""Fixed estimate for one vision-model image (~500 tokens)."""
|
||||
return 500
|
||||
|
||||
@staticmethod
|
||||
def _is_vision_model(model: str) -> bool:
|
||||
return any(kw in model.lower() for kw in LLMClient._VISION_KEYWORDS)
|
||||
|
||||
def _get_client(self, model: str) -> OpenAI:
|
||||
if self._is_vision_model(model):
|
||||
if self._ds_client is None:
|
||||
raise ValueError("DASHSCOPE_API_KEY not set but required for vision model")
|
||||
return self._ds_client
|
||||
else:
|
||||
if self._dk_client is None:
|
||||
raise ValueError("DEEPSEEK_API_KEY not set but required for text model")
|
||||
return self._dk_client
|
||||
|
||||
def chat(
|
||||
self, model: str, messages: list[dict], *, timeout: int | None = None,
|
||||
response_format: dict | None = None,
|
||||
@@ -65,8 +138,10 @@ class LLMClient:
|
||||
"""Send a chat completion request and return the response content.
|
||||
|
||||
Automatically retries on failure and accumulates token usage.
|
||||
Routes to DeepSeek for text, DashScope for vision.
|
||||
"""
|
||||
label = f"chat({model})"
|
||||
client = self._get_client(model)
|
||||
|
||||
def _call():
|
||||
t0 = time.time()
|
||||
@@ -74,7 +149,7 @@ class LLMClient:
|
||||
if response_format is not None:
|
||||
kwargs["response_format"] = response_format
|
||||
kwargs["temperature"] = 0
|
||||
resp = self._client.chat.completions.create(**kwargs)
|
||||
resp = client.chat.completions.create(**kwargs)
|
||||
content = resp.choices[0].message.content
|
||||
usg = resp.usage
|
||||
if usg:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from LLM import LLMClient
|
||||
@@ -8,32 +10,56 @@ from LLM import LLMClient
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompts
|
||||
# Prompt loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROMPT_IMAGE = """请分析这张图片,判断类型并输出文字描述。
|
||||
def _load_prompt() -> str:
|
||||
"""Load PROMPT_IMAGE from external file, falling back to inline default."""
|
||||
prompt_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "prompts")
|
||||
prompt_path = os.path.join(prompt_dir, "image_prompt.md")
|
||||
if os.path.isfile(prompt_path):
|
||||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
# Fallback inline prompt (nested tree format)
|
||||
return """请分析这张图片,判断类型并输出文字描述和(如适用)结构化逻辑树。
|
||||
|
||||
## 判断图片类型
|
||||
|
||||
如果是 **流程图 / 架构图 / 状态图 / 时序图 / 活动图**,详细描述:
|
||||
- 图中所有节点/步骤/状态/组件的名称
|
||||
- 所有连线/箭头/转换关系及其方向
|
||||
- 所有分支条件、判断逻辑和判断结果
|
||||
- 所有文字标注、注释、标签
|
||||
- 图的整体结构和逻辑流程
|
||||
- 如果图片包含多个子图,拆解描述
|
||||
如果是 **流程图 / 架构图 / 状态图 / 时序图 / 活动图**,你需要输出三项内容:
|
||||
1. 类型标签
|
||||
2. **嵌套逻辑树 JSON**(见下方格式)
|
||||
3. 文字描述
|
||||
|
||||
如果是 **其他类型**(UI原型图 / 界面截图 / 设计稿 / 手机屏幕截图 / 网页截图等),简要描述图片内容。
|
||||
如果是 **其他类型**(UI原型图 / 界面截图 / 设计稿 / 手机屏幕截图 / 网页截图等),只输出类型标签和简要文字描述。
|
||||
|
||||
## 嵌套逻辑树 JSON 格式(仅流程图/架构图/状态图/时序图/活动图需要)
|
||||
|
||||
**核心原则:用嵌套的 `children` 数组表达流程的层级关系,而不是用 id 引用。**
|
||||
|
||||
节点类型:`start`(起始), `end`(结束), `process`(处理/状态), `decision`(判断), `action`(动作)
|
||||
|
||||
非判断节点的 `children` 是子节点数组。`end` 节点无 `children`。
|
||||
|
||||
判断节点的 `children` 格式:
|
||||
```json
|
||||
{"condition": "是", "node": {"id": "n6", "name": "...", "type": "action", "children": [...]}}
|
||||
```
|
||||
|
||||
每条从根到 `end` 的路径必须是完整逻辑链。decision 必须穷举所有分支。
|
||||
节点 id 使用 "n1", "n2", "n3"... 格式。
|
||||
|
||||
## 输出格式
|
||||
|
||||
**1. 类型标签(单独一行):**
|
||||
type: <flowchart|architecture|state|sequence|activity|other>
|
||||
|
||||
**2. 文字描述:**
|
||||
该图片的详细文字描述。
|
||||
logic_tree:
|
||||
{...}
|
||||
|
||||
不要输出 ---YAML--- 分隔符或 YAML 内容,不要添加任何额外的解释或问候语。"""
|
||||
description:
|
||||
该图片的详细文字描述。"""
|
||||
|
||||
PROMPT_IMAGE = _load_prompt()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -41,7 +67,10 @@ type: <flowchart|architecture|state|sequence|activity|other>
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ImageParser:
|
||||
"""Vision LLM wrapper for parsing images (type + description).
|
||||
"""Vision LLM wrapper for parsing images (type + description + logic_tree).
|
||||
|
||||
The nested-tree ``logic_tree`` is stored alongside a backward-compatible
|
||||
flat representation so downstream consumers are not broken.
|
||||
|
||||
Usage::
|
||||
|
||||
@@ -49,7 +78,7 @@ class ImageParser:
|
||||
result = parser.parse_image("images/img1.png")
|
||||
"""
|
||||
|
||||
_VALID_TYPES = {"flowchart", "architecture", "state", "sequence", "activity", "text"}
|
||||
_VALID_TYPES = {"flowchart", "architecture", "state", "sequence", "activity", "other"}
|
||||
|
||||
def __init__(self, llm: LLMClient | None = None):
|
||||
self._llm = llm or LLMClient()
|
||||
@@ -59,9 +88,9 @@ class ImageParser:
|
||||
return self._llm.usage
|
||||
|
||||
def parse_image(self, image_path: str) -> Optional[dict]:
|
||||
"""Parse an image and return its type and description (no YAML IR).
|
||||
"""Parse an image and return its type, description, and optional logic_tree.
|
||||
|
||||
Returns ``{type, description}``, or *None* for UI mockups.
|
||||
Returns ``{type, description, [logic_tree], [logic_tree_nested]}``.
|
||||
"""
|
||||
logger.info("Parsing image: %s", image_path)
|
||||
|
||||
@@ -84,34 +113,292 @@ class ImageParser:
|
||||
logger.error(str(e))
|
||||
return {"type": "other", "description": "", "error": str(e)}
|
||||
|
||||
parsed = self._parse_type_and_description(content)
|
||||
parsed = self._parse_response(content)
|
||||
if parsed is None:
|
||||
return None
|
||||
return {"type": parsed[0], "description": parsed[1]}
|
||||
ptype, description, logic_tree_nested = parsed
|
||||
|
||||
result: dict = {"type": ptype, "description": description}
|
||||
if logic_tree_nested is not None:
|
||||
result["logic_tree_nested"] = logic_tree_nested
|
||||
result["logic_tree"] = self._flatten_tree(logic_tree_nested)
|
||||
return result
|
||||
|
||||
# ---- internals ----------------------------------------------------------
|
||||
|
||||
def _parse_type_and_description(self, content: str) -> Optional[tuple[str, str]]:
|
||||
"""Extract ``(type, description)`` from LLM response.
|
||||
def _parse_response(self, content: str) -> Optional[tuple[str, str, Optional[dict]]]:
|
||||
"""Extract ``(type, description, logic_tree_nested)`` from LLM response.
|
||||
|
||||
Returns *None* for ``[[UI]]`` (skip).
|
||||
Parses the nested-tree format. Returns *None* for unparseable content.
|
||||
"""
|
||||
content = content.strip()
|
||||
if content == "[[UI]]" or content.startswith("[[UI]]"):
|
||||
return None
|
||||
|
||||
parsed_type = "other"
|
||||
desc_lines: list[str] = []
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
if (stripped.startswith("type:") or stripped.startswith("类型:")) and parsed_type == "other":
|
||||
type_val = stripped.split(":", 1)[1].strip().lower()
|
||||
if type_val in self._VALID_TYPES:
|
||||
parsed_type = type_val
|
||||
else:
|
||||
desc_lines.append(line)
|
||||
logic_tree = None
|
||||
description = ""
|
||||
|
||||
return parsed_type, "\n".join(desc_lines).strip()
|
||||
# --- type ---
|
||||
type_match = re.search(r'(?:type|类型):\s*(\S+)', content)
|
||||
if type_match:
|
||||
type_val = type_match.group(1).strip().lower()
|
||||
if type_val in self._VALID_TYPES:
|
||||
parsed_type = type_val
|
||||
|
||||
# --- logic_tree (anchored at line start) ---
|
||||
lt_match = re.search(r'(?m)^logic_tree:\s*', content)
|
||||
desc_match = re.search(r'(?m)^description:\s*', content)
|
||||
|
||||
if lt_match:
|
||||
lt_start = lt_match.end()
|
||||
lt_end = desc_match.start() if desc_match and desc_match.start() > lt_start else len(content)
|
||||
lt_raw = content[lt_start:lt_end].strip()
|
||||
|
||||
# Try multiple JSON extraction strategies
|
||||
logic_tree = self._extract_json(lt_raw)
|
||||
|
||||
if logic_tree is not None:
|
||||
is_valid, err_msg = self._validate_flowchart(logic_tree)
|
||||
if not is_valid:
|
||||
logger.warning("Flowchart validation warning: %s", err_msg)
|
||||
else:
|
||||
logger.info("Failed to extract logic_tree JSON. Raw block length=%d", len(lt_raw))
|
||||
logger.debug("Raw logic_tree block: %s", lt_raw[:500])
|
||||
elif parsed_type in self._VALID_TYPES - {"other"}:
|
||||
logger.info("Diagram type=%s but no logic_tree: in response. Response length=%d",
|
||||
parsed_type, len(content))
|
||||
logger.debug("Raw response (first 500): %s", content[:500])
|
||||
|
||||
# --- description ---
|
||||
if desc_match:
|
||||
description = content[desc_match.end():].strip()
|
||||
else:
|
||||
desc = content
|
||||
if type_match:
|
||||
desc = desc[type_match.end():]
|
||||
desc = re.sub(r'(?m)^logic_tree:\s*\{.*?\}\s*', '', desc, flags=re.DOTALL)
|
||||
description = desc.strip()
|
||||
|
||||
return parsed_type, description, logic_tree
|
||||
|
||||
@staticmethod
|
||||
def _validate_flowchart(tree: dict) -> tuple[bool, str]:
|
||||
"""Validate a nested flowchart tree structure.
|
||||
|
||||
Returns ``(is_valid, error_message)``. Non-fatal: returns ``False``
|
||||
with a warning message but the tree is still kept.
|
||||
"""
|
||||
if not isinstance(tree, dict):
|
||||
return False, "logic_tree is not a dict"
|
||||
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
def _walk(node: dict, depth: int = 0) -> tuple[bool, str]:
|
||||
if depth > 20:
|
||||
return False, f"Tree too deep (>20) at node {node.get('id', '?')}"
|
||||
|
||||
nid = node.get("id", "")
|
||||
if not nid:
|
||||
return False, "Node missing 'id' field"
|
||||
if not isinstance(nid, str):
|
||||
return False, f"Node id must be string, got {type(nid).__name__}"
|
||||
if nid in seen_ids:
|
||||
return False, f"Duplicate node id: {nid}"
|
||||
seen_ids.add(nid)
|
||||
|
||||
ntype = node.get("type", "")
|
||||
if ntype not in ("start", "end", "process", "decision", "action"):
|
||||
return False, f"Unknown node type '{ntype}' at {nid}"
|
||||
|
||||
if ntype == "end":
|
||||
if "children" in node:
|
||||
return False, f"End node {nid} should not have children"
|
||||
return True, ""
|
||||
|
||||
children = node.get("children")
|
||||
if not children:
|
||||
if ntype != "end":
|
||||
return False, f"Non-end node {nid} ({ntype}) has no children"
|
||||
return True, ""
|
||||
|
||||
if not isinstance(children, list):
|
||||
return False, f"children of {nid} is not a list"
|
||||
|
||||
if ntype == "decision":
|
||||
for child in children:
|
||||
if not isinstance(child, dict):
|
||||
return False, f"decision child of {nid} is not a dict"
|
||||
if "condition" not in child:
|
||||
return False, f"decision child of {nid} missing 'condition'"
|
||||
if "node" not in child:
|
||||
return False, f"decision child of {nid} missing 'node'"
|
||||
ok, err = _walk(child["node"], depth + 1)
|
||||
if not ok:
|
||||
return False, err
|
||||
else:
|
||||
for child in children:
|
||||
if not isinstance(child, dict):
|
||||
return False, f"child of {nid} is not a dict"
|
||||
ok, err = _walk(child, depth + 1)
|
||||
if not ok:
|
||||
return False, err
|
||||
|
||||
return True, ""
|
||||
|
||||
return _walk(tree)
|
||||
|
||||
@staticmethod
|
||||
def _flatten_tree(tree: dict) -> dict:
|
||||
"""Convert a nested flowchart tree into the legacy flat-nodes format.
|
||||
|
||||
This preserves backward compatibility with downstream consumers
|
||||
(conflict_detection_skill, ir_generator) that expect the flat format.
|
||||
"""
|
||||
nodes: list[dict] = []
|
||||
root_name = ""
|
||||
|
||||
def _collect(node: dict):
|
||||
nonlocal root_name
|
||||
nid = node.get("id", "")
|
||||
ntype = node.get("type", "")
|
||||
name = node.get("name", "")
|
||||
|
||||
if root_name == "" and "children" in node:
|
||||
root_name = name
|
||||
|
||||
if ntype == "decision":
|
||||
branches = []
|
||||
for child in node.get("children", []):
|
||||
branches.append({
|
||||
"value": child.get("condition", ""),
|
||||
"target": child["node"].get("id", ""),
|
||||
})
|
||||
_collect(child["node"])
|
||||
nodes.append({
|
||||
"id": nid,
|
||||
"type": ntype,
|
||||
"condition": name,
|
||||
"branches": branches,
|
||||
})
|
||||
elif ntype in ("action", "process", "state"):
|
||||
nodes.append({
|
||||
"id": nid,
|
||||
"type": ntype,
|
||||
"description": name,
|
||||
})
|
||||
for child in node.get("children", []):
|
||||
_collect(child)
|
||||
elif ntype == "start":
|
||||
nodes.append({
|
||||
"id": nid,
|
||||
"type": ntype,
|
||||
"description": name,
|
||||
})
|
||||
for child in node.get("children", []):
|
||||
_collect(child)
|
||||
# end nodes are collected but have no children
|
||||
|
||||
_collect(tree)
|
||||
|
||||
# Add end nodes from the nested tree
|
||||
ends: list[dict] = []
|
||||
|
||||
def _collect_ends(node: dict):
|
||||
if node.get("type") == "end":
|
||||
ends.append({
|
||||
"id": node.get("id", ""),
|
||||
"type": "end",
|
||||
"description": node.get("name", ""),
|
||||
})
|
||||
elif "children" in node:
|
||||
for child in node.get("children", []):
|
||||
if isinstance(child, dict):
|
||||
if "node" in child:
|
||||
_collect_ends(child["node"])
|
||||
else:
|
||||
_collect_ends(child)
|
||||
|
||||
_collect_ends(tree)
|
||||
nodes.extend(ends)
|
||||
|
||||
return {"root": root_name, "nodes": nodes}
|
||||
|
||||
@staticmethod
|
||||
def extract_paths(tree: dict) -> list[list[dict]]:
|
||||
"""Extract all root-to-leaf paths from a nested flowchart tree.
|
||||
|
||||
Each path is a list of node dicts (each with id, name, type).
|
||||
Returns a list of paths useful for human review and LLM verification.
|
||||
"""
|
||||
paths: list[list[dict]] = []
|
||||
|
||||
def _walk(node: dict, current_path: list[dict]):
|
||||
entry = {"id": node.get("id", ""), "name": node.get("name", ""), "type": node.get("type", "")}
|
||||
new_path = current_path + [entry]
|
||||
|
||||
if node.get("type") == "end":
|
||||
paths.append(new_path)
|
||||
return
|
||||
|
||||
children = node.get("children", [])
|
||||
if not children:
|
||||
paths.append(new_path)
|
||||
return
|
||||
|
||||
if node.get("type") == "decision":
|
||||
for child in children:
|
||||
_walk(child["node"], new_path)
|
||||
else:
|
||||
for child in children:
|
||||
_walk(child, new_path)
|
||||
|
||||
_walk(tree, [])
|
||||
return paths
|
||||
|
||||
@staticmethod
|
||||
def paths_to_text(paths: list[list[dict]]) -> str:
|
||||
"""Render extracted paths as human-readable text for review."""
|
||||
lines: list[str] = []
|
||||
for i, path in enumerate(paths, 1):
|
||||
steps = []
|
||||
for node in path:
|
||||
if node["type"] == "decision":
|
||||
steps.append(f"[判断] {node['name']}")
|
||||
elif node["type"] == "end":
|
||||
steps.append(f"[结束] {node['name']}")
|
||||
else:
|
||||
steps.append(f"[{node['type']}] {node['name']}")
|
||||
lines.append(f"路径 {i}: {' -> '.join(steps)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _extract_json(text: str) -> Optional[dict]:
|
||||
"""Try multiple strategies to extract a JSON object from text.
|
||||
|
||||
Returns the parsed dict or None.
|
||||
"""
|
||||
# Strategy 1: first { ... } pair (simple regex)
|
||||
json_match = re.search(r'\{.*\}', text, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Strategy 2: find balanced braces
|
||||
start = text.find("{")
|
||||
if start >= 0:
|
||||
depth = 0
|
||||
for i in range(start, len(text)):
|
||||
if text[i] == "{":
|
||||
depth += 1
|
||||
elif text[i] == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
try:
|
||||
return json.loads(text[start:i + 1])
|
||||
except json.JSONDecodeError:
|
||||
break
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _mime_type(image_path: str) -> str:
|
||||
|
||||
@@ -0,0 +1,384 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Verify flowchart logic trees for structural correctness and consistency.
|
||||
|
||||
Usage::
|
||||
|
||||
python verify_flowchart.py <parsed.json|flowchart.json> [--llm] [--output-report REPORT.md]
|
||||
|
||||
Performs three levels of checks:
|
||||
|
||||
1. **Structural validation** — tree integrity, node uniqueness, leaf types
|
||||
2. **Path extraction** — renders all root-to-leaf paths as readable text
|
||||
3. **LLM consistency check** (opt-in with ``--llm``) — compares extracted paths
|
||||
against the original text description for logical inconsistencies
|
||||
|
||||
Outputs PASS/FAIL and a detailed report.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from image_parser import ImageParser
|
||||
from LLM import LLMClient
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt for LLM path-vs-description consistency check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROMPT_VERIFY_PATHS = """你是一个流程图审核专家。以下内容来自同一张流程图的解析结果:
|
||||
|
||||
## 流程图路径(从嵌套逻辑树提取的所有根到叶路径)
|
||||
```
|
||||
{paths_text}
|
||||
```
|
||||
|
||||
## 原始文字描述
|
||||
```
|
||||
{description}
|
||||
```
|
||||
|
||||
## 你的任务
|
||||
逐条检查每条路径是否与文字描述一致。重点关注:
|
||||
|
||||
1. **分支方向错误**:路径中的判断分支走向是否与文字描述矛盾?
|
||||
例如:文字说"满足条件后退出",但路径中"是"分支走向了"不受限"。
|
||||
2. **缺失步骤**:路径中是否缺少文字描述中提到的关键步骤?
|
||||
3. **冗余步骤**:路径中是否包含文字描述未提及的多余步骤?
|
||||
4. **条件颠倒**:判断条件的"是/否"分支是否与文字描述相反?
|
||||
|
||||
## 输出格式
|
||||
|
||||
如果**所有路径一致**,只输出:
|
||||
```
|
||||
[[PATHS_CONSISTENT]]
|
||||
```
|
||||
|
||||
如果**发现不一致**,输出 JSON 数组:
|
||||
```json
|
||||
[
|
||||
{{
|
||||
"path_index": 1,
|
||||
"issue_type": "branch_error|missing_step|redundant_step|condition_reversed",
|
||||
"severity": "high|medium|low",
|
||||
"description": "用中文说明具体问题"
|
||||
}}
|
||||
]
|
||||
```
|
||||
|
||||
注意:输出必须是严格合法的 JSON 数组,不要有尾随逗号,不要包含代码块包裹符号。
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core verification logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def verify_parsed_json(parsed_path: str, *, use_llm: bool = False) -> dict:
|
||||
"""Load _parsed.json and verify all flowchart logic trees.
|
||||
|
||||
Returns a report dict with keys:
|
||||
- total_flowcharts: int
|
||||
- passed: int
|
||||
- failed: int
|
||||
- results: list of per-flowchart results
|
||||
"""
|
||||
with open(parsed_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
image_analysis = data.get("image_analysis", [])
|
||||
flowcharts = [img for img in image_analysis if img.get("type") == "flowchart"]
|
||||
|
||||
report = {
|
||||
"total_flowcharts": len(flowcharts),
|
||||
"passed": 0,
|
||||
"failed": 0,
|
||||
"results": [],
|
||||
}
|
||||
|
||||
llm = LLMClient() if use_llm else None
|
||||
|
||||
for img in flowcharts:
|
||||
rid = img.get("rid", "unknown")
|
||||
logger.info("Verifying flowchart: rid=%s", rid)
|
||||
|
||||
result = _verify_single(img, llm)
|
||||
report["results"].append(result)
|
||||
|
||||
if result["structural_ok"] and (not use_llm or result.get("llm_ok", True)):
|
||||
report["passed"] += 1
|
||||
else:
|
||||
report["failed"] += 1
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def verify_flowchart_file(filepath: str, *, use_llm: bool = False) -> dict:
|
||||
"""Load a standalone flowchart JSON file and verify it."""
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
tree = json.load(f)
|
||||
|
||||
img = {"logic_tree_nested": tree, "description": "", "rid": os.path.basename(filepath)}
|
||||
llm = LLMClient() if use_llm else None
|
||||
result = _verify_single(img, llm)
|
||||
|
||||
return {
|
||||
"total_flowcharts": 1,
|
||||
"passed": 1 if result["structural_ok"] else 0,
|
||||
"failed": 0 if result["structural_ok"] else 1,
|
||||
"results": [result],
|
||||
}
|
||||
|
||||
|
||||
def _verify_single(img: dict, llm: LLMClient | None) -> dict:
|
||||
"""Verify a single flowchart image analysis entry."""
|
||||
rid = img.get("rid", "unknown")
|
||||
description = img.get("description", "").strip()
|
||||
|
||||
# Try nested format first, fall back to flat format
|
||||
tree = img.get("logic_tree_nested") or img.get("logic_tree")
|
||||
if tree is None:
|
||||
return {
|
||||
"rid": rid,
|
||||
"structural_ok": False,
|
||||
"errors": ["No logic_tree found"],
|
||||
"paths_text": "",
|
||||
"llm_issues": [],
|
||||
}
|
||||
|
||||
# Check if it's the new nested format or old flat format
|
||||
is_nested = "children" in tree and isinstance(tree.get("children"), list)
|
||||
|
||||
# --- Level 1: Structural validation ---
|
||||
structural_ok = True
|
||||
errors: list[str] = []
|
||||
|
||||
if is_nested:
|
||||
ok, err = ImageParser._validate_flowchart(tree)
|
||||
if not ok:
|
||||
structural_ok = False
|
||||
errors.append(f"Structure: {err}")
|
||||
|
||||
# Extract paths
|
||||
paths = ImageParser.extract_paths(tree)
|
||||
paths_text = ImageParser.paths_to_text(paths)
|
||||
errors.append(f"Path count: {len(paths)}")
|
||||
else:
|
||||
# Old flat format — basic check
|
||||
nodes = tree.get("nodes", [])
|
||||
ids = [n.get("id", "") for n in nodes]
|
||||
if len(ids) != len(set(ids)):
|
||||
structural_ok = False
|
||||
errors.append("Structure: duplicate node ids in flat format")
|
||||
|
||||
# Build simple path-like text for flat format
|
||||
paths_text = _flat_to_text(tree)
|
||||
|
||||
# --- Level 2: Path count sanity check ---
|
||||
if is_nested and len(paths) == 0:
|
||||
structural_ok = False
|
||||
errors.append("No paths extracted from tree")
|
||||
|
||||
# --- Level 3: LLM consistency check ---
|
||||
llm_issues: list[dict] = []
|
||||
llm_ok = True
|
||||
if llm and description and paths_text:
|
||||
prompt = PROMPT_VERIFY_PATHS.format(
|
||||
paths_text=paths_text,
|
||||
description=description,
|
||||
)
|
||||
try:
|
||||
raw = llm.chat(
|
||||
model=LLMClient.TEXT_MODEL,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
)
|
||||
llm_issues = _parse_llm_issues(raw)
|
||||
if llm_issues:
|
||||
llm_ok = False
|
||||
errors.append(f"LLM found {len(llm_issues)} issue(s)")
|
||||
except RuntimeError as e:
|
||||
errors.append(f"LLM check failed: {e}")
|
||||
|
||||
return {
|
||||
"rid": rid,
|
||||
"structural_ok": structural_ok,
|
||||
"errors": errors,
|
||||
"paths_text": paths_text,
|
||||
"llm_ok": llm_ok,
|
||||
"llm_issues": llm_issues,
|
||||
}
|
||||
|
||||
|
||||
def _flat_to_text(tree: dict) -> str:
|
||||
"""Build path-like text from old flat-format logic_tree."""
|
||||
nodes = tree.get("nodes", [])
|
||||
root = tree.get("root", "")
|
||||
lines = [f"Root: {root}"]
|
||||
|
||||
node_map = {n["id"]: n for n in nodes}
|
||||
|
||||
def _trace(node_id: str, visited: set, path: list[str]) -> list[str]:
|
||||
if node_id in visited:
|
||||
path.append(f"[循环] {node_id}")
|
||||
return path
|
||||
visited.add(node_id)
|
||||
node = node_map.get(node_id)
|
||||
if node is None:
|
||||
path.append(f"[缺失] {node_id}")
|
||||
return path
|
||||
ntype = node.get("type", "")
|
||||
if ntype == "decision":
|
||||
cond = node.get("condition", "")
|
||||
for b in node.get("branches", []):
|
||||
val = b.get("value", "")
|
||||
tgt = b.get("target", "")
|
||||
new_path = path + [f"[判断] {cond} → {val}"]
|
||||
_trace(tgt, visited.copy(), new_path)
|
||||
elif ntype == "end":
|
||||
path.append(f"[结束] {node.get('description', '')}")
|
||||
lines.append(" -> ".join(path))
|
||||
else:
|
||||
path.append(f"[{ntype}] {node.get('description', '')}")
|
||||
# Flat format doesn't have explicit children for non-decision nodes
|
||||
# so we can't trace further
|
||||
lines.append(" -> ".join(path))
|
||||
return path
|
||||
|
||||
# Try to find start nodes
|
||||
starts = [n for n in nodes if n.get("type") == "start"]
|
||||
if starts:
|
||||
for s in starts:
|
||||
_trace(s["id"], set(), [])
|
||||
else:
|
||||
lines.append("(Cannot trace: no start node in flat format)")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _parse_llm_issues(content: str) -> list[dict]:
|
||||
"""Parse LLM response for path consistency issues."""
|
||||
stripped = content.strip()
|
||||
if "[[PATHS_CONSISTENT]]" in stripped:
|
||||
return []
|
||||
|
||||
# Remove markdown code fences
|
||||
if "```json" in stripped:
|
||||
stripped = stripped.split("```json", 1)[1]
|
||||
if "```" in stripped:
|
||||
stripped = stripped.split("```", 1)[0]
|
||||
elif "```" in stripped:
|
||||
stripped = stripped.split("```", 1)[1]
|
||||
if "```" in stripped:
|
||||
stripped = stripped.split("```", 1)[0]
|
||||
|
||||
stripped = stripped.strip()
|
||||
if not stripped:
|
||||
return []
|
||||
|
||||
try:
|
||||
issues = json.loads(stripped)
|
||||
if isinstance(issues, list):
|
||||
return issues
|
||||
return []
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Failed to parse LLM issues: %s", stripped[:200])
|
||||
return []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Report rendering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def print_report(report: dict) -> str:
|
||||
"""Print a human-readable verification report and return it as a string."""
|
||||
lines: list[str] = []
|
||||
lines.append("=" * 60)
|
||||
lines.append("流程图校验报告")
|
||||
lines.append("=" * 60)
|
||||
lines.append(f"流程图总数: {report['total_flowcharts']}")
|
||||
lines.append(f"通过: {report['passed']}")
|
||||
lines.append(f"失败: {report['failed']}")
|
||||
|
||||
overall = "PASS" if report["failed"] == 0 else "FAIL"
|
||||
lines.append(f"总体结果: {overall}")
|
||||
lines.append("")
|
||||
|
||||
for i, r in enumerate(report["results"], 1):
|
||||
rid = r["rid"]
|
||||
status = "[PASS]" if r["structural_ok"] else "[FAIL]"
|
||||
lines.append(f"[{i}] rid={rid} {status}")
|
||||
for err in r.get("errors", []):
|
||||
lines.append(f" - {err}")
|
||||
|
||||
if r.get("paths_text"):
|
||||
lines.append(" 路径:")
|
||||
for path_line in r["paths_text"].split("\n"):
|
||||
lines.append(f" {path_line}")
|
||||
|
||||
llm_issues = r.get("llm_issues", [])
|
||||
if llm_issues:
|
||||
lines.append(" LLM发现的问题:")
|
||||
for issue in llm_issues:
|
||||
lines.append(f" [{issue.get('severity', '?')}] {issue.get('description', '')}")
|
||||
lines.append("")
|
||||
|
||||
report_text = "\n".join(lines)
|
||||
print(report_text)
|
||||
return report_text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Verify flowchart logic trees for correctness.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"input", metavar="FILE",
|
||||
help="Path to _parsed.json or standalone flowchart JSON",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm", action="store_true",
|
||||
help="Run LLM consistency check (compares paths against text description)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-report", metavar="PATH",
|
||||
help="Save verification report to a file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine input type
|
||||
with open(args.input, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if "image_analysis" in data:
|
||||
report = verify_parsed_json(args.input, use_llm=args.llm)
|
||||
else:
|
||||
report = verify_flowchart_file(args.input, use_llm=args.llm)
|
||||
|
||||
report_text = print_report(report)
|
||||
|
||||
if args.output_report:
|
||||
with open(args.output_report, "w", encoding="utf-8") as f:
|
||||
f.write(report_text)
|
||||
logger.info("Report saved: %s", args.output_report)
|
||||
|
||||
# Exit code: 0 for PASS, 1 for FAIL
|
||||
if report["failed"] > 0:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,9 @@
|
||||
# Generated output
|
||||
output/
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
# Console log
|
||||
Console output.txt
|
||||
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Shared configuration for the IR Generation pipeline.
|
||||
Reads API keys from a secrets.yaml file, falling back to environment variables.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import yaml
|
||||
|
||||
# ---- Paths ----
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
WORKSPACE_DIR = os.path.dirname(BASE_DIR)
|
||||
DOC_PARSER_OUTPUT = os.path.join(WORKSPACE_DIR, "doc_parser_skill", "output")
|
||||
PROMPTS_DIR = os.path.join(BASE_DIR, "prompts")
|
||||
TESTS_DIR = os.path.join(BASE_DIR, "tests")
|
||||
OUTPUT_DIR = os.path.join(BASE_DIR, "output")
|
||||
|
||||
# Input file (the parsed PRD JSON)
|
||||
_DEFAULT_INPUT = os.path.join(
|
||||
DOC_PARSER_OUTPUT,
|
||||
"车机娱乐系统禁止功能文档_脱敏 v0.9_v2_updated.json",
|
||||
)
|
||||
INPUT_JSON = os.environ.get("IR_INPUT_JSON", _DEFAULT_INPUT)
|
||||
|
||||
|
||||
def set_input_file(path: str) -> None:
|
||||
"""Override the default input JSON path."""
|
||||
global INPUT_JSON
|
||||
INPUT_JSON = path
|
||||
|
||||
# Secrets file (shared with workspace-document-analyzer)
|
||||
# .openclaw/workspace/skills/ir_generation_new_skill -> .openclaw/workspace-document-analyzer
|
||||
OPENCLAW_HOME = os.path.dirname(os.path.dirname(WORKSPACE_DIR))
|
||||
SECRETS_YAML = os.path.join(
|
||||
OPENCLAW_HOME, "workspace-document-analyzer", "config", "secrets.yaml",
|
||||
)
|
||||
|
||||
# Intermediate outputs
|
||||
SEMANTIC_INDEX_R1_JSON = os.path.join(OUTPUT_DIR, "semantic_index_r1.json")
|
||||
SEMANTIC_INDEX_R2_JSON = os.path.join(OUTPUT_DIR, "semantic_index_r2.json")
|
||||
SEMANTIC_INDEX_R3_JSON = os.path.join(OUTPUT_DIR, "semantic_index_r3.json")
|
||||
SEMANTIC_INDEX_JSON = os.path.join(OUTPUT_DIR, "semantic_index.json") # merged final
|
||||
IR_FRAGMENTS_JSON = os.path.join(OUTPUT_DIR, "ir_fragments.json")
|
||||
PATH_ENUM_JSON = os.path.join(OUTPUT_DIR, "path_enumeration.json")
|
||||
IR_AUTOCOMPLETE_FRAGMENTS_JSON = os.path.join(OUTPUT_DIR, "ir_autocomplete_fragments.json")
|
||||
|
||||
# Final deliverables (placed in doc_parser output per spec)
|
||||
IR_FINAL_JSON = os.path.join(DOC_PARSER_OUTPUT, "ir_final.json")
|
||||
IR_AUDIT_REPORT_MD = os.path.join(DOC_PARSER_OUTPUT, "ir_audit_report.md")
|
||||
|
||||
# ---- LLM API ----
|
||||
# Choose provider: "deepseek" | "dashscope"
|
||||
LLM_PROVIDER = os.environ.get("IR_PROVIDER", "deepseek")
|
||||
|
||||
# Model names per provider
|
||||
PROVIDER_MODELS = {
|
||||
"deepseek": os.environ.get("IR_MODEL", "deepseek-v4-flash"),
|
||||
"dashscope": os.environ.get("IR_MODEL", "qwen-max"),
|
||||
}
|
||||
MODEL_NAME = PROVIDER_MODELS.get(LLM_PROVIDER, PROVIDER_MODELS["deepseek"])
|
||||
|
||||
# Maximum tokens for LLM responses
|
||||
MAX_TOKENS = int(os.environ.get("IR_MAX_TOKENS", "16000"))
|
||||
TEMPERATURE = float(os.environ.get("IR_TEMPERATURE", "0.1"))
|
||||
|
||||
# ---- Iteration & Quality ----
|
||||
MAX_RETRIES_PER_STAGE = int(os.environ.get("IR_MAX_RETRIES", "3"))
|
||||
COVERAGE_TARGET = float(os.environ.get("IR_COVERAGE_TARGET", "0.95"))
|
||||
|
||||
# Stage 1 ensemble temperatures (parallel multi-temperature generation)
|
||||
ENSEMBLE_TEMPERATURES = [
|
||||
float(os.environ.get("IR_ENSEMBLE_T1", "0.0")),
|
||||
float(os.environ.get("IR_ENSEMBLE_T2", "0.3")),
|
||||
float(os.environ.get("IR_ENSEMBLE_T3", "0.7")),
|
||||
]
|
||||
|
||||
|
||||
def _load_secrets() -> dict[str, dict[str, str]]:
|
||||
"""Load provider credentials from secrets.yaml.
|
||||
|
||||
Returns a dict like: {"deepseek": {"apiKey": "...", "baseUrl": "..."}, ...}
|
||||
"""
|
||||
if os.path.isfile(SECRETS_YAML):
|
||||
with open(SECRETS_YAML, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
return {}
|
||||
|
||||
|
||||
def _get_provider_config(provider: str) -> dict[str, str]:
|
||||
"""Get {apiKey, baseUrl} for a provider from secrets, with env-var fallback."""
|
||||
secrets = _load_secrets()
|
||||
entry = secrets.get(provider, {})
|
||||
|
||||
env_prefix = provider.upper()
|
||||
api_key = (
|
||||
os.environ.get(f"{env_prefix}_API_KEY")
|
||||
or entry.get("apiKey", "")
|
||||
)
|
||||
base_url = (
|
||||
os.environ.get(f"{env_prefix}_BASE_URL")
|
||||
or entry.get("baseUrl", "https://api.deepseek.com/v1")
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
raise RuntimeError(
|
||||
f"No API key found for provider '{provider}'. "
|
||||
f"Check {SECRETS_YAML} or set {env_prefix}_API_KEY."
|
||||
)
|
||||
return {"apiKey": api_key, "baseUrl": base_url}
|
||||
|
||||
|
||||
def llm_client():
|
||||
"""Return an OpenAI-compatible client configured from secrets.yaml."""
|
||||
from openai import OpenAI
|
||||
|
||||
cfg = _get_provider_config(LLM_PROVIDER)
|
||||
return OpenAI(base_url=cfg["baseUrl"], api_key=cfg["apiKey"])
|
||||
|
||||
|
||||
def load_input_document(path: str | None = None) -> dict:
|
||||
"""Load the parsed PRD JSON document."""
|
||||
path = path or INPUT_JSON
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_json(data, path: str) -> None:
|
||||
"""Save data as formatted JSON."""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def load_json(path: str) -> dict:
|
||||
"""Load a JSON file."""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
@@ -0,0 +1,593 @@
|
||||
"""
|
||||
Deterministic ensemble merge for semantic index generation.
|
||||
|
||||
All functions are pure Python with zero LLM calls. Fully testable with mock data.
|
||||
|
||||
Cross-references N semantic_index outputs (generated with different temperatures)
|
||||
and produces a single merged index with confidence scores.
|
||||
|
||||
Used by: step1_semantic_index.py
|
||||
Tested by: tests/test_ensemble_merge.py
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Concept Name Similarity
|
||||
# =============================================================================
|
||||
|
||||
def concept_name_similarity(name_a: str, name_b: str) -> float:
|
||||
"""Compute similarity between two concept names for cross-version matching.
|
||||
|
||||
Strategy (in order of precedence):
|
||||
1. Exact string match -> 1.0
|
||||
2. Substring containment (one is a substring of the other) -> 0.9
|
||||
3. SequenceMatcher ratio on character sequences -> 0.0-1.0
|
||||
|
||||
Returns:
|
||||
float in [0.0, 1.0] where >= 0.7 means "likely the same concept".
|
||||
"""
|
||||
if name_a == name_b:
|
||||
return 1.0
|
||||
|
||||
# Substring containment: one name is contained in the other
|
||||
if name_a in name_b or name_b in name_a:
|
||||
# Only count as similar if they're of comparable length
|
||||
# (avoid matching "国内" with "国内行车娱乐限制")
|
||||
len_ratio = min(len(name_a), len(name_b)) / max(len(name_a), len(name_b))
|
||||
if len_ratio >= 0.5:
|
||||
return 0.85 + 0.05 * len_ratio # range 0.875-0.90
|
||||
return 0.55 # too different in length → below threshold
|
||||
|
||||
return SequenceMatcher(None, name_a, name_b).ratio()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Concept Clustering & Merging
|
||||
# =============================================================================
|
||||
|
||||
def cluster_concepts(
|
||||
all_concepts_lists: list[list[dict]],
|
||||
similarity_threshold: float = 0.7,
|
||||
) -> list[list[tuple[int, dict]]]:
|
||||
"""Group concepts across ensemble versions by name similarity.
|
||||
|
||||
Uses greedy single-pass clustering: for each concept, find the best-matching
|
||||
existing cluster. If max similarity >= threshold, add to it; otherwise,
|
||||
create a new cluster.
|
||||
|
||||
Args:
|
||||
all_concepts_lists: List of concept lists, one per ensemble version.
|
||||
all_concepts_lists[i] = concepts from version i.
|
||||
similarity_threshold: Minimum name similarity to join a cluster.
|
||||
|
||||
Returns:
|
||||
List of clusters. Each cluster is list of (version_idx, concept_dict).
|
||||
"""
|
||||
clusters = [] # type: list[list[tuple[int, dict]]]
|
||||
|
||||
for version_idx, concepts in enumerate(all_concepts_lists):
|
||||
for c in concepts:
|
||||
name = c.get("name", "")
|
||||
if not name:
|
||||
continue
|
||||
|
||||
best_cluster = None
|
||||
best_sim = 0.0
|
||||
|
||||
for cluster in clusters:
|
||||
# Compare against the first member of the cluster (seed)
|
||||
seed_name = cluster[0][1].get("name", "")
|
||||
sim = concept_name_similarity(name, seed_name)
|
||||
if sim > best_sim:
|
||||
best_sim = sim
|
||||
best_cluster = cluster
|
||||
|
||||
if best_cluster is not None and best_sim >= similarity_threshold:
|
||||
best_cluster.append((version_idx, c))
|
||||
else:
|
||||
clusters.append([(version_idx, c)])
|
||||
|
||||
return clusters
|
||||
|
||||
|
||||
def merge_concept_cluster(
|
||||
cluster: list[tuple[int, dict]],
|
||||
total_versions: int,
|
||||
) -> tuple[dict, str]:
|
||||
"""Merge a single cluster of matched concepts into one concept dict.
|
||||
|
||||
Rules:
|
||||
- name: Longest name (most specific). Tie-break by lower version_idx.
|
||||
- aliases: Union of all aliases across versions.
|
||||
- defined_in: Union of all defined_in across versions.
|
||||
- parent: Most common non-null parent (voting). Tie-break by lower version_idx.
|
||||
|
||||
Returns:
|
||||
(merged_concept_dict, confidence_level) where confidence is "high"/"medium"/"low".
|
||||
"""
|
||||
if not cluster:
|
||||
return {}, "low"
|
||||
|
||||
# --- name: longest (most specific) ---
|
||||
best_name = ""
|
||||
best_name_len = 0
|
||||
for v_idx, c in cluster:
|
||||
n = c.get("name", "")
|
||||
if len(n) > best_name_len:
|
||||
best_name = n
|
||||
best_name_len = len(n)
|
||||
elif len(n) == best_name_len and v_idx < cluster[0][0]: # lower version idx
|
||||
best_name = n
|
||||
|
||||
# --- aliases: union ---
|
||||
aliases = set()
|
||||
for _, c in cluster:
|
||||
for a in c.get("aliases", []):
|
||||
aliases.add(a)
|
||||
|
||||
# --- defined_in: union ---
|
||||
defined_in = set()
|
||||
for _, c in cluster:
|
||||
for d in c.get("defined_in", []):
|
||||
defined_in.add(d)
|
||||
|
||||
# --- parent: most common non-null parent (vote) ---
|
||||
parent_votes = defaultdict(int)
|
||||
for v_idx, c in cluster:
|
||||
p = c.get("parent")
|
||||
if p is not None:
|
||||
parent_votes[p] += 1
|
||||
|
||||
if parent_votes:
|
||||
best_parent = max(parent_votes, key=lambda p: (parent_votes[p], -1))
|
||||
else:
|
||||
best_parent = None
|
||||
|
||||
# --- confidence ---
|
||||
versions_present = len({v_idx for v_idx, _ in cluster})
|
||||
confidence = compute_confidence_versions(versions_present, total_versions,
|
||||
any(v_idx == 0 for v_idx, _ in cluster))
|
||||
|
||||
merged = {
|
||||
"name": best_name,
|
||||
"aliases": sorted(aliases),
|
||||
"defined_in": sorted(defined_in),
|
||||
"parent": best_parent,
|
||||
"confidence": confidence,
|
||||
}
|
||||
return merged, confidence
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Unit Similarity Functions
|
||||
# =============================================================================
|
||||
|
||||
def _collect_logic_tree_nodes(unit: dict) -> set[str]:
|
||||
"""Extract the flattened set of all logic tree node IDs from a function_unit."""
|
||||
nodes = set()
|
||||
for src in unit.get("sources", []):
|
||||
if src.get("type") == "logic_tree":
|
||||
nodes.update(src.get("logic_tree_nodes", []))
|
||||
return nodes
|
||||
|
||||
|
||||
def unit_node_jaccard(unit_a: dict, unit_b: dict) -> float:
|
||||
"""Compute Jaccard similarity on logic tree node sets between two units.
|
||||
|
||||
Jaccard(A, B) = |A ∩ B| / |A ∪ B|. Returns 0.0 if both have no nodes.
|
||||
"""
|
||||
nodes_a = _collect_logic_tree_nodes(unit_a)
|
||||
nodes_b = _collect_logic_tree_nodes(unit_b)
|
||||
|
||||
if not nodes_a and not nodes_b:
|
||||
return 0.0
|
||||
if not nodes_a or not nodes_b:
|
||||
return 0.0
|
||||
|
||||
intersection = nodes_a & nodes_b
|
||||
union = nodes_a | nodes_b
|
||||
return len(intersection) / len(union)
|
||||
|
||||
|
||||
def path_similarity(path_a: list[str], path_b: list[str]) -> float:
|
||||
"""Compute similarity between two path arrays.
|
||||
|
||||
Hybrid approach:
|
||||
- Sequential similarity (order-aware): SequenceMatcher on joined strings.
|
||||
- Set similarity (order-independent): Jaccard on path element sets.
|
||||
- Final score: 0.5 * seq_sim + 0.5 * set_sim
|
||||
|
||||
Returns:
|
||||
float in [0.0, 1.0].
|
||||
"""
|
||||
if not path_a and not path_b:
|
||||
return 1.0
|
||||
if not path_a or not path_b:
|
||||
return 0.0
|
||||
|
||||
# Sequential similarity
|
||||
joined_a = "|".join(path_a)
|
||||
joined_b = "|".join(path_b)
|
||||
seq_sim = SequenceMatcher(None, joined_a, joined_b).ratio()
|
||||
|
||||
# Set similarity
|
||||
set_a = set(path_a)
|
||||
set_b = set(path_b)
|
||||
set_sim = len(set_a & set_b) / len(set_a | set_b)
|
||||
|
||||
return 0.5 * seq_sim + 0.5 * set_sim
|
||||
|
||||
|
||||
def unit_similarity(unit_a: dict, unit_b: dict) -> float:
|
||||
"""Combined similarity between two function_units.
|
||||
|
||||
Weighted combination:
|
||||
- 0.6 * unit_node_jaccard (primary signal: same logic tree nodes = same rule)
|
||||
- 0.4 * path_similarity (secondary signal: semantic agreement)
|
||||
|
||||
Returns:
|
||||
float in [0.0, 1.0]. >= 0.5 means "likely the same function_unit".
|
||||
"""
|
||||
return 0.6 * unit_node_jaccard(unit_a, unit_b) + 0.4 * path_similarity(
|
||||
unit_a.get("path", []), unit_b.get("path", [])
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Function Unit Clustering & Merging
|
||||
# =============================================================================
|
||||
|
||||
def cluster_function_units(
|
||||
all_units_lists: list[list[dict]],
|
||||
similarity_threshold: float = 0.5,
|
||||
) -> list[list[tuple[int, dict]]]:
|
||||
"""Group function_units across ensemble versions by content similarity.
|
||||
|
||||
Lowest-temperature versions are processed first (most stable → cluster seeds).
|
||||
Higher-temperature variants join existing clusters if similar enough.
|
||||
|
||||
Args:
|
||||
all_units_lists: List of unit lists, one per ensemble version.
|
||||
similarity_threshold: Minimum unit_similarity to join a cluster.
|
||||
|
||||
Returns:
|
||||
List of clusters. Each cluster is list of (version_idx, unit_dict).
|
||||
"""
|
||||
clusters = [] # type: list[list[tuple[int, dict]]]
|
||||
|
||||
for version_idx, units in enumerate(all_units_lists):
|
||||
for unit in units:
|
||||
best_cluster = None
|
||||
best_sim = 0.0
|
||||
|
||||
for cluster in clusters:
|
||||
# Compare against all members already in the cluster
|
||||
cluster_sim = max(
|
||||
unit_similarity(unit, existing_unit)
|
||||
for (_, existing_unit) in cluster
|
||||
)
|
||||
if cluster_sim > best_sim:
|
||||
best_sim = cluster_sim
|
||||
best_cluster = cluster
|
||||
|
||||
if best_cluster is not None and best_sim >= similarity_threshold:
|
||||
best_cluster.append((version_idx, unit))
|
||||
else:
|
||||
clusters.append([(version_idx, unit)])
|
||||
|
||||
return clusters
|
||||
|
||||
|
||||
def pick_best_representative(
|
||||
cluster: list[tuple[int, dict]],
|
||||
) -> dict:
|
||||
"""Select the best function_unit from a cluster as the merged representative.
|
||||
|
||||
Scoring formula (all normalized to [0, 1]):
|
||||
- 0.35: Node count (more logic_tree_nodes = more complete trace)
|
||||
- 0.25: Source count (more sources = more evidence)
|
||||
- 0.20: Description length (longer = more detail, capped at 500 chars)
|
||||
- 0.20: Temperature rank (lower version_idx = lower temp = more stable)
|
||||
|
||||
Returns a deep copy of the winning unit dict.
|
||||
"""
|
||||
if not cluster:
|
||||
return {}
|
||||
|
||||
# Compute max values for normalization
|
||||
max_nodes = max(
|
||||
len(_collect_logic_tree_nodes(unit)) for _, unit in cluster
|
||||
)
|
||||
max_sources = max(
|
||||
len(unit.get("sources", [])) for _, unit in cluster
|
||||
)
|
||||
max_desc_len = max(
|
||||
len(unit.get("description", "")) for _, unit in cluster
|
||||
)
|
||||
max_version_idx = max(v_idx for v_idx, _ in cluster)
|
||||
num_versions = len(cluster)
|
||||
|
||||
def score(v_idx: int, unit: dict) -> float:
|
||||
nodes = len(_collect_logic_tree_nodes(unit))
|
||||
sources = len(unit.get("sources", []))
|
||||
desc_len = min(len(unit.get("description", "")), 500)
|
||||
temp_rank = 1.0 - (v_idx / max(num_versions, max_version_idx + 1))
|
||||
|
||||
return (
|
||||
0.35 * (nodes / max(1, max_nodes))
|
||||
+ 0.25 * (sources / max(1, max_sources))
|
||||
+ 0.20 * (desc_len / max(1, max_desc_len))
|
||||
+ 0.20 * temp_rank
|
||||
)
|
||||
|
||||
best = max(cluster, key=lambda x: score(x[0], x[1]))
|
||||
return dict(best[1]) # deep-ish copy (1 level)
|
||||
|
||||
|
||||
def merge_unit_sources(
|
||||
cluster: list[tuple[int, dict]],
|
||||
) -> list[dict]:
|
||||
"""Union all sources from units in a cluster, deduplicating by (type, image_id, section).
|
||||
|
||||
When the same source key appears in multiple versions, keeps the one with
|
||||
the most logic_tree_nodes.
|
||||
"""
|
||||
# Group by dedup key
|
||||
source_groups = defaultdict(list)
|
||||
|
||||
for v_idx, unit in cluster:
|
||||
for src in unit.get("sources", []):
|
||||
# Build a dedup key
|
||||
src_type = src.get("type", "")
|
||||
if src_type == "logic_tree":
|
||||
key = ("logic_tree", src.get("image_id", ""))
|
||||
else:
|
||||
key = (src_type, src.get("section", ""), src.get("row", ""))
|
||||
|
||||
source_groups[key].append(src)
|
||||
|
||||
# Pick best per group
|
||||
result = []
|
||||
for key, sources in source_groups.items():
|
||||
# Pick the source with the most logic_tree_nodes (if any)
|
||||
best = max(sources, key=lambda s: len(s.get("logic_tree_nodes", [])))
|
||||
result.append(dict(best))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def compute_confidence_versions(
|
||||
versions_present: int,
|
||||
total_versions: int,
|
||||
includes_lowest_temp: bool = False,
|
||||
) -> str:
|
||||
"""Compute 3-level confidence based on cross-version agreement.
|
||||
|
||||
- "high": Appears in all versions, OR >= 2/3 with lowest-temp version (T=0.0).
|
||||
- "medium": Appears in >= half the versions but not all.
|
||||
- "low": Appears in fewer than half (singleton in ensemble).
|
||||
|
||||
Args:
|
||||
versions_present: Number of versions this item appeared in.
|
||||
total_versions: Total number of ensemble versions.
|
||||
includes_lowest_temp: Whether the item appeared in the T=0.0 version.
|
||||
"""
|
||||
ratio = versions_present / total_versions
|
||||
|
||||
if ratio >= 1.0:
|
||||
return "high"
|
||||
if ratio >= 0.5 and includes_lowest_temp:
|
||||
return "high"
|
||||
if ratio >= 0.5:
|
||||
return "medium"
|
||||
return "low"
|
||||
|
||||
|
||||
def ensemble_merge_concepts(
|
||||
all_concepts_lists: list[list[dict]],
|
||||
) -> list[dict]:
|
||||
"""Merge concepts across all ensemble versions.
|
||||
|
||||
Returns:
|
||||
List of merged concept dicts, each with added "confidence" field.
|
||||
"""
|
||||
total = len(all_concepts_lists)
|
||||
clusters = cluster_concepts(all_concepts_lists)
|
||||
merged = []
|
||||
seen_names = set()
|
||||
|
||||
for cluster in clusters:
|
||||
concept, confidence = merge_concept_cluster(cluster, total)
|
||||
name = concept.get("name", "")
|
||||
if name and name not in seen_names:
|
||||
concept["ensemble_support"] = f"{len({v for v, _ in cluster})}/{total}"
|
||||
merged.append(concept)
|
||||
seen_names.add(name)
|
||||
|
||||
# Sort: high confidence first, then by name
|
||||
conf_order = {"high": 0, "medium": 1, "low": 2}
|
||||
merged.sort(key=lambda c: (conf_order.get(c.get("confidence", "low"), 3), c.get("name", "")))
|
||||
|
||||
# Validate and fix parent references
|
||||
merged = _validate_concept_parents(merged)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _validate_concept_parents(concepts: list[dict]) -> list[dict]:
|
||||
"""Post-merge: validate that every concept's parent exists in the list.
|
||||
|
||||
Strategy for dangling parents:
|
||||
1. Fuzzy match (concept_name_similarity >= 0.7) → fix reference
|
||||
2. No match → set parent to null, downgrade confidence to "low"
|
||||
"""
|
||||
concept_names = {c["name"] for c in concepts}
|
||||
conf_order = {"high": 0, "medium": 1, "low": 2}
|
||||
|
||||
for c in concepts:
|
||||
parent = c.get("parent")
|
||||
if parent is None:
|
||||
continue
|
||||
if parent in concept_names:
|
||||
continue
|
||||
|
||||
# Dangling parent — try fuzzy match
|
||||
best_match = None
|
||||
best_sim = 0.0
|
||||
for name in concept_names:
|
||||
sim = concept_name_similarity(parent, name)
|
||||
if sim > best_sim:
|
||||
best_sim = sim
|
||||
best_match = name
|
||||
|
||||
if best_match and best_sim >= 0.7:
|
||||
c["parent"] = best_match
|
||||
# Downgrade if match was fuzzy (not exact)
|
||||
if best_sim < 1.0:
|
||||
current_conf = c.get("confidence", "low")
|
||||
c["confidence"] = _downgrade_confidence(current_conf)
|
||||
else:
|
||||
c["parent"] = None
|
||||
c["confidence"] = _downgrade_confidence(c.get("confidence", "low"))
|
||||
|
||||
# Re-sort after confidence changes
|
||||
concepts.sort(key=lambda c: (conf_order.get(c.get("confidence", "low"), 3), c.get("name", "")))
|
||||
return concepts
|
||||
|
||||
|
||||
def _downgrade_confidence(current: str) -> str:
|
||||
"""Drop confidence one level."""
|
||||
if current == "high":
|
||||
return "medium"
|
||||
return "low"
|
||||
|
||||
|
||||
def ensemble_merge_function_units(
|
||||
all_units_lists: list[list[dict]],
|
||||
) -> list[dict]:
|
||||
"""Merge function_units across all ensemble versions.
|
||||
|
||||
1. Cluster units across versions.
|
||||
2. For each cluster: pick best, merge sources, compute confidence.
|
||||
3. Reassign stable unit_ids: FU-ENS-001, FU-ENS-002, ...
|
||||
|
||||
Returns:
|
||||
List of merged function_unit dicts with added "confidence",
|
||||
"ensemble_support", "source_versions" fields.
|
||||
"""
|
||||
total = len(all_units_lists)
|
||||
clusters = cluster_function_units(all_units_lists)
|
||||
|
||||
merged = []
|
||||
for cluster in clusters:
|
||||
# Pick best representative
|
||||
best = pick_best_representative(cluster)
|
||||
|
||||
# Merge sources from all cluster members
|
||||
best["sources"] = merge_unit_sources(cluster)
|
||||
|
||||
# Compute confidence
|
||||
versions_present = len({v_idx for v_idx, _ in cluster})
|
||||
includes_t0 = any(v_idx == 0 for v_idx, _ in cluster)
|
||||
confidence = compute_confidence_versions(
|
||||
versions_present, total, includes_t0
|
||||
)
|
||||
|
||||
best["confidence"] = confidence
|
||||
best["ensemble_support"] = f"{versions_present}/{total}"
|
||||
best["source_versions"] = versions_present
|
||||
|
||||
merged.append(best)
|
||||
|
||||
# Sort by confidence desc, then by unit_id
|
||||
conf_order = {"high": 0, "medium": 1, "low": 2}
|
||||
merged.sort(key=lambda u: (conf_order.get(u.get("confidence", "low"), 3),
|
||||
u.get("unit_id", "")))
|
||||
|
||||
# Reassign stable unit_ids
|
||||
for i, unit in enumerate(merged):
|
||||
# Preserve original unit_id for traceability
|
||||
if "original_unit_id" not in unit:
|
||||
unit["original_unit_id"] = unit.get("unit_id", "")
|
||||
unit["unit_id"] = f"FU-ENS-{i + 1:03d}"
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Top-Level Ensemble Merge
|
||||
# =============================================================================
|
||||
|
||||
def ensemble_merge(
|
||||
semantic_indices: list[dict],
|
||||
) -> dict:
|
||||
"""Merge N semantic index outputs into one ensemble result.
|
||||
|
||||
Args:
|
||||
semantic_indices: List of semantic_index dicts from each temperature run.
|
||||
semantic_indices[0] should be the lowest-temperature version.
|
||||
|
||||
Returns:
|
||||
Merged semantic_index dict with structure:
|
||||
{
|
||||
"feature_name": str,
|
||||
"ensemble_versions": int,
|
||||
"concepts": [...],
|
||||
"function_units": [...],
|
||||
"confidence_summary": {...},
|
||||
}
|
||||
"""
|
||||
if not semantic_indices:
|
||||
return {
|
||||
"feature_name": "",
|
||||
"ensemble_versions": 0,
|
||||
"concepts": [],
|
||||
"function_units": [],
|
||||
"confidence_summary": {},
|
||||
}
|
||||
|
||||
total = len(semantic_indices)
|
||||
|
||||
# Extract concepts and function_units from each version
|
||||
all_concepts = [si.get("concepts", []) for si in semantic_indices]
|
||||
all_units = [si.get("function_units", []) for si in semantic_indices]
|
||||
|
||||
# Merge
|
||||
merged_concepts = ensemble_merge_concepts(all_concepts)
|
||||
merged_units = ensemble_merge_function_units(all_units)
|
||||
|
||||
# Feature name: majority vote across versions
|
||||
feature_names = [si.get("feature_name", "") for si in semantic_indices]
|
||||
name_counts = defaultdict(int)
|
||||
for fn in feature_names:
|
||||
if fn:
|
||||
name_counts[fn] += 1
|
||||
feature_name = max(name_counts, key=name_counts.get) if name_counts else ""
|
||||
|
||||
# Confidence summary
|
||||
unit_conf = defaultdict(int)
|
||||
for u in merged_units:
|
||||
unit_conf[u.get("confidence", "low")] += 1
|
||||
concept_conf = defaultdict(int)
|
||||
for c in merged_concepts:
|
||||
concept_conf[c.get("confidence", "low")] += 1
|
||||
|
||||
return {
|
||||
"feature_name": feature_name,
|
||||
"ensemble_versions": total,
|
||||
"concepts": merged_concepts,
|
||||
"function_units": merged_units,
|
||||
"confidence_summary": {
|
||||
"total_units": len(merged_units),
|
||||
"high": unit_conf.get("high", 0),
|
||||
"medium": unit_conf.get("medium", 0),
|
||||
"low": unit_conf.get("low", 0),
|
||||
"total_concepts": len(merged_concepts),
|
||||
"concept_high": concept_conf.get("high", 0),
|
||||
"concept_medium": concept_conf.get("medium", 0),
|
||||
"concept_low": concept_conf.get("low", 0),
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
IR Generation Pipeline Orchestrator.
|
||||
|
||||
Run all four stages sequentially:
|
||||
python main.py [--skip-step1] [--skip-step2] [--skip-step2.5] [--skip-step3] [--test-only]
|
||||
|
||||
The pipeline reads the parsed PRD JSON from doc_parser and produces:
|
||||
- ir_final.json: the final IR rules
|
||||
- ir_audit_report.md: completeness audit report for human review
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import config
|
||||
|
||||
BASE_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def _subprocess_env(extra: dict | None = None) -> dict:
|
||||
"""Build environment dict for subprocesses, carrying forward overrides."""
|
||||
env = os.environ.copy()
|
||||
env.update(extra or {})
|
||||
return env
|
||||
|
||||
|
||||
def run_step(script_name: str, description: str, extra_env: dict | None = None) -> bool:
|
||||
"""Run a single pipeline step script, return True if it succeeded."""
|
||||
print(f"\n{'#' * 60}")
|
||||
print(f"# {description}")
|
||||
print(f"{'#' * 60}")
|
||||
script_path = BASE_DIR / script_name
|
||||
if not script_path.exists():
|
||||
print(f"错误: 脚本不存在 {script_path}")
|
||||
return False
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(script_path)],
|
||||
cwd=str(BASE_DIR),
|
||||
env=_subprocess_env(extra_env),
|
||||
)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def run_test(test_name: str, description: str, extra_env: dict | None = None) -> bool:
|
||||
"""Run a test script, return True if all tests passed."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"测试: {description}")
|
||||
print(f"{'='*60}")
|
||||
test_path = BASE_DIR / "tests" / test_name
|
||||
if not test_path.exists():
|
||||
print(f"错误: 测试脚本不存在 {test_path}")
|
||||
return False
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(test_path)],
|
||||
cwd=str(BASE_DIR),
|
||||
env=_subprocess_env(extra_env),
|
||||
)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="IR Generation Pipeline")
|
||||
parser.add_argument("--skip-step1", action="store_true",
|
||||
help="跳过阶段一(语义索引)")
|
||||
parser.add_argument("--skip-step2", action="store_true",
|
||||
help="跳过阶段二(IR 提取)")
|
||||
parser.add_argument("--skip-step2.5", "--skip-step2-5", action="store_true",
|
||||
dest="skip_step2_5",
|
||||
help="跳过阶段2.5(分支覆盖自动补全)")
|
||||
parser.add_argument("--skip-step3", action="store_true",
|
||||
help="跳过阶段三(合并与审计)")
|
||||
parser.add_argument("--test-only", action="store_true",
|
||||
help="仅运行测试,不调用 LLM")
|
||||
parser.add_argument(
|
||||
"--input", "-i", type=str, default=None,
|
||||
help="输入 JSON 文件路径(覆盖默认的 doc_parser 输出)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider", "-p", type=str, default=None,
|
||||
help="LLM provider: deepseek | dashscope(覆盖 IR_PROVIDER 环境变量)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Build extra env vars for subprocesses
|
||||
extra_env = {}
|
||||
if args.input:
|
||||
extra_env["IR_INPUT_JSON"] = args.input
|
||||
print(f"输入文件: {args.input}")
|
||||
if args.provider:
|
||||
extra_env["IR_PROVIDER"] = args.provider
|
||||
print(f"LLM Provider: {args.provider}")
|
||||
|
||||
if args.test_only:
|
||||
all_ok = True
|
||||
all_ok &= run_test("test_step1.py", "Step 1 验证", extra_env)
|
||||
all_ok &= run_test("test_step2.py", "Step 2 验证", extra_env)
|
||||
all_ok &= run_test("test_step2_5.py", "Step 2.5 验证", extra_env)
|
||||
all_ok &= run_test("test_step3.py", "Step 3 验证", extra_env)
|
||||
sys.exit(0 if all_ok else 1)
|
||||
|
||||
failures = []
|
||||
|
||||
# Stage 1
|
||||
if not args.skip_step1:
|
||||
ok = run_step("step1_semantic_index.py",
|
||||
"阶段一:宏观语义索引", extra_env)
|
||||
if not ok:
|
||||
failures.append("阶段一")
|
||||
print("\n阶段一失败,停止流水线。修复后重试。")
|
||||
sys.exit(1)
|
||||
run_test("test_step1.py", "Step 1 验证", extra_env)
|
||||
|
||||
# Stage 2
|
||||
if not args.skip_step2:
|
||||
ok = run_step("step2_ir_extraction.py",
|
||||
"阶段二:逐功能单元 IR 提取", extra_env)
|
||||
if not ok:
|
||||
failures.append("阶段二")
|
||||
print("\n阶段二失败,停止流水线。修复后重试。")
|
||||
sys.exit(1)
|
||||
run_test("test_step2.py", "Step 2 验证", extra_env)
|
||||
|
||||
# Stage 2.5
|
||||
if not args.skip_step2_5:
|
||||
ok = run_step("step2_5_branch_coverage.py",
|
||||
"阶段2.5:分支覆盖自动补全", extra_env)
|
||||
if not ok:
|
||||
failures.append("阶段2.5")
|
||||
print("\n阶段2.5失败,停止流水线。修复后重试。")
|
||||
sys.exit(1)
|
||||
run_test("test_step2_5.py", "Step 2.5 验证", extra_env)
|
||||
|
||||
# Stage 3
|
||||
if not args.skip_step3:
|
||||
ok = run_step("step3_merge_and_audit.py",
|
||||
"阶段三:确定性合并与完整性校验", extra_env)
|
||||
if not ok:
|
||||
failures.append("阶段三")
|
||||
sys.exit(1)
|
||||
run_test("test_step3.py", "Step 3 验证", extra_env)
|
||||
|
||||
if failures:
|
||||
print(f"\n失败阶段: {', '.join(failures)}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("流水线全部完成!")
|
||||
print(f"最终 IR: {config.IR_FINAL_JSON}")
|
||||
print(f"审计报告: {config.IR_AUDIT_REPORT_MD}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,46 @@
|
||||
## 上一轮遗漏分析
|
||||
|
||||
上一轮生成的语义索引经过自动校验,发现以下问题需要修正:
|
||||
|
||||
### 遗漏的逻辑树路径
|
||||
以下逻辑树决策路径未被任何 function_unit 覆盖,请为每条路径生成对应的 function_unit:
|
||||
{missing_paths}
|
||||
|
||||
### 遗漏的概念
|
||||
以下关键概念未在 concepts 列表中出现,请补充:
|
||||
{missing_concepts}
|
||||
|
||||
### 格式问题
|
||||
以下 function_unit 或 concept 的格式不符合要求:
|
||||
{format_issues}
|
||||
|
||||
### concept parent 问题
|
||||
以下概念的 parent 引用有问题(悬空引用或缺少 parent):
|
||||
{parent_issues}
|
||||
|
||||
---
|
||||
|
||||
请在本次生成中针对以上问题进行修正。注意:
|
||||
1. 你不需要从头生成完整的语义索引,只需要输出**补充和修正**的部分
|
||||
2. function_units 的输出应只包含本次新增或修正的单元(已有的正确单元不需要重复)
|
||||
3. concepts 的输出应只包含本次新增或修正的概念
|
||||
4. 如果格式问题中提到"空壳单元":删除该 unit,或将其合并到包含实际 action 的 unit 中。纯开关状态不是独立的功能行为
|
||||
5. 如果格式问题中提到"不构成有效路径":说明你引用了互斥分支上的节点。检查 logic_tree_nodes,确保它们都落在逻辑树的**同一条分支路径**上(例如 n4 是关闭分支,n8 是开启分支,不能共存)
|
||||
6. 如果格式问题提到"缺少 path"或"缺少 sources":补充对应字段
|
||||
|
||||
## 输出格式
|
||||
|
||||
只输出 JSON:
|
||||
|
||||
{
|
||||
"feature_name": "(与之前相同)",
|
||||
"supplemental_function_units": [
|
||||
// 只放新增的或修正的 function_unit
|
||||
],
|
||||
"supplemental_concepts": [
|
||||
// 只放新增的或修正的 concept
|
||||
],
|
||||
"corrections": {
|
||||
// 需要修正的已有项: { "unit_id或concept_name": { 修正后的字段 }, ... }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
你是吉利汽车车机系统(XX Auto)的产品需求分析师。你的任务是从行车娱乐限制功能 PRD 文档中提取"语义索引"——一份结构化、有层级的功能清单,而不是逐字翻译。
|
||||
|
||||
## 文档结构说明
|
||||
|
||||
下面是一份 Word 文档的解析结果,包含:
|
||||
|
||||
1. **sections**:按章节组织的混合内容(段落 + 表格),每个 section 有 `source`(章节标题)、`blocks`(`para` 文本段落和 `table` 结构表格)、`images`(引用的图片 ID 列表)
|
||||
2. **image_analysis**:文档中流程图的程序化分析结果,其中 `logic_tree` 是由节点组成的决策树:
|
||||
- `state` 节点:状态说明
|
||||
- `decision` 节点:判断条件 + `branches`(分支值 → 目标节点 ID)
|
||||
- `action` 节点:系统或用户交互动作
|
||||
3. **resolved_conflicts**:文档中图文冲突的仲裁结果,明确指出应以文字还是图片为准
|
||||
|
||||
## 文档全文
|
||||
|
||||
{document_json}
|
||||
|
||||
## 你的任务
|
||||
|
||||
阅读整份文档后,输出一份 **语义索引 JSON**,包含:
|
||||
|
||||
### 1. feature_name
|
||||
功能名称,如"行车娱乐限制"
|
||||
|
||||
### 2. concepts(带层级)
|
||||
文档中定义或使用的关键概念列表。每个概念包含:
|
||||
- `name`:概念的标准名称(必填)
|
||||
- `aliases`:同义词/别名列表(如"行车娱乐限制"、"行车娱乐禁止")
|
||||
- `defined_in`:定义该概念的章节号列表(如 ["3.1", "3.1.1"])
|
||||
- `parent`:父概念名称(字符串或 null)(必填)
|
||||
|
||||
**概念层级规则(重要)**:
|
||||
你必须按照以下 4 层结构组织概念,并为每个概念指定正确的 `parent`:
|
||||
- **Level 0(地理范围)**: "国内"、"海外" — parent 为 null
|
||||
- **Level 1(功能)**: "行车娱乐限制"、"行车娱乐禁止" — parent 为对应的 scope(如 "国内")
|
||||
- **Level 2(限制方式)**: "系统限制"、"SDK限制"、"其他应用" — parent 为对应的 feature
|
||||
- **Level 3(具体行为)**: "前台打断"、"后台限制启动"、"后台暂停功能"、"无限制" — parent 为对应的 method
|
||||
|
||||
除了以上层级,还可以有"行车娱乐限制开关"、"车速条件"、"档位条件"、"Toast提示"等辅助概念,它们应有合理的 parent。
|
||||
|
||||
**重要约束:每个 concept 的 parent 值必须是 concepts 列表中已存在的另一个 concept 的 name,或者是 null。禁止引用不存在的概念名。**
|
||||
|
||||
### 3. function_units(带路径)
|
||||
文档中描述的所有主要功能行为的列表。**每个 function_unit 对应逻辑树中的一条叶子路径**。每个 function unit 包含:
|
||||
|
||||
- `unit_id`:唯一标识,格式 "FU-001", "FU-002"...
|
||||
- `name`:简短名称,如"国内-系统限制-前台-行车打断"
|
||||
- `description`:1-3 句描述该规则的行为
|
||||
- `path`:层级路径数组,从高到低,如 `["国内", "系统限制", "前台打断"]`(必填)。**path 中的每个元素必须是 concepts 列表中已存在的概念名。**
|
||||
- `sources`:该规则在文档中的来源锚点列表,每项包含:
|
||||
- `section`:章节号
|
||||
- `type`:来源类型,`"table"` 或 `"para"` 或 `"logic_tree"`
|
||||
- `row`:如果是表格行(从 1 开始)
|
||||
- `text_snippet`:前 200 字的关键文字
|
||||
- `image_id`:如果是逻辑树来源,填写图片 rId
|
||||
- `logic_tree_nodes`:如果是逻辑树来源,列出相关节点 ID 列表
|
||||
|
||||
## function_units 分解策略(重要)
|
||||
|
||||
**按逻辑树的每条叶子路径生成一个 function_unit**:
|
||||
|
||||
1. **叶子路径 = 从根节点到叶子节点(end 类型)的完整决策链**,包含路径上所有中间节点和叶子节点的最终动作
|
||||
2. **每条叶子路径对应一个 function_unit**:不同决策分支导向不同叶子节点 → 不同的 function_unit
|
||||
3. **"不受限"叶子节点也必须建模**:即使 action 是"不执行任何限制操作",也要创建对应的 function_unit
|
||||
4. **禁止合并不同叶子节点**:不要将多个不同叶子节点的结果合并到一个 function_unit(除非它们触发完全相同的动作且属于同一父分支)
|
||||
5. **文字描述中的功能单独列出**:对于无法对应到逻辑树节点的功能(如纯文字描述的功能行为),用 table/para 类型 source,path 用语义路径
|
||||
6. **非流程图的图片也可能包含功能行为**:rId18 等图片的描述文本中可能包含功能规则(如"使用语音打开受限应用"),同样需要提取为 function_unit
|
||||
|
||||
**重要:不要创建纯开关/状态的空壳 unit**。"开关开启"本身不是一个功能行为(它没有 action),它是其他单元的 precondition。如果一个 function_unit 的 path 只有 `["国内", "开关开启"]` 且 sources 中只有 n1/n2/n3 这样的根/开关节点,说明它不是真正的功能单元,不应该输出。
|
||||
|
||||
{feedback}
|
||||
|
||||
## 权威性规则
|
||||
|
||||
1. **逻辑树(流程图)是权威来源**:逻辑树定义了功能的确切行为。识别 function_unit 时必须优先按逻辑树路径建模。文字和表格用于补充描述、提供确切措辞(如 Toast 文案),但不应覆盖或曲解逻辑树路径。
|
||||
|
||||
2. **logic_tree_nodes 必须构成有效路径**:每个 function_unit 引用的 logic_tree_nodes 列表,必须对应逻辑树中的**一条连通路径**。禁止将互斥分支上的节点混入同一个 source(例如 n4 是"开关关闭"分支,n8 是"开关开启"分支的下游节点,它们不能出现在同一 function_unit 中)。
|
||||
|
||||
3. **resolved_conflicts 中的仲裁是最终决定**:如果文档有图文冲突且已仲裁,严格按仲裁结果处理。
|
||||
|
||||
4. **逻辑树路径应全部覆盖**:下面是程序从文档逻辑树中枚举的全部决策路径,请逐一确认每条路径都有对应的 function_unit:
|
||||
|
||||
{logic_tree_paths}
|
||||
|
||||
## 关键要求
|
||||
|
||||
1. **必须覆盖所有逻辑树路径**:上面列出的每条路径必须被至少一个 function_unit 的 sources 引用。
|
||||
|
||||
2. **必须覆盖表格中的所有规则**:表格中列出的每种"限制方法"、"限制规则"都要有对应的 function_unit。
|
||||
|
||||
3. **区分"限制"与"禁止"**:文档中"行车娱乐限制"(前台应用打断)和"行车娱乐禁止"(后台应用启动限制)是两个不同的子场景,必须分别建模。
|
||||
|
||||
4. **区分不同应用类型**:系统限制、SDK 限制、其他应用的行为路径不同,必须分别建模。
|
||||
|
||||
5. **包含开关状态**:开关"开启"和"关闭"两种状态下的行为都要覆盖。
|
||||
|
||||
6. **概念和路径必须有层级**:每个 concept 指定正确的 parent;每个 function_unit 输出 path 数组。
|
||||
|
||||
## 输出格式
|
||||
|
||||
**只输出 JSON,不要有 markdown 代码块标记或其他文字**:
|
||||
|
||||
{
|
||||
"feature_name": "...",
|
||||
"concepts": [
|
||||
{"name": "国内", "aliases": [], "defined_in": ["2.7", "3.1"], "parent": null},
|
||||
{"name": "行车娱乐限制", "aliases": [], "defined_in": ["3.1", "3.1.1"], "parent": "国内"},
|
||||
...
|
||||
],
|
||||
"function_units": [
|
||||
{
|
||||
"unit_id": "FU-001",
|
||||
"name": "国内-系统限制-前台-行车打断",
|
||||
"description": "...",
|
||||
"path": ["国内", "系统限制", "前台打断"],
|
||||
"sources": [
|
||||
{"section": "3.1.1", "type": "table", "row": 2, "text_snippet": "打断:车速>=15km/h且持续5秒后..."},
|
||||
{"image_id": "rId16", "type": "logic_tree", "logic_tree_nodes": ["n2","n3","n8","n19","n21","n23","n25","n26"]}
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
你是吉利汽车车机系统的需求分析专家。你的任务是基于给定的精准上下文包,为单个功能单元(Function Unit)提取详细的 **IR 规则(Intermediate Representation Rule)**。
|
||||
|
||||
## 上下文
|
||||
|
||||
下面是一个功能单元的精准上下文包,包含了从原始需求文档中提取的相关文字、表格和逻辑树:
|
||||
|
||||
### 功能单元概要
|
||||
- **unit_id**: {unit_id}
|
||||
- **unit_name**: {unit_name}
|
||||
- **unit_description**: {unit_description}
|
||||
|
||||
### 相关文字段落
|
||||
{texts}
|
||||
|
||||
### 相关表格
|
||||
{tables}
|
||||
|
||||
### 相关逻辑树
|
||||
{logic_trees}
|
||||
|
||||
### 图文冲突仲裁(如有)
|
||||
{resolved_conflicts}
|
||||
|
||||
## IR Schema
|
||||
|
||||
你需要为这个功能单元输出一个 **规则数组(rules)**。每条规则遵循以下 schema:
|
||||
|
||||
```json
|
||||
{{
|
||||
"rule_id": "{unit_id}-DOMESTIC-SYS-FG-INTERRUPT-01",
|
||||
"path": ["国内", "系统限制", "前台打断"],
|
||||
"description": "国内车型,开关开启,系统限制类应用在前台,车速>=15km/h且持续>5秒且非P档时,系统打断应用前台进程、将应用调入后台,显示Toast'在行车状态下无法使用该应用'",
|
||||
"priority": "P0",
|
||||
"sources": [
|
||||
{{"type": "table", "section": "3.1.1", "row": 2, "text_snippet": "打断:车速>=15km/h且持续5秒后..."}},
|
||||
{{"type": "logic_tree", "image_id": "rId16", "node_ids": ["n2","n3","n8","n19","n21","n23","n25","n26"], "priority": "primary_source"}}
|
||||
],
|
||||
"precondition": {{
|
||||
"geographic_scope": "国内",
|
||||
"screen_type": "any",
|
||||
"switch": "开启",
|
||||
"app_type": "系统限制",
|
||||
"app_state": "前台"
|
||||
}},
|
||||
"trigger": {{
|
||||
"operator": "AND",
|
||||
"conditions": [
|
||||
{{"signal": "车速", "operator": ">=", "value": 15, "unit": "km/h"}},
|
||||
{{"signal": "车速_持续时间", "operator": ">", "value": 5, "unit": "秒"}},
|
||||
{{"signal": "档位", "operator": "!=", "value": "P"}}
|
||||
]
|
||||
}},
|
||||
"actions": [
|
||||
{{"type": "system", "description": "打断应用前台进程"}},
|
||||
{{"type": "system", "description": "将应用调入后台"}},
|
||||
{{"type": "user_interaction", "description": "显示Toast", "content": "在行车状态下无法使用该应用"}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
## 字段说明(必读)
|
||||
|
||||
1. **rule_id**: 格式为 `{unit_id}-SCOPE-METHOD-BEHAVIOR-NN`,其中:
|
||||
- SCOPE: DOMESTIC(国内)| OVERSEAS(海外)
|
||||
- METHOD: SYS(系统限制)| SDK(SDK限制)| OTHER(其他应用)
|
||||
- BEHAVIOR: FG-INTERRUPT(前台打断)| BG-BLOCK(后台限制启动)| BG-PAUSE(后台暂停功能)| NO-RESTRICT(无限制)| SWITCH-OFF(开关关闭)
|
||||
- NN: 序号从 01 开始
|
||||
|
||||
2. **path**: 层级路径数组(必填)。从 scope 到 behavior 逐级列出,如 `["国内", "系统限制", "前台打断"]`。此字段用于程序化遍历所有功能点。
|
||||
|
||||
3. **description**: 完整但简洁地描述整个规则,必须包含:地理范围 + 开关状态 + 应用类型 + 前后台状态 + 触发条件 + 所有动作。人读取此字段即可设计测试用例。
|
||||
|
||||
4. **priority**: P0(核心安全规则)、P1(重要规则)、P2(边界情况)。
|
||||
|
||||
5. **sources**: 每条规则必须列出所有数据来源。逻辑树类型的 source 必须标记 `"priority": "primary_source"`。文字/表格类型的 source 标记 `"priority": "supplementary"`。**node_ids 必须列举该规则在逻辑树中经历的所有 decision 和 action 节点。**
|
||||
|
||||
6. **precondition**: 规则生效的前置状态条件。必须包含以下字段:
|
||||
- `geographic_scope`(必填):"国内" | "海外"
|
||||
- `screen_type`(必填):"CSD" | "PSD" | "RFD" | "any"(如文档未区分屏幕类型则填 "any")
|
||||
- `switch`:开关状态("开启" | "关闭")
|
||||
- `app_type`:应用类型
|
||||
- `app_state`:应用前后台状态("前台" | "后台")
|
||||
如某字段不适用,可省略。
|
||||
|
||||
7. **trigger**: 触发条件对象:
|
||||
- `operator`: "AND" | "OR"
|
||||
- `conditions`: 条件数组,每个条件必须有 `signal`、`operator`、`value`。有单位加 `unit`。
|
||||
- 如为瞬时事件(用户点击),用 `event` 字段。
|
||||
|
||||
8. **actions**: 每个动作必须有 `type`("system" | "user_interaction")和 `description`。
|
||||
- `"user_interaction"` 类型必须有 `content` 字段,填写**确切的提示文案**。
|
||||
- **禁止使用占位符**:content 不能是"文案由业务定义"、"待定"、"自定义"等。如果文档中给出了文案,必须原样填入。如果文档确实未给出文案,填写 `"(文档未指定)"` 并标注。
|
||||
|
||||
## Few-shot 示例
|
||||
|
||||
### 示例 1:行车娱乐限制(前台打断)
|
||||
|
||||
**输入上下文**:国内车型,开关开启,系统限制类应用在前台,车速>=15km/h且持续>5秒且非P档时,打断应用并显示Toast"在行车状态下无法使用该应用"。
|
||||
|
||||
**期望输出**:
|
||||
|
||||
```json
|
||||
{{
|
||||
"rule_id": "FU-001-DOMESTIC-SYS-FG-INTERRUPT-01",
|
||||
"path": ["国内", "系统限制", "前台打断"],
|
||||
"description": "国内车型,开关开启,系统限制类应用在前台,当车速>=15km/h且持续超过5秒且非P档时,系统打断应用前台进程、将应用调入后台,并弹出Toast提示'在行车状态下无法使用该应用'",
|
||||
"priority": "P0",
|
||||
"sources": [
|
||||
{{"type": "table", "section": "3.1.1", "row": 2, "text_snippet": "行车娱乐限制:目标应用/功能处于前台时 ○ 打断:车速>=15km/h且持续5秒后...", "priority": "supplementary"}},
|
||||
{{"type": "logic_tree", "image_id": "rId16", "node_ids": ["n2","n3","n8","n19","n21","n23","n25","n26"], "priority": "primary_source"}}
|
||||
],
|
||||
"precondition": {{
|
||||
"geographic_scope": "国内",
|
||||
"screen_type": "any",
|
||||
"switch": "开启",
|
||||
"app_type": "系统限制",
|
||||
"app_state": "前台"
|
||||
}},
|
||||
"trigger": {{
|
||||
"operator": "AND",
|
||||
"conditions": [
|
||||
{{"signal": "车速", "operator": ">=", "value": 15, "unit": "km/h"}},
|
||||
{{"signal": "车速_持续时间", "operator": ">", "value": 5, "unit": "秒"}},
|
||||
{{"signal": "档位", "operator": "!=", "value": "P"}}
|
||||
]
|
||||
}},
|
||||
"actions": [
|
||||
{{"type": "system", "description": "打断应用前台进程"}},
|
||||
{{"type": "system", "description": "将应用调入后台"}},
|
||||
{{"type": "user_interaction", "description": "显示Toast", "content": "在行车状态下无法使用该应用"}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
### 示例 2:行车娱乐禁止(后台启动拦截)
|
||||
|
||||
**输入上下文**:国内车型,开关开启,应用在后台,非P档时阻止应用启动,提示"请在P挡时使用该功能/应用"。
|
||||
|
||||
**期望输出**:
|
||||
|
||||
```json
|
||||
{{
|
||||
"rule_id": "FU-002-DOMESTIC-SYS-BG-BLOCK-01",
|
||||
"path": ["国内", "系统限制", "后台限制启动"],
|
||||
"description": "国内车型,开关开启,目标应用处于后台,当用户尝试启动应用且档位非P档时,系统限制应用/功能启用,并弹出Toast提示'请在P挡时使用该功能/应用'",
|
||||
"priority": "P0",
|
||||
"sources": [
|
||||
{{"type": "table", "section": "3.1.1", "row": 2, "text_snippet": "行车娱乐禁止:目标应用/功能处于后台时 ○ 限制:非P挡时,限制目标应用/功能启用...", "priority": "supplementary"}},
|
||||
{{"type": "logic_tree", "image_id": "rId17", "node_ids": ["n1","n2","n5","n7"], "priority": "primary_source"}}
|
||||
],
|
||||
"precondition": {{
|
||||
"geographic_scope": "国内",
|
||||
"screen_type": "any",
|
||||
"switch": "开启",
|
||||
"app_state": "后台"
|
||||
}},
|
||||
"trigger": {{
|
||||
"operator": "AND",
|
||||
"conditions": [
|
||||
{{"signal": "应用请求启动", "operator": "==", "value": true}},
|
||||
{{"signal": "档位", "operator": "!=", "value": "P"}}
|
||||
]
|
||||
}},
|
||||
"actions": [
|
||||
{{"type": "system", "description": "限制应用/功能启用"}},
|
||||
{{"type": "user_interaction", "description": "显示Toast", "content": "请在P挡时使用该功能/应用"}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
## 关键要求
|
||||
|
||||
1. **逻辑树为唯一权威来源**:触发条件和动作序列必须严格按逻辑树路径建模。文字/表格描述仅用于补充确切措辞(如 Toast 文案),不得覆盖或曲解逻辑树路径。在 sources 中,逻辑树类型标记 `"priority": "primary_source"`,文字/表格标记 `"priority": "supplementary"`。
|
||||
|
||||
2. **信号和数值必须精确**:禁止写"车速超过阈值",必须写 `{{"signal": "车速", "operator": ">=", "value": 15, "unit": "km/h"}}`。
|
||||
|
||||
3. **条件必须完整**:逻辑树中的每个 decision 条件必须对应 trigger.conditions 中的一条。如果文档说"车速>=15km/h 且持续超过5秒 且非P档",这三个条件必须全部出现。
|
||||
|
||||
4. **每条规则必须自包含**:人仅凭一条 rule JSON 就能设计出对应的测试用例。必须包含:geographic_scope、screen_type、开关状态、应用类型、前后台状态、完整触发条件、所有动作及确切 Toast 文案、来源引用。
|
||||
|
||||
5. **禁止占位符**:`"user_interaction"` 类型的 `content` 不能是"文案由业务定义"、"待定"、"自定义"。如文档确实未给出文案,填 `"(文档未指定)"`。
|
||||
|
||||
6. **逻辑树节点必须追踪**:在 sources 中列出该规则在逻辑树中经历的所有 decision 节点和 action 节点。
|
||||
|
||||
7. **多条规则**:如果一个功能单元包含多个独立行为分支,输出多条规则分别描述。
|
||||
|
||||
8. **开关关闭状态**:开关关闭时所有限制失效,这也必须作为一条规则输出(path: ["...", "开关关闭", "无限制"])。
|
||||
|
||||
{format_feedback}
|
||||
|
||||
## 输出格式
|
||||
|
||||
**只输出 JSON 数组,不要有任何其他文字或 markdown 标记**:
|
||||
|
||||
[
|
||||
{{ ... }},
|
||||
{{ ... }}
|
||||
]
|
||||
|
||||
注意:即使只有一个规则,也必须用数组格式 `[...]`。
|
||||
@@ -1,105 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""Low-level OpenAI-compatible LLM client with retry and token tracking.
|
||||
|
||||
Usage::
|
||||
|
||||
llm = LLMClient()
|
||||
content = llm.chat("qwen3.5-flash", [{"role": "user", "content": "Hello"}])
|
||||
print(llm.usage)
|
||||
"""
|
||||
|
||||
IMAGE_MODEL = "qwen3-vl-plus"
|
||||
TEXT_MODEL = "qwen3.5-flash-2026-02-23"
|
||||
TIMEOUT = 120
|
||||
MAX_RETRIES = 3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
timeout: int | None = None,
|
||||
):
|
||||
key = os.environ.get("DASHSCOPE_API_KEY", "")
|
||||
if not key:
|
||||
raise ValueError("DASHSCOPE_API_KEY environment variable is not set.")
|
||||
self._client = OpenAI(api_key=key, base_url=base_url)
|
||||
self._timeout = timeout or self.TIMEOUT
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
|
||||
@property
|
||||
def usage(self) -> dict:
|
||||
"""Return accumulated token counts as ``{prompt, completion, total}``."""
|
||||
return {
|
||||
"prompt_tokens": self._prompt_tokens,
|
||||
"completion_tokens": self._completion_tokens,
|
||||
"total_tokens": self._prompt_tokens + self._completion_tokens,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Quick token estimate. CJK ≈1.7/token, others ≈3.0/token."""
|
||||
cjk = sum(1 for c in text if '一' <= c <= '鿿' or ' ' <= c <= '〿')
|
||||
other = len(text) - cjk
|
||||
return max(1, int(cjk / 1.7 + other / 3.0))
|
||||
|
||||
@staticmethod
|
||||
def estimate_image_tokens() -> int:
|
||||
"""Fixed estimate for one vision-model image (~500 tokens)."""
|
||||
return 500
|
||||
|
||||
def chat(
|
||||
self, model: str, messages: list[dict], *, timeout: int | None = None,
|
||||
response_format: dict | None = None,
|
||||
) -> str:
|
||||
"""Send a chat completion request and return the response content.
|
||||
|
||||
Automatically retries on failure and accumulates token usage.
|
||||
"""
|
||||
label = f"chat({model})"
|
||||
|
||||
def _call():
|
||||
t0 = time.time()
|
||||
kwargs = dict(model=model, messages=messages, timeout=timeout or self._timeout)
|
||||
if response_format is not None:
|
||||
kwargs["response_format"] = response_format
|
||||
kwargs["temperature"] = 0
|
||||
resp = self._client.chat.completions.create(**kwargs)
|
||||
content = resp.choices[0].message.content
|
||||
usg = resp.usage
|
||||
if usg:
|
||||
self._prompt_tokens += usg.prompt_tokens
|
||||
self._completion_tokens += usg.completion_tokens
|
||||
elapsed = time.time() - t0
|
||||
logger.info("%s: %d chars in %.1fs", label, len(content) if content else 0, elapsed)
|
||||
if not content:
|
||||
raise RuntimeError("Empty response from LLM")
|
||||
return content
|
||||
|
||||
return self._retry(_call, label)
|
||||
|
||||
def _retry(self, fn, label: str) -> str:
|
||||
"""Call *fn()* with exponential-backoff retry."""
|
||||
last_error: Optional[Exception] = None
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
return fn()
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
"%s error (attempt %d/%d): %s",
|
||||
label, attempt + 1, self.MAX_RETRIES, e,
|
||||
)
|
||||
if attempt < self.MAX_RETRIES - 1:
|
||||
time.sleep(2 ** attempt)
|
||||
raise RuntimeError(f"{label}: all retries exhausted") from last_error
|
||||
@@ -1,359 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate JSON intermediate representation from ``_parsed.json`` or ``_updated.json``.
|
||||
|
||||
Sends the JSON document directly to the LLM for analysis. If the document exceeds
|
||||
``MAX_ANALYSIS_TOKENS``, sections are batched greedily without splitting any
|
||||
individual section. Conflict corrections from ``resolved_conflicts`` are included
|
||||
so the output respects user arbitration decisions.
|
||||
|
||||
Usage::
|
||||
|
||||
python scripts/ir_generator.py output/<basename>_updated.json [output_dir] [--dry-run]
|
||||
|
||||
Output: ``<basename>_ir.json``
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from LLM import LLMClient
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
RATE_LIMIT_DELAY = 0.5
|
||||
MAX_ANALYSIS_TOKENS = 6000 # max content size per LLM call
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROMPT = """你是一个需求文档分析助手。请分析以下需求文档的JSON内容,输出结构化JSON。
|
||||
|
||||
## 已知修正(来自冲突检测)
|
||||
以下内容已确认修正,生成JSON时请**使用修正后的值**,不要同时输出两个版本。
|
||||
{conflict_context}
|
||||
|
||||
## 待分析内容(JSON格式)
|
||||
|
||||
{content}
|
||||
|
||||
## JSON字段说明
|
||||
- sections: 文档章节列表,每个章节含 source(章节标题)和 blocks(内容块数组)
|
||||
- blocks: 类型含 para(段落,字段 text)和 table(表格,字段 rows,每行含 columns 数组)
|
||||
- image_sources: 图片所在章节映射,key 为图片 rid
|
||||
- image_analysis: 图片分析结果,每个含 rid、type(流程图/架构图/状态图等)、description
|
||||
- resolved_conflicts: 已知修正列表,每个含 section、conflict_type、correction、source
|
||||
|
||||
## 功能点定义
|
||||
|
||||
只有满足以下**全部条件**的才视为功能点:
|
||||
1. 描述了一个**系统或软件要实现的具体行为**(有触发条件、执行动作、状态变化或逻辑规则)
|
||||
2. 该行为直接由**系统或框架**执行(不是人的操作流程、管理流程)
|
||||
3. 对用户或系统有**可观察的效果**
|
||||
|
||||
**以下内容不是功能点,不要输出:**
|
||||
- 术语/缩略词定义(
|
||||
- 文档背景、范围说明(如 "本文档涵盖xxx")
|
||||
- 变更日志、版本记录、编制人信息
|
||||
- 文档结构描述(如 "产品简介用户场景说明")
|
||||
- 纯文本的概述、没有具体行为的介绍
|
||||
|
||||
## 决策树/流程图分解规则(重要)
|
||||
|
||||
图片分析(image_analysis)中的流程图和决策树描述包含丰富的功能逻辑,**必须完全分解**:
|
||||
|
||||
1. **每个叶子路径 = 一个独立 function**:从根节点到每个最终结果的完整路径,都拆成一个 function
|
||||
2. **每个判断分支 = 一个独立 function**:菱形判断节点的每个分支方向和对应的结果,单独作为一个 function
|
||||
3. **不同约束条件 = 不同 function**:例如"通过接入SDK限制"和"通过系统限制"是不同约束机制,必须分别列出
|
||||
4. **不要合并不同路径**:即使最终结果相同,只要到达路径不同,就是不同的 function
|
||||
|
||||
## 输出格式
|
||||
|
||||
只输出功能点,每个功能点格式如下:
|
||||
|
||||
{
|
||||
"function": "功能名称",
|
||||
"source": {
|
||||
"section": "章节名",
|
||||
"location": "原文位置(如:正文第1段、表格1第2行、图片rId13)"
|
||||
},
|
||||
"trigger": {
|
||||
"type": "AND或者OR",
|
||||
"conditions": [
|
||||
"触发条件1",
|
||||
"触发条件2"
|
||||
]
|
||||
},
|
||||
"actions": {
|
||||
"场景/角色": [
|
||||
"动作1",
|
||||
"动作2"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
## 输出原则
|
||||
|
||||
1. **只输出功能点**,没有功能点就输出空数组 []
|
||||
2. 每个功能点**必须**包含 source.section 和 source.location
|
||||
3. location 必须是具体的原文位置标签(如 "正文第1段"、"表格1"、"图片rId13")
|
||||
4. **一个 function 只对应一种行为逻辑(一条完整路径)**。决策树中的每个分支路径(从根到叶子)必须拆成独立 function,conditions 中明确写出该路径上的所有判断条件和分支方向。
|
||||
5. **穷举所有分支**:流程图/决策树中的每一条分支路径都要输出对应的 function,不能遗漏任何子逻辑。
|
||||
6. 没有 trigger 或 actions 的字段直接**省略**,不要写 null 或空列表/空对象
|
||||
7. 所有功能点全部列出,**宁多勿漏**
|
||||
8. **已知修正**中确认的信息,使用修正后的值
|
||||
9. 输出一个JSON数组,不要用 ```json 代码块包裹,直接输出纯JSON
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_llm_response(raw: str) -> list | dict | str | None:
|
||||
"""Parse JSON from LLM response, handling markdown code fences."""
|
||||
if raw is None:
|
||||
return None
|
||||
stripped = raw.strip()
|
||||
if stripped.startswith("```"):
|
||||
nl = stripped.find("\n")
|
||||
stripped = stripped[nl + 1:] if nl != -1 else stripped[3:]
|
||||
if stripped.endswith("```"):
|
||||
stripped = stripped[:-3]
|
||||
try:
|
||||
return json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(" Failed to parse JSON, returning raw text")
|
||||
return raw
|
||||
|
||||
|
||||
def _build_conflict_context(
|
||||
section_name: str | None,
|
||||
resolved_conflicts: list[dict],
|
||||
) -> str:
|
||||
"""Build conflict correction context for a section, or all if section_name is None."""
|
||||
if section_name is None:
|
||||
relevant = resolved_conflicts
|
||||
else:
|
||||
relevant = [c for c in resolved_conflicts if c.get("section", "") == section_name]
|
||||
if not relevant:
|
||||
return "没有"
|
||||
|
||||
lines: list[str] = []
|
||||
for c in relevant:
|
||||
correction = c.get("correction", "")
|
||||
conflict_type = c.get("conflict_type", "")
|
||||
source = c.get("source", "")
|
||||
lines.append(f"- 冲突类型:{conflict_type},依据:{source}")
|
||||
lines.append(f" 修正后的值:{correction}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM analysis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _analyze_content(
|
||||
content: str,
|
||||
conflict_context: str,
|
||||
llm: LLMClient,
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
) -> list[dict]:
|
||||
"""Send content to the LLM and return IR entries."""
|
||||
prompt = PROMPT.replace("{conflict_context}", conflict_context).replace("{content}", content)
|
||||
|
||||
if dry_run:
|
||||
est = llm.estimate_tokens(prompt)
|
||||
logger.info(" [DRY RUN] prompt ~%d tokens", est)
|
||||
return []
|
||||
|
||||
try:
|
||||
raw = llm.chat(
|
||||
model=LLMClient.TEXT_MODEL,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
logger.info(" Response: %d chars", len(raw))
|
||||
except RuntimeError as e:
|
||||
logger.error(" Analysis failed: %s", e)
|
||||
return []
|
||||
|
||||
parsed = _parse_llm_response(raw)
|
||||
if isinstance(parsed, list):
|
||||
return parsed
|
||||
elif isinstance(parsed, dict):
|
||||
return [parsed]
|
||||
else:
|
||||
logger.warning(" Unparseable response, raw length: %d", len(raw))
|
||||
return []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def generate_ir(
|
||||
parsed_path: str,
|
||||
output_dir: str = "output",
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
) -> dict:
|
||||
"""Read parsed/updated JSON and generate JSON IR.
|
||||
|
||||
Produces ``<basename>_ir.json`` in *output_dir*.
|
||||
"""
|
||||
with open(parsed_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
basename = os.path.splitext(os.path.basename(parsed_path))[0]
|
||||
for suffix in ("_parsed", "_updated"):
|
||||
if basename.endswith(suffix):
|
||||
basename = basename[:-len(suffix)]
|
||||
break
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
llm = LLMClient()
|
||||
ir_output: list[dict] = []
|
||||
|
||||
sections = data.get("sections", [])
|
||||
image_sources = data.get("image_sources", {})
|
||||
image_analysis = data.get("image_analysis", [])
|
||||
resolved_conflicts = data.get("resolved_conflicts", [])
|
||||
|
||||
# Build full document JSON to measure size
|
||||
full_doc = {
|
||||
"sections": sections,
|
||||
"image_sources": image_sources,
|
||||
"image_analysis": image_analysis,
|
||||
}
|
||||
full_json = json.dumps(full_doc, ensure_ascii=False)
|
||||
total_chars = len(full_json)
|
||||
logger.info("Total document JSON chars: %d", total_chars)
|
||||
|
||||
if total_chars < MAX_ANALYSIS_TOKENS:
|
||||
logger.info("Document fits in one request (< %d chars)", MAX_ANALYSIS_TOKENS)
|
||||
conflict_ctx = _build_conflict_context(None, resolved_conflicts)
|
||||
entries = _analyze_content(full_json, conflict_ctx, llm, dry_run=dry_run)
|
||||
ir_output.extend(entries)
|
||||
else:
|
||||
logger.info("Document is large (>= %d chars), batching sections", MAX_ANALYSIS_TOKENS)
|
||||
|
||||
# Filter to non-empty sections, measure effective size per section
|
||||
# (section JSON + image_sources + image_analysis for images in that section)
|
||||
sec_sizes = []
|
||||
for sec in sections:
|
||||
if not sec.get("blocks"):
|
||||
continue
|
||||
sec_json = json.dumps(sec, ensure_ascii=False)
|
||||
sec_chars = len(sec_json)
|
||||
# Add image overhead for this section
|
||||
sec_name = sec.get("source", "")
|
||||
sec_rids = [rid for rid, src in image_sources.items()
|
||||
if src.get("section", "") == sec_name]
|
||||
if sec_rids:
|
||||
overhead_doc = {
|
||||
"image_sources": {rid: image_sources[rid] for rid in sec_rids},
|
||||
"image_analysis": [img for img in image_analysis
|
||||
if img.get("rid", "") in sec_rids],
|
||||
}
|
||||
sec_chars += len(json.dumps(overhead_doc, ensure_ascii=False))
|
||||
sec_sizes.append((sec, sec_chars))
|
||||
|
||||
# Greedy batch: never split a section, keep adding until next exceeds limit
|
||||
i = 0
|
||||
while i < len(sec_sizes):
|
||||
batch = []
|
||||
batch_size = 0
|
||||
while i < len(sec_sizes) and batch_size + sec_sizes[i][1] <= MAX_ANALYSIS_TOKENS:
|
||||
batch.append(sec_sizes[i][0])
|
||||
batch_size += sec_sizes[i][1]
|
||||
i += 1
|
||||
|
||||
if not batch:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Collect sections and their images for this batch
|
||||
batch_names = [s.get("source", "") for s in batch]
|
||||
batch_image_sources = {
|
||||
rid: src for rid, src in image_sources.items()
|
||||
if src.get("section", "") in batch_names
|
||||
}
|
||||
batch_images = [
|
||||
img for img in image_analysis
|
||||
if image_sources.get(img.get("rid", ""), {}).get("section", "") in batch_names
|
||||
]
|
||||
|
||||
batch_doc = {
|
||||
"sections": batch,
|
||||
"image_sources": batch_image_sources,
|
||||
"image_analysis": batch_images,
|
||||
}
|
||||
batch_json = json.dumps(batch_doc, ensure_ascii=False)
|
||||
|
||||
# Merge conflict contexts
|
||||
ctx_parts = []
|
||||
for sn in batch_names:
|
||||
ctx = _build_conflict_context(sn, resolved_conflicts)
|
||||
if ctx != "没有":
|
||||
ctx_parts.append(ctx)
|
||||
conflict_ctx = "\n".join(ctx_parts) if ctx_parts else "没有"
|
||||
|
||||
label = " + ".join(batch_names)
|
||||
logger.info("Batch [%s]: %d sections, %d chars", label, len(batch), len(batch_json))
|
||||
entries = _analyze_content(batch_json, conflict_ctx, llm, dry_run=dry_run)
|
||||
ir_output.extend(entries)
|
||||
time.sleep(RATE_LIMIT_DELAY)
|
||||
|
||||
# ---- save ----------------------------------------------------------------
|
||||
ir_path = os.path.join(output_dir, f"{basename}_ir.json")
|
||||
os.makedirs(os.path.dirname(ir_path) or ".", exist_ok=True)
|
||||
with open(ir_path, "w", encoding="utf-8") as f:
|
||||
json.dump(ir_output, f, ensure_ascii=False, indent=2)
|
||||
logger.info("Saved: %s (%d entries)", ir_path, len(ir_output))
|
||||
|
||||
# ---- summary -------------------------------------------------------------
|
||||
usg = llm.usage
|
||||
logger.info("Tokens: %d prompt + %d completion = %d total",
|
||||
usg["prompt_tokens"], usg["completion_tokens"], usg["total_tokens"])
|
||||
logger.info("Output: %s", ir_path)
|
||||
|
||||
return {"ir": ir_output, "path": ir_path}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate JSON intermediate representation from parsed/updated JSON.",
|
||||
)
|
||||
parser.add_argument("input", metavar="parsed.json",
|
||||
help="Path to _parsed.json or _updated.json")
|
||||
parser.add_argument("output_dir", nargs="?", default="output", metavar="output_dir",
|
||||
help="Directory for output files (default: output/)")
|
||||
parser.add_argument("--dry-run", action="store_true",
|
||||
help="Print token estimates without calling the API.")
|
||||
|
||||
args = parser.parse_args()
|
||||
generate_ir(args.input, args.output_dir, dry_run=args.dry_run)
|
||||
@@ -0,0 +1,717 @@
|
||||
"""
|
||||
Stage 1: Ensemble Semantic Index Generation.
|
||||
|
||||
Generates N parallel LLM calls with different temperatures (e.g., 0.0, 0.3, 0.7),
|
||||
then deterministically merges the results via ensemble_merge (pure Python, no LLM).
|
||||
The merged output includes confidence scores for each concept and function_unit.
|
||||
|
||||
Outputs:
|
||||
- output/semantic_index_r1.json (T=0.0 raw)
|
||||
- output/semantic_index_r2.json (T=0.3 raw)
|
||||
- output/semantic_index_r3.json (T=0.7 raw)
|
||||
- output/semantic_index.json (ensemble-merged final)
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import config
|
||||
from ensemble_merge import ensemble_merge
|
||||
|
||||
|
||||
# ---- Path Enumeration (for prompt embedding) ----
|
||||
|
||||
|
||||
def _traverse_nested(node: dict, image_id: str, path_nodes: list,
|
||||
branch_taken: str | None) -> list[dict]:
|
||||
"""DFS traversal of a logic_tree_nested node, returning leaf path records."""
|
||||
node_id = node.get("id", "?")
|
||||
node_type = node.get("type", "?")
|
||||
node_name = node.get("name", "")
|
||||
|
||||
path_nodes = path_nodes + [{
|
||||
"id": node_id,
|
||||
"type": node_type,
|
||||
"label": node_name,
|
||||
"branch_taken": branch_taken,
|
||||
}]
|
||||
|
||||
if node_type == "end":
|
||||
return [_make_path_record(path_nodes, image_id)]
|
||||
|
||||
children = node.get("children", [])
|
||||
if not children:
|
||||
return [_make_path_record(path_nodes, image_id)]
|
||||
|
||||
all_paths = []
|
||||
for child in children:
|
||||
# Decision nodes have {condition, node} wrappers; others are direct node dicts
|
||||
if node_type == "decision":
|
||||
condition = child.get("condition", "")
|
||||
child_node = child.get("node", child)
|
||||
else:
|
||||
condition = "(implicit)"
|
||||
child_node = child
|
||||
|
||||
all_paths.extend(
|
||||
_traverse_nested(child_node, image_id, path_nodes, condition)
|
||||
)
|
||||
|
||||
return all_paths
|
||||
|
||||
|
||||
def _make_path_record(path_nodes: list, image_id: str) -> dict:
|
||||
"""Build a path record from a completed node chain."""
|
||||
action_nodes = [n for n in path_nodes if n["type"] == "action"]
|
||||
decision_nodes = [n for n in path_nodes if n["type"] == "decision"]
|
||||
node_ids = [n["id"] for n in path_nodes]
|
||||
|
||||
return {
|
||||
"path_id": f"PATH-{image_id}-{'-'.join(node_ids)}",
|
||||
"nodes": path_nodes,
|
||||
"meaning": _describe_path(path_nodes),
|
||||
"image_id": image_id,
|
||||
"action_nodes": action_nodes,
|
||||
"decision_nodes": decision_nodes,
|
||||
"node_ids": node_ids,
|
||||
}
|
||||
|
||||
|
||||
def enumerate_logic_tree_paths(nested_tree: dict, image_id: str = "") -> list[dict]:
|
||||
"""Enumerate all root-to-leaf paths from a logic_tree_nested structure.
|
||||
|
||||
Uses the nested tree directly (no flat-list adjacency). Decision nodes
|
||||
fork by {condition, node} branches; other nodes have direct children.
|
||||
"""
|
||||
if not nested_tree:
|
||||
return []
|
||||
return _traverse_nested(nested_tree, image_id, [], None)
|
||||
|
||||
|
||||
def _describe_path(path_nodes: list[dict]) -> str:
|
||||
"""Generate a human-readable description of a logic tree path."""
|
||||
parts = []
|
||||
for n in path_nodes:
|
||||
label = n["label"]
|
||||
if n["branch_taken"] and n["branch_taken"] != "(implicit)":
|
||||
label = f"{label} → {n['branch_taken']}"
|
||||
parts.append(label)
|
||||
return " → ".join(parts)
|
||||
|
||||
|
||||
def enumerate_all_paths(doc: dict) -> dict[str, list[dict]]:
|
||||
"""Enumerate paths for all logic trees in the document.
|
||||
|
||||
Uses logic_tree_nested when available (proper tree), falling back to
|
||||
flat logic_tree. Returns {image_id: [path, ...]}.
|
||||
"""
|
||||
result = {}
|
||||
for img in doc.get("image_analysis", []):
|
||||
rid = img.get("rid", "")
|
||||
if not rid:
|
||||
continue
|
||||
nested = img.get("logic_tree_nested")
|
||||
if nested:
|
||||
result[rid] = enumerate_logic_tree_paths(nested, image_id=rid)
|
||||
else:
|
||||
lt = img.get("logic_tree")
|
||||
if lt and lt.get("nodes"):
|
||||
lt["image_id"] = rid
|
||||
result[rid] = _enumerate_flat_tree(lt)
|
||||
elif lt:
|
||||
result[rid] = []
|
||||
return result
|
||||
|
||||
|
||||
def _enumerate_flat_tree(tree: dict) -> list[dict]:
|
||||
"""Fallback: enumerate paths from flat logic_tree using adjacency.
|
||||
Handles start/process/action/state nodes as implicit chain links.
|
||||
"""
|
||||
nodes = tree.get("nodes", [])
|
||||
if not nodes:
|
||||
return []
|
||||
node_map = {n["id"]: n for n in nodes}
|
||||
image_id = tree.get("image_id", "")
|
||||
|
||||
# Find root: first start/state node, or first process node, or first node
|
||||
root = None
|
||||
for n in nodes:
|
||||
if n["type"] in ("start", "state"):
|
||||
root = n
|
||||
break
|
||||
if root is None:
|
||||
for n in nodes:
|
||||
if n["type"] == "process":
|
||||
root = n
|
||||
break
|
||||
if root is None:
|
||||
root = nodes[0]
|
||||
|
||||
adj = _build_adjacency(nodes, node_map)
|
||||
paths = []
|
||||
|
||||
def dfs(current_id, visited, path_nodes, branch_taken):
|
||||
if current_id in visited:
|
||||
return
|
||||
new_visited = visited | {current_id}
|
||||
node = node_map.get(current_id)
|
||||
if node is None:
|
||||
return
|
||||
|
||||
path_nodes = path_nodes + [{
|
||||
"id": current_id,
|
||||
"type": node["type"],
|
||||
"label": node.get("description") or node.get("condition", ""),
|
||||
"branch_taken": branch_taken,
|
||||
}]
|
||||
|
||||
outgoing = adj.get(current_id, [])
|
||||
if not outgoing:
|
||||
action_nodes = [n for n in path_nodes if n["type"] == "action"]
|
||||
decision_nodes = [n for n in path_nodes if n["type"] == "decision"]
|
||||
node_ids = [n["id"] for n in path_nodes]
|
||||
paths.append({
|
||||
"path_id": f"PATH-{image_id}-{'-'.join(node_ids)}",
|
||||
"nodes": path_nodes,
|
||||
"meaning": _describe_path(path_nodes),
|
||||
"image_id": image_id,
|
||||
"action_nodes": action_nodes,
|
||||
"decision_nodes": decision_nodes,
|
||||
"node_ids": node_ids,
|
||||
})
|
||||
else:
|
||||
for branch_val, target_id in outgoing:
|
||||
dfs(target_id, new_visited, path_nodes, branch_val)
|
||||
|
||||
dfs(root["id"], set(), [], None)
|
||||
return paths
|
||||
|
||||
|
||||
def _build_adjacency(nodes, node_map):
|
||||
"""Build {node_id: [(branch_value, target_id)]} adjacency for flat trees.
|
||||
|
||||
Handles: decision branches (explicit), non-branching nodes (implicit sequential).
|
||||
"""
|
||||
NON_BRANCHING = {"start", "process", "state", "action"}
|
||||
|
||||
adj = {}
|
||||
has_explicit_incoming = set()
|
||||
for n in nodes:
|
||||
for br in n.get("branches", []):
|
||||
has_explicit_incoming.add(br["target"])
|
||||
|
||||
for i, node in enumerate(nodes):
|
||||
nid = node["id"]
|
||||
adj.setdefault(nid, [])
|
||||
|
||||
# Explicit edges from decision nodes
|
||||
for br in node.get("branches", []):
|
||||
adj[nid].append((br["value"], br["target"]))
|
||||
|
||||
# Implicit edges for non-branching nodes (start/process/state/action)
|
||||
if node["type"] in NON_BRANCHING and not node.get("branches"):
|
||||
j = i + 1
|
||||
targets = []
|
||||
while j < len(nodes):
|
||||
next_node = nodes[j]
|
||||
next_nid = next_node["id"]
|
||||
if next_nid in has_explicit_incoming:
|
||||
break
|
||||
if next_node["type"] in NON_BRANCHING | {"end"}:
|
||||
targets.append(next_nid)
|
||||
has_explicit_incoming.add(next_nid)
|
||||
j += 1
|
||||
continue
|
||||
elif next_node["type"] == "decision":
|
||||
if not targets:
|
||||
targets.append(next_nid)
|
||||
break
|
||||
j += 1
|
||||
for t in targets:
|
||||
adj[nid].append(("(implicit)", t))
|
||||
|
||||
return adj
|
||||
|
||||
|
||||
def format_paths_for_prompt(all_paths: dict[str, list[dict]]) -> str:
|
||||
"""Format enumerated paths as a readable list for the LLM prompt."""
|
||||
if not all_paths:
|
||||
return "(无逻辑树路径)"
|
||||
|
||||
lines = []
|
||||
for image_id, paths in all_paths.items():
|
||||
lines.append(f"\n### {image_id} 的全部决策路径(共 {len(paths)} 条):")
|
||||
for i, path in enumerate(paths, 1):
|
||||
lines.append(f"\n**路径 {i}** (ID: {path['path_id']})")
|
||||
lines.append(f" 含义: {path['meaning']}")
|
||||
lines.append(f" 节点: {path['node_ids']}")
|
||||
lines.append(f" 决策节点: {[n['id'] for n in path['decision_nodes']]}")
|
||||
lines.append(f" 动作节点: {[n['id'] for n in path['action_nodes']]}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---- Document Formatting ----
|
||||
|
||||
|
||||
def format_document_for_prompt(doc: dict) -> str:
|
||||
"""Render the full parsed document as a readable string for the LLM prompt."""
|
||||
lines = []
|
||||
|
||||
lines.append("=== SECTIONS ===")
|
||||
for i, section in enumerate(doc.get("sections", [])):
|
||||
source = section.get("source", f"(无标题-章节{i})")
|
||||
lines.append(f"\n--- Section: {source} ---")
|
||||
|
||||
for block in section.get("blocks", []):
|
||||
if block["type"] == "para":
|
||||
lines.append(f"[段落 {block['index']}] {block['text']}")
|
||||
elif block["type"] == "table":
|
||||
lines.append(f"[表格 {block.get('table', '?')}]")
|
||||
headers = block.get("headers", [])
|
||||
lines.append(f" 表头: {' | '.join(headers)}")
|
||||
for row in block.get("rows", []):
|
||||
cols = row.get("columns", [])
|
||||
cell_texts = []
|
||||
for c in cols:
|
||||
cell_texts.append(
|
||||
f"[行{c.get('row','?')}]{c.get('name','')}: {c.get('text','')}"
|
||||
)
|
||||
lines.append(f" {'; '.join(cell_texts)}")
|
||||
|
||||
images = section.get("images", [])
|
||||
if images:
|
||||
lines.append(f" 图片引用: {', '.join(images)}")
|
||||
|
||||
lines.append("\n\n=== IMAGE_ANALYSIS (流程图逻辑树) ===")
|
||||
for img in doc.get("image_analysis", []):
|
||||
rid = img.get("rid", "?")
|
||||
img_type = img.get("type", "?")
|
||||
lines.append(f"\n--- Image: {rid} (type={img_type}) ---")
|
||||
lines.append(f" 描述: {img.get('description', '')[:300]}")
|
||||
|
||||
lt = img.get("logic_tree")
|
||||
if lt:
|
||||
lines.append(f" 逻辑树根节点: {lt.get('root', '?')}")
|
||||
lines.append(" 节点详情:")
|
||||
for node in lt.get("nodes", []):
|
||||
nid = node.get("id", "?")
|
||||
ntype = node.get("type", "?")
|
||||
desc = node.get("description", "") or node.get("condition", "")
|
||||
lines.append(f" [{ntype}] {nid}: {desc}")
|
||||
branches = node.get("branches", [])
|
||||
if branches:
|
||||
for br in branches:
|
||||
lines.append(f" → {br['value']} → {br['target']}")
|
||||
|
||||
conflicts = doc.get("resolved_conflicts", [])
|
||||
if conflicts:
|
||||
lines.append("\n\n=== RESOLVED_CONFLICTS (图文冲突仲裁) ===")
|
||||
for c in conflicts:
|
||||
lines.append(
|
||||
f" [{c.get('conflict_type','?')}] {c.get('section','?')}: "
|
||||
f"以{c.get('source','?')}为准 — {c.get('correction','')}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---- Prompt Building ----
|
||||
|
||||
|
||||
def build_prompt(doc: dict, feedback: str = "", all_paths: dict | None = None) -> str:
|
||||
"""Load the prompt template and inject the formatted document + paths + feedback."""
|
||||
template_path = Path(config.PROMPTS_DIR) / "step1_semantic_index.txt"
|
||||
template = template_path.read_text(encoding="utf-8")
|
||||
|
||||
formatted_doc = format_document_for_prompt(doc)
|
||||
prompt = template.replace("{document_json}", formatted_doc)
|
||||
|
||||
if all_paths is None:
|
||||
all_paths = enumerate_all_paths(doc)
|
||||
path_text = format_paths_for_prompt(all_paths)
|
||||
prompt = prompt.replace("{logic_tree_paths}", path_text)
|
||||
|
||||
if feedback:
|
||||
prompt = prompt.replace("{feedback}", feedback)
|
||||
else:
|
||||
prompt = prompt.replace("{feedback}", "")
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
# ---- Validation ----
|
||||
|
||||
|
||||
def _quick_validate(
|
||||
semantic_index: dict, doc: dict, all_paths: dict | None = None
|
||||
) -> tuple[bool, dict]:
|
||||
"""Validate semantic index and return (passed, gaps).
|
||||
|
||||
Uses a single COVERAGE_TARGET threshold (default 0.95).
|
||||
"""
|
||||
gaps = {
|
||||
"missing_paths": [],
|
||||
"missing_concepts": [],
|
||||
"format_issues": [],
|
||||
"parent_issues": [],
|
||||
}
|
||||
|
||||
units = semantic_index.get("function_units", [])
|
||||
concepts = semantic_index.get("concepts", [])
|
||||
|
||||
# --- Check function_units non-empty ---
|
||||
if not units:
|
||||
gaps["format_issues"].append("function_units 为空")
|
||||
return False, gaps
|
||||
|
||||
# --- Check each function_unit has path ---
|
||||
for fu in units:
|
||||
uid = fu.get("unit_id", "?")
|
||||
if not fu.get("path"):
|
||||
gaps["format_issues"].append(f"{uid}: 缺少 path 字段")
|
||||
if not fu.get("sources"):
|
||||
gaps["format_issues"].append(f"{uid}: 缺少 sources")
|
||||
|
||||
# --- Logic tree node coverage ---
|
||||
all_nodes = _collect_logic_tree_nodes(doc)
|
||||
referenced = _collect_referenced_nodes(units)
|
||||
|
||||
threshold = config.COVERAGE_TARGET
|
||||
|
||||
for image_id, node_set in all_nodes.items():
|
||||
ref_set = referenced.get(image_id, set())
|
||||
checkable = {
|
||||
nid for nid, ntype in node_set.items()
|
||||
if ntype in ("decision", "action")
|
||||
}
|
||||
if not checkable:
|
||||
continue
|
||||
covered = checkable & ref_set
|
||||
coverage = len(covered) / len(checkable) if checkable else 1.0
|
||||
|
||||
if coverage < threshold:
|
||||
missing = checkable - ref_set
|
||||
gaps["missing_paths"].append(
|
||||
f"{image_id}: 覆盖率 {coverage:.0%} < {threshold:.0%}, "
|
||||
f"未覆盖节点: {sorted(missing)}"
|
||||
)
|
||||
|
||||
# --- Check logic tree path consistency ---
|
||||
# A unit's logic_tree_nodes must form a valid (connected) path in the tree.
|
||||
if all_paths is not None:
|
||||
for fu in units:
|
||||
uid = fu.get("unit_id", "?")
|
||||
for src in fu.get("sources", []):
|
||||
if src.get("type") != "logic_tree":
|
||||
continue
|
||||
image_id = src.get("image_id", "")
|
||||
unit_nodes = set(src.get("logic_tree_nodes", []))
|
||||
if not unit_nodes:
|
||||
continue
|
||||
# Check if there exists a path containing all these nodes
|
||||
valid = False
|
||||
for path in all_paths.get(image_id, []):
|
||||
path_nodes = set(path.get("node_ids", []))
|
||||
if unit_nodes.issubset(path_nodes):
|
||||
valid = True
|
||||
break
|
||||
if not valid:
|
||||
gaps["format_issues"].append(
|
||||
f"{uid}: logic_tree_nodes 不构成有效路径 "
|
||||
f"(image={image_id}, nodes={sorted(unit_nodes)})"
|
||||
)
|
||||
|
||||
# --- Check for trivial units (only state/switch nodes, no actions) ---
|
||||
if all_paths is not None:
|
||||
for fu in units:
|
||||
uid = fu.get("unit_id", "?")
|
||||
has_logic_ref = False
|
||||
has_action = False
|
||||
has_non_trivial_decision = False
|
||||
for src in fu.get("sources", []):
|
||||
if src.get("type") != "logic_tree":
|
||||
continue
|
||||
has_logic_ref = True
|
||||
node_ids = src.get("logic_tree_nodes", [])
|
||||
node_types = {}
|
||||
for image_id, nset in all_nodes.items():
|
||||
for nid in node_ids:
|
||||
if nid in nset:
|
||||
node_types[nid] = nset[nid]
|
||||
for nid in node_ids:
|
||||
ntype = node_types.get(nid, "")
|
||||
if ntype == "action":
|
||||
has_action = True
|
||||
# Count decisions beyond first level (e.g., n1/n2 are just root+switch)
|
||||
decisions = [nid for nid in node_ids
|
||||
if node_types.get(nid, "") == "decision"]
|
||||
if len(decisions) > 1:
|
||||
has_non_trivial_decision = True
|
||||
if has_logic_ref and not has_action and not has_non_trivial_decision:
|
||||
gaps["format_issues"].append(
|
||||
f"{uid}: 可能为空壳单元(仅有state/开关节点,无action或深层decision)"
|
||||
)
|
||||
|
||||
# --- Concept parent validity ---
|
||||
concept_names = {c["name"] for c in concepts}
|
||||
for c in concepts:
|
||||
name = c.get("name", "?")
|
||||
parent = c.get("parent") # can be None for scope-level
|
||||
if parent is not None and parent not in concept_names:
|
||||
gaps["parent_issues"].append(
|
||||
f"concept '{name}' 的 parent '{parent}' 不存在"
|
||||
)
|
||||
# Warn about scope-level concepts without parent=null
|
||||
for c in concepts:
|
||||
if c.get("parent") is not None:
|
||||
continue
|
||||
name = c.get("name", "")
|
||||
# Scope-level concepts (国内/海外) should have parent=null
|
||||
if name not in ("国内", "海外", ""):
|
||||
gaps["parent_issues"].append(
|
||||
f"concept '{name}' 的 parent 为 null,但它可能不是 scope 概念"
|
||||
)
|
||||
|
||||
# --- Check for missing scope concepts ---
|
||||
if "国内" not in concept_names:
|
||||
gaps["missing_concepts"].append("缺少 scope 概念: 国内")
|
||||
if "海外" not in concept_names and any(
|
||||
"海外" in s.get("source", "") for s in doc.get("sections", [])
|
||||
):
|
||||
gaps["missing_concepts"].append("缺少 scope 概念: 海外")
|
||||
|
||||
passed = (
|
||||
not gaps["missing_paths"]
|
||||
and not gaps["format_issues"]
|
||||
and not gaps["parent_issues"]
|
||||
)
|
||||
return passed, gaps
|
||||
|
||||
|
||||
def _collect_logic_tree_nodes(doc: dict) -> dict[str, dict[str, str]]:
|
||||
"""Return {image_id: {node_id: node_type}} for all logic trees."""
|
||||
result = {}
|
||||
for img in doc.get("image_analysis", []):
|
||||
lt = img.get("logic_tree")
|
||||
rid = img.get("rid", "")
|
||||
if lt and rid:
|
||||
result[rid] = {n["id"]: n["type"] for n in lt.get("nodes", [])}
|
||||
return result
|
||||
|
||||
|
||||
def _collect_referenced_nodes(units: list[dict]) -> dict[str, set[str]]:
|
||||
"""Return {image_id: {referenced node_ids}} across all function_units."""
|
||||
refs = {}
|
||||
for fu in units:
|
||||
for src in fu.get("sources", []):
|
||||
if src.get("type") == "logic_tree":
|
||||
image_id = src.get("image_id", "")
|
||||
if image_id not in refs:
|
||||
refs[image_id] = set()
|
||||
refs[image_id].update(src.get("logic_tree_nodes", []))
|
||||
return refs
|
||||
|
||||
|
||||
# ---- LLM Calls ----
|
||||
|
||||
|
||||
def extract_json_from_response(text: str) -> str:
|
||||
"""Robustly extract JSON from LLM response."""
|
||||
m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text)
|
||||
if m:
|
||||
return m.group(1).strip()
|
||||
|
||||
start = text.find("{")
|
||||
if start == -1:
|
||||
raise ValueError("No JSON object found in LLM response")
|
||||
|
||||
depth = 0
|
||||
for i in range(start, len(text)):
|
||||
if text[i] == "{":
|
||||
depth += 1
|
||||
elif text[i] == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start : i + 1]
|
||||
|
||||
raise ValueError("Unclosed JSON object in LLM response")
|
||||
|
||||
|
||||
def call_llm(prompt: str, max_retries: int = 2,
|
||||
temperature: float | None = None) -> dict:
|
||||
"""Send prompt to LLM, return parsed JSON dict.
|
||||
|
||||
Args:
|
||||
temperature: Override config.TEMPERATURE. If None, uses config default.
|
||||
"""
|
||||
client = config.llm_client()
|
||||
temp = temperature if temperature is not None else config.TEMPERATURE
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
print(f" LLM 调用 T={temp} (尝试 {attempt + 1}/{max_retries + 1})...", flush=True)
|
||||
try:
|
||||
resp = client.chat.completions.create(
|
||||
model=config.MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个精确的 JSON 输出引擎。只输出合法的 JSON。",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=temp,
|
||||
max_tokens=config.MAX_TOKENS,
|
||||
)
|
||||
content = resp.choices[0].message.content
|
||||
if content is None:
|
||||
raise RuntimeError("LLM returned empty response")
|
||||
|
||||
json_str = extract_json_from_response(content)
|
||||
return json.loads(json_str)
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
print(f" JSON 解析失败: {e}")
|
||||
if attempt < max_retries:
|
||||
time.sleep(2)
|
||||
|
||||
raise RuntimeError("无法从 LLM 响应中解析 JSON")
|
||||
|
||||
|
||||
# ---- Ensemble Orchestration ----
|
||||
|
||||
|
||||
def run_ensemble_semantic_index(doc: dict) -> dict:
|
||||
"""Run N parallel LLM calls at different temperatures, then ensemble-merge.
|
||||
|
||||
1. Enumerate all logic tree paths (once).
|
||||
2. Build the prompt (once — no iterative feedback needed).
|
||||
3. Launch len(ENSEMBLE_TEMPERATURES) parallel LLM calls via ThreadPoolExecutor.
|
||||
4. Collect all results.
|
||||
5. Call ensemble_merge() for deterministic merge.
|
||||
6. Validate final output with _quick_validate().
|
||||
7. Save individual version outputs + merged output.
|
||||
"""
|
||||
all_paths = enumerate_all_paths(doc)
|
||||
print(f" 已枚举逻辑树路径: {sum(len(v) for v in all_paths.values())} 条")
|
||||
|
||||
prompt = build_prompt(doc, "", all_paths)
|
||||
print(f" Prompt 长度: {len(prompt)} 字符")
|
||||
|
||||
temperatures = config.ENSEMBLE_TEMPERATURES
|
||||
print(f" 集成温度: {temperatures}")
|
||||
|
||||
# Parallel LLM calls
|
||||
raw_results: list[tuple[int, float, dict]] = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=len(temperatures)
|
||||
) as executor:
|
||||
future_to_meta = {}
|
||||
for i, temp in enumerate(temperatures):
|
||||
future = executor.submit(call_llm, prompt, 2, temp)
|
||||
future_to_meta[future] = (i, temp)
|
||||
|
||||
for future in concurrent.futures.as_completed(future_to_meta):
|
||||
idx, temp = future_to_meta[future]
|
||||
try:
|
||||
si = future.result()
|
||||
n_units = len(si.get("function_units", []))
|
||||
n_concepts = len(si.get("concepts", []))
|
||||
print(f" T={temp}: {n_concepts} 概念, {n_units} 功能单元")
|
||||
raw_results.append((idx, temp, si))
|
||||
except Exception as e:
|
||||
print(f" T={temp}: FAIL — {e}")
|
||||
raw_results.append((idx, temp, {
|
||||
"feature_name": "", "concepts": [], "function_units": []
|
||||
}))
|
||||
|
||||
if not raw_results:
|
||||
raise RuntimeError("所有集成的 LLM 调用均失败")
|
||||
|
||||
# Sort by temperature for determinism
|
||||
raw_results.sort(key=lambda x: x[1])
|
||||
semantic_indices = [r[2] for r in raw_results]
|
||||
|
||||
# Save individual version outputs
|
||||
version_paths = {
|
||||
0: config.SEMANTIC_INDEX_R1_JSON,
|
||||
1: config.SEMANTIC_INDEX_R2_JSON,
|
||||
2: config.SEMANTIC_INDEX_R3_JSON,
|
||||
}
|
||||
for i, si in enumerate(semantic_indices):
|
||||
out_path = version_paths.get(i)
|
||||
if out_path:
|
||||
config.save_json(si, out_path)
|
||||
print(f" 保存版本 {i} (T={temperatures[i]}): {out_path}")
|
||||
|
||||
# Ensemble merge
|
||||
print(f"\n 集成合并 {len(semantic_indices)} 个版本...")
|
||||
merged = ensemble_merge(semantic_indices)
|
||||
merged["ensemble_temperatures"] = list(temperatures)
|
||||
|
||||
# Validate
|
||||
passed, gaps = _quick_validate(merged, doc, all_paths)
|
||||
merged["validation_passed"] = passed
|
||||
merged["validation_gaps"] = {
|
||||
k: v for k, v in gaps.items() if v
|
||||
}
|
||||
|
||||
# Print summary
|
||||
cs = merged.get("confidence_summary", {})
|
||||
print(f" 合并后: {cs.get('total_concepts', 0)} 概念, "
|
||||
f"{cs.get('total_units', 0)} 功能单元")
|
||||
print(f" 置信度: high={cs.get('high', 0)}, medium={cs.get('medium', 0)}, "
|
||||
f"low={cs.get('low', 0)}")
|
||||
print(f" 验证: {'PASS' if passed else 'GAPS FOUND'}")
|
||||
if not passed:
|
||||
for k, v in gaps.items():
|
||||
if v:
|
||||
print(f" {k}: {len(v)} 个问题")
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
# ---- Main ----
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("阶段一:集成语义索引 (Ensemble Semantic Index)")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Load input
|
||||
print(f"\n[1/3] 加载输入文档: {config.INPUT_JSON}")
|
||||
doc = config.load_input_document()
|
||||
print(f" 已加载 {len(doc.get('sections', []))} 个 section, "
|
||||
f"{len(doc.get('image_analysis', []))} 张图片分析")
|
||||
|
||||
# 2. Run ensemble generation + merge
|
||||
print(f"\n[2/3] 运行集成语义索引 ({len(config.ENSEMBLE_TEMPERATURES)} 个温度版本)...")
|
||||
merged_index = run_ensemble_semantic_index(doc)
|
||||
|
||||
# 3. Save outputs
|
||||
print(f"\n[3/3] 保存最终语义索引: {config.SEMANTIC_INDEX_JSON}")
|
||||
config.save_json(merged_index, config.SEMANTIC_INDEX_JSON)
|
||||
|
||||
# Also save path enumeration for downstream use
|
||||
all_paths = enumerate_all_paths(doc)
|
||||
config.save_json(
|
||||
{"logic_tree_paths": {k: v for k, v in all_paths.items()}},
|
||||
config.PATH_ENUM_JSON,
|
||||
)
|
||||
print(f" 路径枚举: {config.PATH_ENUM_JSON}")
|
||||
|
||||
cs = merged_index.get("confidence_summary", {})
|
||||
n_concepts = cs.get("total_concepts", len(merged_index.get("concepts", [])))
|
||||
n_units = cs.get("total_units", len(merged_index.get("function_units", [])))
|
||||
n_versions = merged_index.get("ensemble_versions", len(config.ENSEMBLE_TEMPERATURES))
|
||||
print(f"\n完成! {n_versions} 版本集成, {n_concepts} 个概念, {n_units} 个功能单元.")
|
||||
print(f"输出: {config.SEMANTIC_INDEX_JSON}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Stage 2.5: Branch Coverage Auto-Completion.
|
||||
|
||||
1. Enumerates all root-to-leaf paths in every logic tree
|
||||
2. Compares paths against existing IR rules to find uncovered paths
|
||||
3. Generates synthetic function_units for uncovered paths
|
||||
4. Calls LLM (same extract_rules_for_unit) to produce rules for synthetic units
|
||||
5. Iterates up to MAX_RETRIES_PER_STAGE rounds to reach COVERAGE_TARGET
|
||||
|
||||
Outputs:
|
||||
- output/path_enumeration.json
|
||||
- output/ir_autocomplete_fragments.json
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import config
|
||||
|
||||
|
||||
# ---- Path Enumeration (shared with step1, duplicated for module independence) ----
|
||||
|
||||
|
||||
def enumerate_all_paths(doc: dict) -> dict[str, list[dict]]:
|
||||
"""Enumerate all root-to-leaf paths for every logic tree."""
|
||||
from step1_semantic_index import enumerate_all_paths as _enum
|
||||
return _enum(doc)
|
||||
|
||||
|
||||
# ---- Coverage Analysis ----
|
||||
|
||||
|
||||
def find_referenced_path_ids(rules: list[dict]) -> dict[str, set[str]]:
|
||||
"""Map each rule to the set of logic tree nodes it references.
|
||||
|
||||
Returns {rule_id: set of "image_id:node_id" pairs}
|
||||
"""
|
||||
result = {}
|
||||
for rule in rules:
|
||||
rid = rule.get("rule_id", "?")
|
||||
refs = set()
|
||||
for src in rule.get("sources", []):
|
||||
if src.get("type") == "logic_tree":
|
||||
image_id = src.get("image_id", "")
|
||||
for nid in src.get("node_ids", []):
|
||||
refs.add(f"{image_id}:{nid}")
|
||||
result[rid] = refs
|
||||
return result
|
||||
|
||||
|
||||
def compute_path_coverage(
|
||||
all_paths: dict[str, list[dict]], rules: list[dict]
|
||||
) -> tuple[list[dict], list[dict], dict]:
|
||||
"""Compute coverage of enumerated paths by existing rules.
|
||||
|
||||
Returns (covered_paths, uncovered_paths, stats).
|
||||
A path is "covered" if at least one rule's node_ids form a superset
|
||||
of the path's decision+action nodes for that image.
|
||||
"""
|
||||
# Build per-rule node sets keyed by image_id
|
||||
rule_node_sets = {} # {rule_id: {image_id: set(node_ids)}}
|
||||
for rule in rules:
|
||||
rid = rule.get("rule_id", "?")
|
||||
rule_node_sets[rid] = {}
|
||||
for src in rule.get("sources", []):
|
||||
if src.get("type") == "logic_tree":
|
||||
image_id = src.get("image_id", "")
|
||||
rule_node_sets[rid].setdefault(image_id, set()).update(
|
||||
src.get("node_ids", [])
|
||||
)
|
||||
|
||||
covered = []
|
||||
uncovered = []
|
||||
|
||||
for image_id, paths in all_paths.items():
|
||||
for path in paths:
|
||||
# Get checkable nodes for this path (decision + action)
|
||||
checkable = set(
|
||||
n["id"] for n in path["nodes"]
|
||||
if n["type"] in ("decision", "action")
|
||||
)
|
||||
if not checkable:
|
||||
# Path with no decision/action nodes — trivially covered
|
||||
covered.append(path)
|
||||
continue
|
||||
|
||||
path_covered = False
|
||||
for rid, img_sets in rule_node_sets.items():
|
||||
rule_nodes = img_sets.get(image_id, set())
|
||||
if checkable.issubset(rule_nodes):
|
||||
path_covered = True
|
||||
break
|
||||
|
||||
if path_covered:
|
||||
covered.append(path)
|
||||
else:
|
||||
uncovered.append(path)
|
||||
|
||||
total = len(covered) + len(uncovered)
|
||||
stats = {
|
||||
"total_paths": total,
|
||||
"covered_paths": len(covered),
|
||||
"uncovered_paths": len(uncovered),
|
||||
"coverage_pct": round(len(covered) / total * 100, 1) if total > 0 else 100.0,
|
||||
}
|
||||
return covered, uncovered, stats
|
||||
|
||||
|
||||
# ---- Synthetic Function Unit Generation ----
|
||||
|
||||
|
||||
def generate_synthetic_unit(path: dict, unit_seq: int) -> dict:
|
||||
"""Create a synthetic function_unit from an uncovered logic tree path.
|
||||
|
||||
Infers preconditions and trigger from the decision nodes along the path.
|
||||
"""
|
||||
node_map = {n["id"]: n for n in path["nodes"]}
|
||||
|
||||
# Infer switch state from path
|
||||
switch = _infer_switch_state(path)
|
||||
|
||||
# Infer app_type from path
|
||||
app_type = _infer_app_type(path)
|
||||
|
||||
# Infer app_state from path
|
||||
app_state = _infer_app_state(path)
|
||||
|
||||
# Infer geographic_scope from section context
|
||||
scope = _infer_scope(path)
|
||||
|
||||
# Build description from path meaning
|
||||
description = f"自动补全: {path.get('meaning', '')}"
|
||||
if switch:
|
||||
description = f"开关{switch}, {description}"
|
||||
|
||||
# Build path list
|
||||
path_labels = []
|
||||
if scope:
|
||||
path_labels.append(scope)
|
||||
if switch:
|
||||
path_labels.append(f"开关{switch}")
|
||||
if app_type:
|
||||
path_labels.append(app_type)
|
||||
if app_state:
|
||||
path_labels.append(app_state)
|
||||
# Add behavior from terminal action
|
||||
action_nodes = path.get("action_nodes", [])
|
||||
if action_nodes:
|
||||
last_action = action_nodes[-1].get("label", "")
|
||||
path_labels.append(last_action[:20])
|
||||
|
||||
unit_id = f"FU-AUTO-{path['image_id']}-{unit_seq:03d}"
|
||||
seq = f"{unit_seq:03d}"
|
||||
|
||||
return {
|
||||
"unit_id": unit_id,
|
||||
"name": f"自动补全-{path.get('meaning', '')[:60]}",
|
||||
"description": description,
|
||||
"path": path_labels,
|
||||
"auto_generated": True,
|
||||
"sources": [
|
||||
{
|
||||
"section": "",
|
||||
"type": "logic_tree",
|
||||
"image_id": path["image_id"],
|
||||
"logic_tree_nodes": path.get("node_ids", []),
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _infer_switch_state(path: dict) -> str:
|
||||
"""Infer switch state from decision nodes in path."""
|
||||
for n in path["nodes"]:
|
||||
label = n.get("label", "")
|
||||
branch = n.get("branch_taken", "")
|
||||
if "开关" in label and n["type"] == "decision":
|
||||
if branch == "开启":
|
||||
return "开启"
|
||||
elif branch == "关闭":
|
||||
return "关闭"
|
||||
return ""
|
||||
|
||||
|
||||
def _infer_app_type(path: dict) -> str:
|
||||
"""Infer app type from state nodes in path."""
|
||||
type_map = {
|
||||
"其他应用": "其他应用",
|
||||
"SDK限制": "SDK限制",
|
||||
"通过接入SDK限制的应用": "SDK限制",
|
||||
"系统限制": "系统限制",
|
||||
"通过系统限制应用": "系统限制",
|
||||
}
|
||||
for n in path["nodes"]:
|
||||
if n["type"] == "state":
|
||||
for key, val in type_map.items():
|
||||
if key in n.get("label", ""):
|
||||
return val
|
||||
return ""
|
||||
|
||||
|
||||
def _infer_app_state(path: dict) -> str:
|
||||
"""Infer app state (前台/后台) from decision nodes."""
|
||||
for n in path["nodes"]:
|
||||
label = n.get("label", "")
|
||||
branch = n.get("branch_taken", "")
|
||||
if "前台" in label:
|
||||
if branch == "是":
|
||||
return "前台"
|
||||
elif branch == "否":
|
||||
return "后台"
|
||||
return ""
|
||||
|
||||
|
||||
def _infer_scope(path: dict) -> str:
|
||||
"""Infer geographic scope. Defaults to 国内."""
|
||||
return "国内"
|
||||
|
||||
|
||||
# ---- LLM Extraction for Synthetic Units ----
|
||||
|
||||
|
||||
def extract_rules_for_synthetic_units(
|
||||
synthetic_units: list[dict], doc: dict, max_retries: int | None = None
|
||||
) -> list[dict]:
|
||||
"""Extract IR rules for synthetic function_units using step2's LLM logic."""
|
||||
from step2_ir_extraction import (
|
||||
build_document_lookup,
|
||||
extract_context_package,
|
||||
extract_rules_for_unit,
|
||||
)
|
||||
|
||||
if max_retries is None:
|
||||
max_retries = config.MAX_RETRIES_PER_STAGE
|
||||
|
||||
sections_by_source, image_by_rid, conflicts_by_section = build_document_lookup(doc)
|
||||
|
||||
fragments = []
|
||||
for unit in synthetic_units:
|
||||
pkg = extract_context_package(
|
||||
unit, doc, sections_by_source, image_by_rid, conflicts_by_section
|
||||
)
|
||||
# Enrich pkg with unit's own path and description
|
||||
pkg["unit_path"] = unit.get("path", [])
|
||||
pkg["unit_description"] = unit.get("description", pkg["unit_description"])
|
||||
|
||||
try:
|
||||
rules = extract_rules_for_unit(pkg, max_retries)
|
||||
except Exception as e:
|
||||
rules = []
|
||||
|
||||
fragments.append({
|
||||
"unit_id": unit["unit_id"],
|
||||
"unit_name": unit.get("name", ""),
|
||||
"rules": rules,
|
||||
"auto_generated": True,
|
||||
})
|
||||
print(f" {unit['unit_id']}: {len(rules)} 条规则")
|
||||
|
||||
return fragments
|
||||
|
||||
|
||||
# ---- Iterative Auto-Completion ----
|
||||
|
||||
|
||||
def run_autocomplete(
|
||||
all_paths: dict[str, list[dict]],
|
||||
existing_rules: list[dict],
|
||||
doc: dict,
|
||||
) -> tuple[list[dict], dict]:
|
||||
"""Run iterative auto-completion. Returns (autocomplete_fragments, final_stats)."""
|
||||
print(f"\n 初始路径覆盖率分析...")
|
||||
covered, uncovered, stats = compute_path_coverage(all_paths, existing_rules)
|
||||
print(f" 覆盖: {stats['covered_paths']}/{stats['total_paths']} "
|
||||
f"({stats['coverage_pct']}%)")
|
||||
|
||||
if not uncovered:
|
||||
print(f" 所有路径已覆盖,无需自动补全")
|
||||
return [], stats
|
||||
|
||||
print(f" 未覆盖路径: {len(uncovered)} 条")
|
||||
|
||||
all_fragments = []
|
||||
best_stats = stats
|
||||
|
||||
for round_n in range(1, config.MAX_RETRIES_PER_STAGE + 1):
|
||||
if not uncovered:
|
||||
break
|
||||
|
||||
print(f"\n--- 自动补全 第 {round_n} 轮 ---")
|
||||
print(f" 为 {len(uncovered)} 条未覆盖路径生成合成单元...")
|
||||
|
||||
# Generate synthetic units
|
||||
start_seq = (round_n - 1) * len(uncovered) + 1
|
||||
synthetic_units = [
|
||||
generate_synthetic_unit(path, start_seq + i)
|
||||
for i, path in enumerate(uncovered)
|
||||
]
|
||||
|
||||
# Extract rules via LLM
|
||||
max_llm_workers = min(2, len(synthetic_units))
|
||||
if len(synthetic_units) <= 1:
|
||||
fragments = extract_rules_for_synthetic_units(synthetic_units, doc)
|
||||
else:
|
||||
# Sequential to avoid flooding the API
|
||||
fragments = extract_rules_for_synthetic_units(synthetic_units, doc)
|
||||
|
||||
all_fragments.extend(fragments)
|
||||
|
||||
# Re-compute coverage
|
||||
all_rules = existing_rules + [
|
||||
rule for f in fragments for rule in f.get("rules", [])
|
||||
]
|
||||
covered, uncovered, stats = compute_path_coverage(all_paths, all_rules)
|
||||
print(f" 第 {round_n} 轮后覆盖: {stats['covered_paths']}/{stats['total_paths']} "
|
||||
f"({stats['coverage_pct']}%)")
|
||||
|
||||
if stats["coverage_pct"] > best_stats["coverage_pct"]:
|
||||
best_stats = stats
|
||||
|
||||
if stats["coverage_pct"] >= config.COVERAGE_TARGET * 100:
|
||||
print(f" 达到目标覆盖率 {config.COVERAGE_TARGET:.0%},停止")
|
||||
break
|
||||
|
||||
# If coverage didn't improve, try a different approach next round
|
||||
uncovered_decision_nodes = set()
|
||||
for p in uncovered:
|
||||
for n in p.get("decision_nodes", []):
|
||||
uncovered_decision_nodes.add(n.get("label", ""))
|
||||
if not uncovered_decision_nodes:
|
||||
print(f" 无更多可补全路径,停止")
|
||||
break
|
||||
|
||||
return all_fragments, best_stats
|
||||
|
||||
|
||||
# ---- Main ----
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("阶段 2.5:分支覆盖自动补全")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Load inputs
|
||||
print(f"\n[1/5] 加载输入...")
|
||||
doc = config.load_input_document()
|
||||
fragments = config.load_json(config.IR_FRAGMENTS_JSON)
|
||||
|
||||
all_rules = []
|
||||
for f in fragments:
|
||||
all_rules.extend(f.get("rules", []))
|
||||
|
||||
print(f" 已有规则: {len(all_rules)} 条")
|
||||
|
||||
# 2. Enumerate paths
|
||||
print(f"\n[2/5] 枚举逻辑树路径...")
|
||||
all_paths = enumerate_all_paths(doc)
|
||||
total_paths = sum(len(v) for v in all_paths.values())
|
||||
print(f" 共 {total_paths} 条路径")
|
||||
|
||||
# Save path enumeration for downstream audit
|
||||
path_enum_data = {
|
||||
"logic_tree_paths": {
|
||||
k: [{kk: vv for kk, vv in p.items() if kk != "nodes"} for p in v]
|
||||
for k, v in all_paths.items()
|
||||
},
|
||||
"total_paths": total_paths,
|
||||
}
|
||||
config.save_json(path_enum_data, config.PATH_ENUM_JSON)
|
||||
|
||||
# 3. Run auto-completion
|
||||
print(f"\n[3/5] 运行自动补全...")
|
||||
autocomplete_fragments, final_stats = run_autocomplete(
|
||||
all_paths, all_rules, doc
|
||||
)
|
||||
|
||||
# 4. Save
|
||||
print(f"\n[4/5] 保存自动补全片段...")
|
||||
config.save_json(
|
||||
autocomplete_fragments, config.IR_AUTOCOMPLETE_FRAGMENTS_JSON
|
||||
)
|
||||
print(f" 输出: {config.IR_AUTOCOMPLETE_FRAGMENTS_JSON}")
|
||||
print(f" 生成 {len(autocomplete_fragments)} 个补全片段")
|
||||
|
||||
# 5. Summary
|
||||
print(f"\n[5/5] 完成!")
|
||||
print(f" 最终路径覆盖: {final_stats['covered_paths']}/{final_stats['total_paths']} "
|
||||
f"({final_stats['coverage_pct']}%)")
|
||||
|
||||
if final_stats["coverage_pct"] < config.COVERAGE_TARGET * 100:
|
||||
remaining = final_stats["total_paths"] - final_stats["covered_paths"]
|
||||
print(f" WARN: {remaining} 条路径仍未覆盖,将在审计报告中列出")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,508 @@
|
||||
"""
|
||||
Stage 2: Per Function Unit IR Extraction.
|
||||
|
||||
For each function unit from the semantic index, constructs a precision context
|
||||
package and calls the LLM to extract detailed IR rules.
|
||||
|
||||
Runs multiple LLM calls in parallel (up to MAX_CONCURRENCY).
|
||||
|
||||
Output: output/ir_fragments.json
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import config
|
||||
|
||||
|
||||
MAX_CONCURRENCY = 3 # Max parallel LLM calls
|
||||
|
||||
|
||||
def load_semantic_index() -> dict:
|
||||
"""Load the semantic index from Stage 1."""
|
||||
return config.load_json(config.SEMANTIC_INDEX_JSON)
|
||||
|
||||
|
||||
def build_document_lookup(doc: dict):
|
||||
"""Build lookup structures for fast context extraction from the document."""
|
||||
|
||||
# sections_by_source: "3.1.1" -> section dict
|
||||
sections_by_source = {}
|
||||
for section in doc.get("sections", []):
|
||||
source = section.get("source", "")
|
||||
# Normalize: extract leading number like "3.1.1"
|
||||
parts = source.split()
|
||||
if parts:
|
||||
key = parts[0].strip()
|
||||
sections_by_source[key] = section
|
||||
|
||||
# image_by_rid: "rId16" -> image_analysis entry
|
||||
image_by_rid = {}
|
||||
for img in doc.get("image_analysis", []):
|
||||
rid = img.get("rid", "")
|
||||
if rid:
|
||||
image_by_rid[rid] = img
|
||||
|
||||
# Conflicts indexed by section
|
||||
conflicts_by_section = {}
|
||||
for c in doc.get("resolved_conflicts", []):
|
||||
section = c.get("section", "")
|
||||
key = section.split()[0] if section else ""
|
||||
conflicts_by_section.setdefault(key, []).append(c)
|
||||
|
||||
return sections_by_source, image_by_rid, conflicts_by_section
|
||||
|
||||
|
||||
def extract_context_package(
|
||||
fu: dict, doc: dict, sections_by_source: dict, image_by_rid: dict,
|
||||
conflicts_by_section: dict
|
||||
) -> dict:
|
||||
"""Build a precision context package for a single function unit."""
|
||||
texts = []
|
||||
tables = []
|
||||
logic_trees = []
|
||||
seen_sections = set()
|
||||
seen_images = set()
|
||||
|
||||
for src in fu.get("sources", []):
|
||||
src_type = src.get("type", "")
|
||||
section_key = src.get("section", "").split()[0] if src.get("section") else ""
|
||||
|
||||
# --- Text source ---
|
||||
if src_type in ("table", "para") and section_key:
|
||||
if section_key in seen_sections:
|
||||
continue
|
||||
seen_sections.add(section_key)
|
||||
|
||||
section = sections_by_source.get(section_key)
|
||||
if section is None:
|
||||
# Fuzzy match by prefix
|
||||
for key in sections_by_source:
|
||||
if key.startswith(section_key):
|
||||
section = sections_by_source[key]
|
||||
break
|
||||
|
||||
if section:
|
||||
for block in section.get("blocks", []):
|
||||
if block["type"] == "para":
|
||||
texts.append({
|
||||
"section": section_key,
|
||||
"text": block["text"]
|
||||
})
|
||||
elif block["type"] == "table":
|
||||
row_num = src.get("row") if src_type == "table" else None
|
||||
if row_num is not None:
|
||||
# Extract only the specific row
|
||||
matching_rows = []
|
||||
for r in block.get("rows", []):
|
||||
for c in r.get("columns", []):
|
||||
if c.get("row") == row_num:
|
||||
matching_rows.append({
|
||||
"headers": block.get("headers", []),
|
||||
"cells": {
|
||||
col["name"]: col["text"]
|
||||
for col in r["columns"]
|
||||
},
|
||||
"row": row_num
|
||||
})
|
||||
break
|
||||
tables.append({
|
||||
"section": section_key,
|
||||
"headers": block.get("headers", []),
|
||||
"rows": matching_rows,
|
||||
"all_rows": [
|
||||
{
|
||||
"row": col.get("row"),
|
||||
"name": col.get("name"),
|
||||
"text": col.get("text")
|
||||
}
|
||||
for row in block.get("rows", [])
|
||||
for col in row.get("columns", [])
|
||||
]
|
||||
})
|
||||
else:
|
||||
# Include full table
|
||||
tables.append({
|
||||
"section": section_key,
|
||||
"headers": block.get("headers", []),
|
||||
"all_rows": [
|
||||
{
|
||||
"row": col.get("row"),
|
||||
"name": col.get("name"),
|
||||
"text": col.get("text")
|
||||
}
|
||||
for row in block.get("rows", [])
|
||||
for col in row.get("columns", [])
|
||||
]
|
||||
})
|
||||
|
||||
# --- Logic tree source ---
|
||||
if src_type == "logic_tree":
|
||||
image_id = src.get("image_id", "")
|
||||
if not image_id or image_id in seen_images:
|
||||
continue
|
||||
seen_images.add(image_id)
|
||||
|
||||
img = image_by_rid.get(image_id)
|
||||
if img:
|
||||
lt = img.get("logic_tree")
|
||||
if lt:
|
||||
logic_trees.append({
|
||||
"image_id": image_id,
|
||||
"description": img.get("description", ""),
|
||||
"tree": lt
|
||||
})
|
||||
|
||||
# Include relevant resolved conflicts
|
||||
relevant_conflicts = []
|
||||
for section_key in seen_sections:
|
||||
for c in conflicts_by_section.get(section_key, []):
|
||||
relevant_conflicts.append(c)
|
||||
|
||||
return {
|
||||
"unit_id": fu["unit_id"],
|
||||
"unit_name": fu.get("name", ""),
|
||||
"unit_description": fu.get("description", ""),
|
||||
"unit_path": fu.get("path", []),
|
||||
"texts": texts,
|
||||
"tables": tables,
|
||||
"logic_trees": logic_trees,
|
||||
"resolved_conflicts": relevant_conflicts
|
||||
}
|
||||
|
||||
|
||||
def format_context_package(pkg: dict) -> str:
|
||||
"""Format a context package as a readable string for the prompt."""
|
||||
parts = []
|
||||
|
||||
# Texts
|
||||
parts.append("【文字段落】")
|
||||
for i, t in enumerate(pkg.get("texts", [])):
|
||||
parts.append(f"[{t.get('section', '?')}] {t.get('text', '')}")
|
||||
if not pkg.get("texts"):
|
||||
parts.append("(无)")
|
||||
|
||||
# Tables
|
||||
parts.append("\n【表格数据】")
|
||||
for i, tbl in enumerate(pkg.get("tables", [])):
|
||||
parts.append(f"表格 {i+1} (section={tbl.get('section', '?')})")
|
||||
headers = tbl.get("headers", [])
|
||||
parts.append(f" 表头: {headers}")
|
||||
parts.append(" 全部行数据:")
|
||||
for row in tbl.get("all_rows", []):
|
||||
parts.append(
|
||||
f" 行{row.get('row','?')}[{row.get('name','?')}]: {row.get('text','')}"
|
||||
)
|
||||
# Highlight matched rows if any
|
||||
matched = tbl.get("rows", [])
|
||||
if matched:
|
||||
parts.append(" <重点关注行>:")
|
||||
for mr in matched:
|
||||
parts.append(f" 行{mr.get('row','?')}: {mr.get('cells', {})}")
|
||||
if not pkg.get("tables"):
|
||||
parts.append("(无)")
|
||||
|
||||
# Logic trees
|
||||
parts.append("\n【逻辑树】")
|
||||
for i, lt in enumerate(pkg.get("logic_trees", [])):
|
||||
parts.append(f"逻辑树 {i+1} (image_id={lt.get('image_id', '?')})")
|
||||
parts.append(f" 描述: {lt.get('description', '')[:200]}")
|
||||
tree = lt.get("tree", {})
|
||||
parts.append(f" 根: {tree.get('root', '?')}")
|
||||
parts.append(" 节点:")
|
||||
for node in tree.get("nodes", []):
|
||||
nid = node.get("id", "?")
|
||||
ntype = node.get("type", "?")
|
||||
desc = node.get("description", "") or node.get("condition", "")
|
||||
parts.append(f" [{ntype}] {nid}: {desc}")
|
||||
for br in node.get("branches", []):
|
||||
parts.append(f" → {br['value']} → {br['target']}")
|
||||
if not pkg.get("logic_trees"):
|
||||
parts.append("(无)")
|
||||
|
||||
# Conflicts
|
||||
conflicts = pkg.get("resolved_conflicts", [])
|
||||
if conflicts:
|
||||
parts.append("\n【图文冲突仲裁】")
|
||||
for c in conflicts:
|
||||
parts.append(
|
||||
f" [{c.get('conflict_type', '?')}] 以{c.get('source', '?')}为准: "
|
||||
f"{c.get('correction', '')}"
|
||||
)
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _escape_json_for_format(s: str) -> str:
|
||||
"""Escape curly braces in a JSON string for use with str.format()."""
|
||||
return s.replace("{", "{{").replace("}", "}}")
|
||||
|
||||
|
||||
def build_prompt(pkg: dict, format_feedback: str = "") -> str:
|
||||
"""Build the LLM prompt for a single function unit."""
|
||||
template_path = Path(config.PROMPTS_DIR) / "step2_ir_extraction.txt"
|
||||
template = template_path.read_text(encoding="utf-8")
|
||||
|
||||
prompt = template.format(
|
||||
unit_id=pkg["unit_id"],
|
||||
unit_name=_escape_json_for_format(pkg["unit_name"]),
|
||||
unit_description=_escape_json_for_format(pkg["unit_description"]),
|
||||
texts=_escape_json_for_format(
|
||||
json.dumps(pkg.get("texts", []), ensure_ascii=False, indent=2)
|
||||
),
|
||||
tables=_escape_json_for_format(
|
||||
json.dumps(pkg.get("tables", []), ensure_ascii=False, indent=2)
|
||||
),
|
||||
logic_trees=_escape_json_for_format(
|
||||
json.dumps(pkg.get("logic_trees", []), ensure_ascii=False, indent=2)
|
||||
),
|
||||
resolved_conflicts=_escape_json_for_format(
|
||||
json.dumps(pkg.get("resolved_conflicts", []), ensure_ascii=False, indent=2)
|
||||
),
|
||||
format_feedback=_escape_json_for_format(format_feedback),
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
def extract_json_from_response(text: str) -> str:
|
||||
"""Extract JSON array from LLM response."""
|
||||
m = re.search(r"```(?:json)?\s*(\[[\s\S]*?\])\s*```", text)
|
||||
if m:
|
||||
return m.group(1).strip()
|
||||
|
||||
# Find outermost [ ... ]
|
||||
start = text.find("[")
|
||||
if start == -1:
|
||||
raise ValueError("No JSON array found in LLM response")
|
||||
|
||||
depth = 0
|
||||
for i in range(start, len(text)):
|
||||
if text[i] == "[":
|
||||
depth += 1
|
||||
elif text[i] == "]":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start : i + 1]
|
||||
|
||||
raise ValueError("Unclosed JSON array in LLM response")
|
||||
|
||||
|
||||
def _check_rule_fields(rules: list[dict]) -> tuple[bool, list[dict]]:
|
||||
"""Validate each rule has required fields. Returns (passed, failures).
|
||||
|
||||
Each failure: {rule_id, field, issue}
|
||||
"""
|
||||
failures = []
|
||||
for j, rule in enumerate(rules):
|
||||
if not isinstance(rule, dict):
|
||||
failures.append({"rule_id": f"rule[{j}]", "field": "-", "issue": "规则不是 dict"})
|
||||
continue
|
||||
rid = rule.get("rule_id") or f"rule[{j}]"
|
||||
|
||||
if not rule.get("path"):
|
||||
failures.append({"rule_id": rid, "field": "path", "issue": "缺少 path 字段(必填)"})
|
||||
|
||||
precond = rule.get("precondition") or {}
|
||||
if not precond.get("geographic_scope"):
|
||||
failures.append({"rule_id": rid, "field": "precondition.geographic_scope", "issue": "缺少 geographic_scope(必填)"})
|
||||
|
||||
for k, action in enumerate(rule.get("actions") or []):
|
||||
if not isinstance(action, dict):
|
||||
continue
|
||||
if action.get("type") == "user_interaction":
|
||||
content = action.get("content") or ""
|
||||
if not content:
|
||||
failures.append({
|
||||
"rule_id": rid, "field": f"actions[{k}].content",
|
||||
"issue": "user_interaction 的 content 为空"
|
||||
})
|
||||
elif any(ph in content for ph in ["文案由业务定义", "待定", "自定义"]):
|
||||
failures.append({
|
||||
"rule_id": rid, "field": f"actions[{k}].content",
|
||||
"issue": f"content 包含占位符: '{content}'"
|
||||
})
|
||||
|
||||
trigger = rule.get("trigger") or {}
|
||||
for k, cond in enumerate(trigger.get("conditions") or []):
|
||||
if isinstance(cond, dict):
|
||||
if not cond.get("signal"):
|
||||
failures.append({
|
||||
"rule_id": rid, "field": f"trigger.conditions[{k}].signal",
|
||||
"issue": "缺少 signal"
|
||||
})
|
||||
if not cond.get("operator"):
|
||||
failures.append({
|
||||
"rule_id": rid, "field": f"trigger.conditions[{k}].operator",
|
||||
"issue": "缺少 operator"
|
||||
})
|
||||
if "value" not in cond:
|
||||
failures.append({
|
||||
"rule_id": rid, "field": f"trigger.conditions[{k}].value",
|
||||
"issue": "缺少 value"
|
||||
})
|
||||
|
||||
return len(failures) == 0, failures
|
||||
|
||||
|
||||
def _build_fix_prompt(failures: list[dict]) -> str:
|
||||
"""Build a format-fix instruction block for the prompt."""
|
||||
if not failures:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"\n## 上一轮格式问题修正\n",
|
||||
"上一轮输出的规则存在以下格式问题,请修正后重新输出:\n",
|
||||
]
|
||||
for f in failures:
|
||||
lines.append(f"- **{f['rule_id']}.{f['field']}**: {f['issue']}")
|
||||
|
||||
lines.append("\n请修正以上所有问题,重新输出完整的规则数组。")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def extract_rules_for_unit(pkg: dict, max_retries: int | None = None) -> list[dict]:
|
||||
"""Call LLM for one function unit, return its IR rules.
|
||||
|
||||
Includes format validation with auto-fix retries.
|
||||
"""
|
||||
if max_retries is None:
|
||||
max_retries = config.MAX_RETRIES_PER_STAGE
|
||||
client = config.llm_client()
|
||||
prompt = build_prompt(pkg)
|
||||
last_failures = []
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
# Append format feedback on retry
|
||||
if attempt > 0 and last_failures:
|
||||
fix_text = _build_fix_prompt(last_failures)
|
||||
prompt = build_prompt(pkg, format_feedback=fix_text)
|
||||
|
||||
try:
|
||||
resp = client.chat.completions.create(
|
||||
model=config.MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个精确的 JSON 输出引擎。只输出合法的 JSON 数组。",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=config.TEMPERATURE,
|
||||
max_tokens=config.MAX_TOKENS,
|
||||
)
|
||||
content = resp.choices[0].message.content
|
||||
if content is None:
|
||||
raise RuntimeError("LLM returned empty response")
|
||||
|
||||
json_str = extract_json_from_response(content)
|
||||
rules = json.loads(json_str)
|
||||
if not isinstance(rules, list):
|
||||
raise ValueError(f"Expected JSON array, got {type(rules).__name__}")
|
||||
|
||||
# Format validation
|
||||
passed, failures = _check_rule_fields(rules)
|
||||
if passed:
|
||||
return rules
|
||||
|
||||
# Format issues found — retry with fix instructions
|
||||
print(f" 格式问题 ({len(failures)} 个): {[f['field'] for f in failures[:5]]}")
|
||||
last_failures = failures
|
||||
if attempt < max_retries:
|
||||
time.sleep(1)
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
print(f" JSON 解析失败 (尝试 {attempt + 1}): {e}")
|
||||
last_failures = [{"rule_id": "?", "field": "json", "issue": str(e)}]
|
||||
if attempt < max_retries:
|
||||
time.sleep(2)
|
||||
|
||||
# Exhausted retries — return what we have (even if imperfect)
|
||||
print(f" WARN: {pkg['unit_id']} 格式修复耗尽了 {max_retries} 次重试")
|
||||
return []
|
||||
|
||||
|
||||
def extract_all_rules(
|
||||
semantic_index: dict, doc: dict
|
||||
) -> list[dict]:
|
||||
"""Extract IR rules for all function units. Runs in parallel up to MAX_CONCURRENCY."""
|
||||
sections_by_source, image_by_rid, conflicts_by_section = build_document_lookup(doc)
|
||||
function_units = semantic_index.get("function_units", [])
|
||||
|
||||
print(f" 共 {len(function_units)} 个功能单元待处理")
|
||||
print(f" 最大并发: {MAX_CONCURRENCY}")
|
||||
|
||||
# Build context packages (serial — fast)
|
||||
packages = []
|
||||
for fu in function_units:
|
||||
pkg = extract_context_package(
|
||||
fu, doc, sections_by_source, image_by_rid, conflicts_by_section
|
||||
)
|
||||
packages.append(pkg)
|
||||
|
||||
# Run LLM calls in parallel
|
||||
fragments = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_CONCURRENCY) as executor:
|
||||
futures = {}
|
||||
for i, pkg in enumerate(packages):
|
||||
future = executor.submit(extract_rules_for_unit, pkg)
|
||||
futures[future] = (i, pkg["unit_id"], pkg["unit_name"])
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
i, uid, uname = futures[future]
|
||||
try:
|
||||
rules = future.result()
|
||||
fragments.append({
|
||||
"unit_id": uid,
|
||||
"unit_name": uname,
|
||||
"rules": rules
|
||||
})
|
||||
print(f" [OK] {uid} ({uname}): {len(rules)} 条规则")
|
||||
except Exception as e:
|
||||
print(f" [FAIL] {uid} ({uname}): 失败 — {e}")
|
||||
fragments.append({
|
||||
"unit_id": uid,
|
||||
"unit_name": uname,
|
||||
"rules": [],
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# Sort by unit_id to maintain stable ordering
|
||||
fragments.sort(key=lambda f: f["unit_id"])
|
||||
return fragments
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("阶段二:逐功能单元 IR 提取")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Load inputs
|
||||
print(f"\n[1/3] 加载输入...")
|
||||
semantic_index = load_semantic_index()
|
||||
doc = config.load_input_document()
|
||||
n_units = len(semantic_index.get("function_units", []))
|
||||
print(f" 语义索引: {n_units} 个功能单元")
|
||||
|
||||
# 2. Extract rules
|
||||
print(f"\n[2/3] 逐单元提取 IR 规则...")
|
||||
fragments = extract_all_rules(semantic_index, doc)
|
||||
|
||||
# 3. Save
|
||||
print(f"\n[3/3] 保存 IR 片段...")
|
||||
config.save_json(fragments, config.IR_FRAGMENTS_JSON)
|
||||
|
||||
total_rules = sum(len(f["rules"]) for f in fragments)
|
||||
failed_units = [f for f in fragments if f.get("error")]
|
||||
print(f"\n完成! {len(fragments)} 个功能单元, 共 {total_rules} 条规则")
|
||||
if failed_units:
|
||||
print(f" [WARN] {len(failed_units)} 个单元提取失败: "
|
||||
f"{[f['unit_id'] for f in failed_units]}")
|
||||
print(f"输出: {config.IR_FRAGMENTS_JSON}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,472 @@
|
||||
"""
|
||||
Tests for ensemble_merge.py — all pure Python, no LLM calls, no file I/O.
|
||||
|
||||
Each test uses hardcoded mock data to verify one piece of the merge logic.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from ensemble_merge import (
|
||||
concept_name_similarity,
|
||||
cluster_concepts,
|
||||
merge_concept_cluster,
|
||||
unit_node_jaccard,
|
||||
path_similarity,
|
||||
unit_similarity,
|
||||
cluster_function_units,
|
||||
pick_best_representative,
|
||||
compute_confidence_versions,
|
||||
ensemble_merge_concepts,
|
||||
ensemble_merge_function_units,
|
||||
ensemble_merge,
|
||||
_collect_logic_tree_nodes,
|
||||
)
|
||||
|
||||
PASS = "[PASS]"
|
||||
FAIL = "[FAIL]"
|
||||
|
||||
# ---- Mock helpers ----
|
||||
|
||||
def _mk_unit(unit_id, name, path, logic_tree_nodes, description="", sources=None):
|
||||
"""Create a minimal function_unit dict for testing."""
|
||||
if sources is None:
|
||||
srcs = []
|
||||
if logic_tree_nodes:
|
||||
srcs.append({
|
||||
"image_id": "rId16",
|
||||
"type": "logic_tree",
|
||||
"logic_tree_nodes": logic_tree_nodes,
|
||||
})
|
||||
if not srcs:
|
||||
srcs.append({
|
||||
"section": "3.1",
|
||||
"type": "table",
|
||||
"text_snippet": "test",
|
||||
})
|
||||
else:
|
||||
srcs = sources
|
||||
return {
|
||||
"unit_id": unit_id,
|
||||
"name": name,
|
||||
"description": description or f"desc for {name}",
|
||||
"path": path,
|
||||
"sources": srcs,
|
||||
}
|
||||
|
||||
|
||||
def _mk_concept(name, parent=None, aliases=None, defined_in=None):
|
||||
"""Create a minimal concept dict for testing."""
|
||||
return {
|
||||
"name": name,
|
||||
"aliases": aliases or [],
|
||||
"defined_in": defined_in or ["3.1"],
|
||||
"parent": parent,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 1: concept_name_similarity
|
||||
# =============================================================================
|
||||
|
||||
def test_concept_name_similarity_exact():
|
||||
assert concept_name_similarity("国内", "国内") == 1.0
|
||||
assert concept_name_similarity("行车娱乐限制", "行车娱乐限制") == 1.0
|
||||
|
||||
def test_concept_name_similarity_substring():
|
||||
sim = concept_name_similarity("国内行车娱乐限制", "行车娱乐限制")
|
||||
assert sim >= 0.85, f"expected >= 0.85, got {sim}"
|
||||
|
||||
def test_concept_name_similarity_different():
|
||||
sim = concept_name_similarity("国内", "海外")
|
||||
assert sim < 0.7, f"expected < 0.7, got {sim}"
|
||||
|
||||
def test_concept_name_similarity_seq_matcher():
|
||||
sim = concept_name_similarity("前台打断", "前台应用打断")
|
||||
assert 0.6 < sim < 0.95, f"expected 0.6-0.95, got {sim}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 2: _collect_logic_tree_nodes
|
||||
# =============================================================================
|
||||
|
||||
def test_collect_logic_tree_nodes():
|
||||
unit = _mk_unit("U1", "test", ["A"], ["n1", "n2", "n3"])
|
||||
nodes = _collect_logic_tree_nodes(unit)
|
||||
assert nodes == {"n1", "n2", "n3"}
|
||||
|
||||
def test_collect_logic_tree_nodes_empty():
|
||||
unit = _mk_unit("U2", "test", ["A"], [], sources=[{"section": "3.1", "type": "table"}])
|
||||
nodes = _collect_logic_tree_nodes(unit)
|
||||
assert nodes == set()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 3: unit_node_jaccard
|
||||
# =============================================================================
|
||||
|
||||
def test_unit_node_jaccard_identical():
|
||||
u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2", "n3"])
|
||||
u2 = _mk_unit("U2", "b", ["A"], ["n1", "n2", "n3"])
|
||||
assert unit_node_jaccard(u1, u2) == 1.0
|
||||
|
||||
def test_unit_node_jaccard_partial():
|
||||
u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2", "n3", "n4"])
|
||||
u2 = _mk_unit("U2", "b", ["A"], ["n1", "n2", "n3"])
|
||||
# intersection=3, union=4
|
||||
assert abs(unit_node_jaccard(u1, u2) - 0.75) < 0.01
|
||||
|
||||
def test_unit_node_jaccard_disjoint():
|
||||
u1 = _mk_unit("U1", "a", ["A"], ["n1", "n2"])
|
||||
u2 = _mk_unit("U2", "b", ["B"], ["n3", "n4"])
|
||||
assert unit_node_jaccard(u1, u2) == 0.0
|
||||
|
||||
def test_unit_node_jaccard_both_empty():
|
||||
u1 = _mk_unit("U1", "a", ["A"], [], sources=[{"section": "3.1", "type": "table"}])
|
||||
u2 = _mk_unit("U2", "b", ["B"], [], sources=[{"section": "3.1", "type": "table"}])
|
||||
assert unit_node_jaccard(u1, u2) == 0.0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 4: path_similarity
|
||||
# =============================================================================
|
||||
|
||||
def test_path_similarity_identical():
|
||||
assert path_similarity(
|
||||
["国内", "系统限制", "前台打断"],
|
||||
["国内", "系统限制", "前台打断"],
|
||||
) == 1.0
|
||||
|
||||
def test_path_similarity_partial():
|
||||
sim = path_similarity(
|
||||
["国内", "系统限制", "前台打断"],
|
||||
["国内", "系统限制", "后台限制启动"],
|
||||
)
|
||||
# 2/3 set overlap, sequential 3/5 ≈ 0.6
|
||||
assert 0.4 < sim < 0.9, f"expected 0.4-0.9, got {sim}"
|
||||
|
||||
def test_path_similarity_different():
|
||||
sim = path_similarity(["国内"], ["海外"])
|
||||
assert sim < 0.7, f"expected < 0.7, got {sim}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 5: unit_similarity
|
||||
# =============================================================================
|
||||
|
||||
def test_unit_similarity_identical():
|
||||
u = _mk_unit("U1", "国内-系统限制-前台打断",
|
||||
["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19"])
|
||||
assert unit_similarity(u, u) > 0.99
|
||||
|
||||
def test_unit_similarity_different():
|
||||
u1 = _mk_unit("U1", "a", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3"])
|
||||
u2 = _mk_unit("U2", "b", ["海外", "SDK限制"], ["n10", "n11", "n12"])
|
||||
assert unit_similarity(u1, u2) < 0.3
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 6: cluster_concepts
|
||||
# =============================================================================
|
||||
|
||||
def test_cluster_concepts_identical():
|
||||
v0 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")]
|
||||
v1 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")]
|
||||
v2 = [_mk_concept("国内"), _mk_concept("海外"), _mk_concept("系统限制", parent="国内")]
|
||||
clusters = cluster_concepts([v0, v1, v2])
|
||||
# Should have exactly 3 clusters (国内, 海外, 系统限制)
|
||||
assert len(clusters) == 3, f"expected 3 clusters, got {len(clusters)}"
|
||||
for c in clusters:
|
||||
assert len(c) == 3, f"expected each cluster to have 3 members, got {len(c)}"
|
||||
|
||||
def test_cluster_concepts_name_variation():
|
||||
v0 = [_mk_concept("国内行车娱乐限制", parent="国内")]
|
||||
v1 = [_mk_concept("行车娱乐限制", parent="国内")]
|
||||
v2 = [_mk_concept("国内行车娱乐限制", parent="国内")]
|
||||
clusters = cluster_concepts([v0, v1, v2])
|
||||
assert len(clusters) == 1, f"expected 1 cluster, got {len(clusters)}"
|
||||
assert len(clusters[0]) == 3, f"expected 3 members, got {len(clusters[0])}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 7: merge_concept_cluster
|
||||
# =============================================================================
|
||||
|
||||
def test_merge_concept_cluster():
|
||||
cluster = [
|
||||
(0, _mk_concept("国内行车娱乐限制", parent="国内", aliases=["限制"])),
|
||||
(1, _mk_concept("行车娱乐限制", parent="国内", aliases=["行车限制"])),
|
||||
(2, _mk_concept("行车娱乐限制", parent="国内", aliases=["限制"])),
|
||||
]
|
||||
merged, conf = merge_concept_cluster(cluster, 3)
|
||||
assert "行车娱乐限制" in merged["name"]
|
||||
assert merged["parent"] == "国内"
|
||||
assert set(merged["aliases"]) == {"限制", "行车限制"}
|
||||
assert conf in ("high", "medium")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 8: cluster_function_units
|
||||
# =============================================================================
|
||||
|
||||
def test_cluster_function_units_all_agree():
|
||||
u0 = _mk_unit("U-001", "国内-系统限制-前台打断",
|
||||
["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
||||
"switch ON, system app, foreground, speed>=15, non-P, interrupt + toast")
|
||||
u1 = _mk_unit("U-001", "国内-系统限制-前台打断",
|
||||
["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
||||
"switch ON, system app, foreground, speed>=15, non-P, interrupt + toast")
|
||||
u2 = _mk_unit("U-001", "国内-系统限制-前台打断",
|
||||
["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
||||
"switch ON, system app, foreground, interrupt")
|
||||
clusters = cluster_function_units([[u0], [u1], [u2]])
|
||||
assert len(clusters) == 1, f"expected 1 cluster, got {len(clusters)}"
|
||||
assert len(clusters[0]) == 3
|
||||
|
||||
def test_cluster_function_units_partial_agree():
|
||||
u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19"])
|
||||
u1 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19"])
|
||||
u2 = _mk_unit("U-002", "禁止", ["国内", "系统限制", "后台限制启动"],
|
||||
["n5", "n6"])
|
||||
clusters = cluster_function_units([[u0], [u1], [u2]])
|
||||
# u0+u1 in one cluster, u2 in another
|
||||
assert len(clusters) == 2, f"expected 2 clusters, got {len(clusters)}"
|
||||
cluster_sizes = sorted(len(c) for c in clusters)
|
||||
assert cluster_sizes == [1, 2], f"expected cluster sizes [1,2], got {cluster_sizes}"
|
||||
|
||||
def test_cluster_function_units_all_disagree():
|
||||
u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"], ["n1", "n2", "n3"])
|
||||
u1 = _mk_unit("U-002", "禁止", ["国内", "系统限制", "后台限制启动"], ["n5", "n6"])
|
||||
u2 = _mk_unit("U-003", "SDK", ["国内", "SDK限制"], ["n10", "n11"])
|
||||
clusters = cluster_function_units([[u0], [u1], [u2]])
|
||||
assert len(clusters) == 3, f"expected 3 clusters, got {len(clusters)}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 9: pick_best_representative
|
||||
# =============================================================================
|
||||
|
||||
def test_pick_best_representative_prefers_rich():
|
||||
u0 = _mk_unit("U-001", "short", ["国内", "系统限制"],
|
||||
["n1", "n2", "n3"],
|
||||
description="short desc")
|
||||
u1 = _mk_unit("U-001", "detailed", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
||||
description="very detailed description of the full rule behavior " * 5)
|
||||
cluster = [(0, u0), (1, u1)]
|
||||
best = pick_best_representative(cluster)
|
||||
# u1 should win: more nodes, longer description, though u0 has lower temp
|
||||
assert best["name"] == "detailed"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 10: compute_confidence_versions
|
||||
# =============================================================================
|
||||
|
||||
def test_confidence_high_unanimous():
|
||||
assert compute_confidence_versions(3, 3, True) == "high"
|
||||
|
||||
def test_confidence_high_two_of_three_with_t0():
|
||||
assert compute_confidence_versions(2, 3, True) == "high"
|
||||
|
||||
def test_confidence_medium_two_of_three_without_t0():
|
||||
assert compute_confidence_versions(2, 3, False) == "medium"
|
||||
|
||||
def test_confidence_low_one_of_three():
|
||||
assert compute_confidence_versions(1, 3, False) == "low"
|
||||
|
||||
def test_confidence_high_all_two_versions():
|
||||
assert compute_confidence_versions(2, 2, True) == "high"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 11: ensemble_merge_concepts
|
||||
# =============================================================================
|
||||
|
||||
def test_ensemble_merge_concepts():
|
||||
v0 = [_mk_concept("国内"), _mk_concept("海外"),
|
||||
_mk_concept("国内行车娱乐限制", parent="国内")]
|
||||
v1 = [_mk_concept("国内"), _mk_concept("海外"),
|
||||
_mk_concept("行车娱乐限制", parent="国内",
|
||||
aliases=["限制"], defined_in=["3.1", "3.1.1"])]
|
||||
v2 = [_mk_concept("国内"), _mk_concept("海外"),
|
||||
_mk_concept("行车娱乐限制", parent="国内")]
|
||||
|
||||
merged = ensemble_merge_concepts([v0, v1, v2])
|
||||
# Should merge the 3 concepts across 3 versions into 3 clusters
|
||||
assert len(merged) == 3, f"expected 3 merged concepts, got {len(merged)}"
|
||||
for c in merged:
|
||||
assert "confidence" in c
|
||||
assert "ensemble_support" in c
|
||||
assert c["ensemble_support"] == "3/3"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 12: ensemble_merge_function_units
|
||||
# =============================================================================
|
||||
|
||||
def test_ensemble_merge_function_units():
|
||||
u0 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
||||
description="full description A")
|
||||
u1 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25", "n26"],
|
||||
description="full description B (more detail)")
|
||||
u2 = _mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n21", "n23", "n25"],
|
||||
description="partial description")
|
||||
|
||||
merged = ensemble_merge_function_units([[u0], [u1], [u2]])
|
||||
assert len(merged) == 1, f"expected 1 unit, got {len(merged)}"
|
||||
unit = merged[0]
|
||||
assert unit["confidence"] == "high"
|
||||
assert unit["ensemble_support"] == "3/3"
|
||||
assert unit["source_versions"] == 3
|
||||
assert unit["unit_id"].startswith("FU-ENS-")
|
||||
# Should have picked u1 (more detail)
|
||||
assert "more detail" in unit["description"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 13: ensemble_merge full integration
|
||||
# =============================================================================
|
||||
|
||||
def test_ensemble_merge_full():
|
||||
v0 = {
|
||||
"feature_name": "行车娱乐限制",
|
||||
"concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")],
|
||||
"function_units": [
|
||||
_mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n25", "n26"]),
|
||||
_mk_unit("U-002", "后台禁止", ["国内", "系统限制", "后台限制启动"],
|
||||
["n5", "n6"]),
|
||||
],
|
||||
}
|
||||
v1 = {
|
||||
"feature_name": "行车娱乐限制",
|
||||
"concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")],
|
||||
"function_units": [
|
||||
_mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n25", "n26"]),
|
||||
_mk_unit("U-003", "SDK自定义", ["国内", "SDK限制", "自定义限制"],
|
||||
["n10", "n11"]),
|
||||
],
|
||||
}
|
||||
v2 = {
|
||||
"feature_name": "行车娱乐限制",
|
||||
"concepts": [_mk_concept("国内"), _mk_concept("系统限制", parent="国内")],
|
||||
"function_units": [
|
||||
_mk_unit("U-001", "打断", ["国内", "系统限制", "前台打断"],
|
||||
["n1", "n2", "n3", "n8", "n19", "n25", "n26"]),
|
||||
],
|
||||
}
|
||||
|
||||
result = ensemble_merge([v0, v1, v2])
|
||||
|
||||
assert result["feature_name"] == "行车娱乐限制"
|
||||
assert result["ensemble_versions"] == 3
|
||||
|
||||
units = result["function_units"]
|
||||
concepts = result["concepts"]
|
||||
|
||||
# Concepts: 国内 + 系统限制
|
||||
assert len(concepts) == 2
|
||||
|
||||
# Units: 打断 (3 versions → high), 后台禁止 (1 version → low), SDK (1 version → low)
|
||||
assert len(units) == 3
|
||||
|
||||
high_units = [u for u in units if u["confidence"] == "high"]
|
||||
low_units = [u for u in units if u["confidence"] == "low"]
|
||||
assert len(high_units) == 1
|
||||
assert len(low_units) == 2
|
||||
|
||||
# All units should have ensemble fields
|
||||
for u in units:
|
||||
assert "confidence" in u
|
||||
assert "ensemble_support" in u
|
||||
assert "source_versions" in u
|
||||
|
||||
# Confidence summary
|
||||
cs = result["confidence_summary"]
|
||||
assert cs["total_units"] == 3
|
||||
assert cs["high"] == 1
|
||||
assert cs["low"] == 2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Runner
|
||||
# =============================================================================
|
||||
|
||||
def run_all_tests():
|
||||
print("=" * 60)
|
||||
print("Ensemble Merge 测试 (纯 Python, 无 LLM)")
|
||||
print("=" * 60)
|
||||
|
||||
tests = [
|
||||
("concept_name_similarity exact", test_concept_name_similarity_exact),
|
||||
("concept_name_similarity substring", test_concept_name_similarity_substring),
|
||||
("concept_name_similarity different", test_concept_name_similarity_different),
|
||||
("concept_name_similarity seq_matcher", test_concept_name_similarity_seq_matcher),
|
||||
("collect_logic_tree_nodes", test_collect_logic_tree_nodes),
|
||||
("collect_logic_tree_nodes empty", test_collect_logic_tree_nodes_empty),
|
||||
("unit_node_jaccard identical", test_unit_node_jaccard_identical),
|
||||
("unit_node_jaccard partial", test_unit_node_jaccard_partial),
|
||||
("unit_node_jaccard disjoint", test_unit_node_jaccard_disjoint),
|
||||
("unit_node_jaccard both_empty", test_unit_node_jaccard_both_empty),
|
||||
("path_similarity identical", test_path_similarity_identical),
|
||||
("path_similarity partial", test_path_similarity_partial),
|
||||
("path_similarity different", test_path_similarity_different),
|
||||
("unit_similarity identical", test_unit_similarity_identical),
|
||||
("unit_similarity different", test_unit_similarity_different),
|
||||
("cluster_concepts identical", test_cluster_concepts_identical),
|
||||
("cluster_concepts name variation", test_cluster_concepts_name_variation),
|
||||
("merge_concept_cluster", test_merge_concept_cluster),
|
||||
("cluster_function_units all_agree", test_cluster_function_units_all_agree),
|
||||
("cluster_function_units partial_agree", test_cluster_function_units_partial_agree),
|
||||
("cluster_function_units all_disagree", test_cluster_function_units_all_disagree),
|
||||
("pick_best_representative", test_pick_best_representative_prefers_rich),
|
||||
("confidence high unanimous", test_confidence_high_unanimous),
|
||||
("confidence high 2/3 with t0", test_confidence_high_two_of_three_with_t0),
|
||||
("confidence medium 2/3 no t0", test_confidence_medium_two_of_three_without_t0),
|
||||
("confidence low 1/3", test_confidence_low_one_of_three),
|
||||
("confidence high 2/2", test_confidence_high_all_two_versions),
|
||||
("ensemble_merge_concepts", test_ensemble_merge_concepts),
|
||||
("ensemble_merge_function_units", test_ensemble_merge_function_units),
|
||||
("ensemble_merge full", test_ensemble_merge_full),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
for name, test_fn in tests:
|
||||
try:
|
||||
test_fn()
|
||||
print(f" {PASS} {name}")
|
||||
passed += 1
|
||||
except AssertionError as e:
|
||||
print(f" {FAIL} {name}: {e}")
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
print(f" {FAIL} {name}: unexpected {type(e).__name__}: {e}")
|
||||
failed += 1
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
if failed == 0:
|
||||
print(f"{PASS} 所有 {passed} 个测试通过!")
|
||||
else:
|
||||
print(f"{FAIL} {failed}/{passed + failed} 个测试失败")
|
||||
print(f"{'='*60}")
|
||||
|
||||
return failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
Tests for Stage 1 (Semantic Index).
|
||||
|
||||
Validates that the generated semantic_index.json meets all completeness
|
||||
and structural requirements, including the new iterative features:
|
||||
- function_units have path fields
|
||||
- concepts have parent references
|
||||
- logic tree node coverage meets thresholds
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
import config
|
||||
|
||||
|
||||
PASS = "[PASS]"
|
||||
FAIL = "[FAIL]"
|
||||
WARN = "[WARN]"
|
||||
|
||||
|
||||
def load_inputs():
|
||||
"""Load semantic_index.json and the original parsed document."""
|
||||
try:
|
||||
si = config.load_json(config.SEMANTIC_INDEX_JSON)
|
||||
except FileNotFoundError:
|
||||
print(f"{FAIL} semantic_index.json 未找到: {config.SEMANTIC_INDEX_JSON}")
|
||||
print(" 请先运行 step1_semantic_index.py")
|
||||
sys.exit(1)
|
||||
doc = config.load_input_document()
|
||||
return si, doc
|
||||
|
||||
|
||||
def build_image_index(doc: dict) -> dict[str, dict]:
|
||||
"""Build lookup: image rId -> image_analysis entry."""
|
||||
idx = {}
|
||||
for img in doc.get("image_analysis", []):
|
||||
rid = img.get("rid", "")
|
||||
if rid:
|
||||
idx[rid] = img
|
||||
return idx
|
||||
|
||||
|
||||
def build_logic_tree_node_index(doc: dict) -> dict[str, set[str]]:
|
||||
"""Build lookup: image rId -> set of all node IDs in that logic_tree."""
|
||||
idx = {}
|
||||
for img in doc.get("image_analysis", []):
|
||||
rid = img.get("rid", "")
|
||||
lt = img.get("logic_tree")
|
||||
if lt and rid:
|
||||
node_ids = {n["id"] for n in lt.get("nodes", [])}
|
||||
idx[rid] = node_ids
|
||||
return idx
|
||||
|
||||
|
||||
def check_unit_ids(units: list[dict]) -> list[str]:
|
||||
"""Check that every function_unit has a non-empty unit_id and name."""
|
||||
errors = []
|
||||
seen_ids = set()
|
||||
for i, fu in enumerate(units):
|
||||
uid = fu.get("unit_id", "")
|
||||
name = fu.get("name", "")
|
||||
if not uid:
|
||||
errors.append(f"function_unit[{i}]: unit_id 为空")
|
||||
elif uid in seen_ids:
|
||||
errors.append(f"function_unit[{i}]: unit_id '{uid}' 重复")
|
||||
seen_ids.add(uid)
|
||||
if not name:
|
||||
errors.append(f"function_unit[{i}] ({uid}): name 为空")
|
||||
return errors
|
||||
|
||||
|
||||
def check_unit_paths(units: list[dict]) -> list[str]:
|
||||
"""Check that every function_unit has a non-empty path array."""
|
||||
errors = []
|
||||
for fu in units:
|
||||
uid = fu.get("unit_id", "?")
|
||||
path = fu.get("path", [])
|
||||
if not path:
|
||||
errors.append(f"{uid}: path 字段为空或缺失")
|
||||
elif not isinstance(path, list):
|
||||
errors.append(f"{uid}: path 必须是数组")
|
||||
return errors
|
||||
|
||||
|
||||
def check_concept_parents(concepts: list[dict]) -> list[str]:
|
||||
"""Check that non-scope concepts have valid parent references."""
|
||||
errors = []
|
||||
concept_names = {c.get("name", "") for c in concepts}
|
||||
scope_concepts = {"国内", "海外"}
|
||||
|
||||
for c in concepts:
|
||||
name = c.get("name", "?")
|
||||
parent = c.get("parent", "")
|
||||
|
||||
if name in scope_concepts:
|
||||
# Scope concepts should have no parent
|
||||
if parent:
|
||||
errors.append(f"scope 概念 '{name}' 不应有 parent (当前: '{parent}')")
|
||||
else:
|
||||
# Non-scope concepts must have a parent
|
||||
if not parent:
|
||||
errors.append(f"概念 '{name}' 缺少 parent 字段")
|
||||
elif parent not in concept_names:
|
||||
errors.append(f"概念 '{name}' 的 parent '{parent}' 不存在于 concepts 中")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_sources_exist(
|
||||
units: list[dict], image_index: dict[str, dict], node_index: dict[str, set[str]]
|
||||
) -> list[str]:
|
||||
"""Check that all source references point to real content."""
|
||||
errors = []
|
||||
for fu in units:
|
||||
uid = fu.get("unit_id", "?")
|
||||
sources = fu.get("sources", [])
|
||||
if not sources:
|
||||
errors.append(f"{uid}: sources 为空,必须至少引用一张图片或一段文字")
|
||||
continue
|
||||
|
||||
has_text = False
|
||||
has_image = False
|
||||
|
||||
for j, src in enumerate(sources):
|
||||
src_type = src.get("type", "")
|
||||
if src_type in ("table", "para"):
|
||||
has_text = True
|
||||
section = src.get("section", "")
|
||||
if not section:
|
||||
errors.append(f"{uid}.sources[{j}]: 缺少 section")
|
||||
elif src_type == "logic_tree":
|
||||
has_image = True
|
||||
image_id = src.get("image_id", "")
|
||||
if not image_id:
|
||||
errors.append(f"{uid}.sources[{j}]: logic_tree 缺少 image_id")
|
||||
continue
|
||||
if image_id not in image_index:
|
||||
errors.append(
|
||||
f"{uid}.sources[{j}]: image_id '{image_id}' "
|
||||
f"在 image_analysis 中不存在"
|
||||
)
|
||||
continue
|
||||
node_ids = src.get("logic_tree_nodes", [])
|
||||
if node_ids and image_id in node_index:
|
||||
valid_nodes = node_index[image_id]
|
||||
for nid in node_ids:
|
||||
if nid not in valid_nodes:
|
||||
errors.append(
|
||||
f"{uid}.sources[{j}]: 节点 '{nid}' 在 "
|
||||
f"{image_id} 的逻辑树中不存在"
|
||||
)
|
||||
elif not node_ids:
|
||||
errors.append(
|
||||
f"{uid}.sources[{j}]: logic_tree 类型但未提供 logic_tree_nodes"
|
||||
)
|
||||
|
||||
if not has_text and not has_image:
|
||||
errors.append(f"{uid}: 必须至少引用一个文本或图片来源")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_logic_tree_coverage(
|
||||
units: list[dict], node_index: dict[str, set[str]]
|
||||
) -> list[str]:
|
||||
"""Check that decision and action nodes in logic trees are covered."""
|
||||
warnings = []
|
||||
for image_id, all_nodes in node_index.items():
|
||||
referenced = set()
|
||||
for fu in units:
|
||||
for src in fu.get("sources", []):
|
||||
if src.get("image_id") == image_id:
|
||||
for nid in src.get("logic_tree_nodes", []):
|
||||
referenced.add(nid)
|
||||
|
||||
uncovered = all_nodes - referenced
|
||||
if uncovered:
|
||||
doc = config.load_input_document()
|
||||
node_types = {}
|
||||
for img in doc.get("image_analysis", []):
|
||||
if img.get("rid") == image_id:
|
||||
lt = img.get("logic_tree", {})
|
||||
for n in lt.get("nodes", []):
|
||||
node_types[n["id"]] = n.get("type", "?")
|
||||
break
|
||||
|
||||
decision_action_uncovered = [
|
||||
n for n in uncovered if node_types.get(n) in ("decision", "action")
|
||||
]
|
||||
if decision_action_uncovered:
|
||||
warnings.append(
|
||||
f"{image_id}: {len(decision_action_uncovered)} 个 "
|
||||
f"decision/action 节点未被引用: {decision_action_uncovered}"
|
||||
)
|
||||
|
||||
return warnings
|
||||
|
||||
|
||||
def check_ensemble_confidence(units: list[dict]) -> list[str]:
|
||||
"""Check that every function_unit has confidence, ensemble_support, source_versions."""
|
||||
errors = []
|
||||
valid_conf = {"high", "medium", "low"}
|
||||
for fu in units:
|
||||
uid = fu.get("unit_id", "?")
|
||||
conf = fu.get("confidence", "")
|
||||
if not conf:
|
||||
errors.append(f"{uid}: 缺少 confidence 字段")
|
||||
elif conf not in valid_conf:
|
||||
errors.append(f"{uid}: confidence='{conf}' 无效 (期望 high/medium/low)")
|
||||
support = fu.get("ensemble_support", "")
|
||||
if not support:
|
||||
errors.append(f"{uid}: 缺少 ensemble_support 字段")
|
||||
if "source_versions" not in fu:
|
||||
errors.append(f"{uid}: 缺少 source_versions 字段")
|
||||
return errors
|
||||
|
||||
|
||||
def check_confidence_summary(si: dict) -> list[str]:
|
||||
"""Check that confidence_summary counts match actual unit/concept confidence."""
|
||||
errors = []
|
||||
cs = si.get("confidence_summary", {})
|
||||
if not cs:
|
||||
errors.append("缺少 confidence_summary 字段")
|
||||
return errors
|
||||
|
||||
units = si.get("function_units", [])
|
||||
concepts = si.get("concepts", [])
|
||||
|
||||
# Count actual confidence levels
|
||||
unit_high = sum(1 for u in units if u.get("confidence") == "high")
|
||||
unit_medium = sum(1 for u in units if u.get("confidence") == "medium")
|
||||
unit_low = sum(1 for u in units if u.get("confidence") == "low")
|
||||
concept_high = sum(1 for c in concepts if c.get("confidence") == "high")
|
||||
concept_medium = sum(1 for c in concepts if c.get("confidence") == "medium")
|
||||
concept_low = sum(1 for c in concepts if c.get("confidence") == "low")
|
||||
|
||||
if cs.get("total_units", 0) != len(units):
|
||||
errors.append(f"confidence_summary.total_units={cs.get('total_units')} != 实际 {len(units)}")
|
||||
if cs.get("high", 0) != unit_high:
|
||||
errors.append(f"confidence_summary.high={cs.get('high')} != 实际 {unit_high}")
|
||||
if cs.get("medium", 0) != unit_medium:
|
||||
errors.append(f"confidence_summary.medium={cs.get('medium')} != 实际 {unit_medium}")
|
||||
if cs.get("low", 0) != unit_low:
|
||||
errors.append(f"confidence_summary.low={cs.get('low')} != 实际 {unit_low}")
|
||||
if cs.get("total_concepts", 0) != len(concepts):
|
||||
errors.append(f"confidence_summary.total_concepts={cs.get('total_concepts')} != 实际 {len(concepts)}")
|
||||
if cs.get("concept_high", 0) != concept_high:
|
||||
errors.append(f"confidence_summary.concept_high={cs.get('concept_high')} != 实际 {concept_high}")
|
||||
if cs.get("concept_medium", 0) != concept_medium:
|
||||
errors.append(f"confidence_summary.concept_medium={cs.get('concept_medium')} != 实际 {concept_medium}")
|
||||
if cs.get("concept_low", 0) != concept_low:
|
||||
errors.append(f"confidence_summary.concept_low={cs.get('concept_low')} != 实际 {concept_low}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
print("=" * 60)
|
||||
print("Step 1 自检测试")
|
||||
print("=" * 60)
|
||||
|
||||
si, doc = load_inputs()
|
||||
units = si.get("function_units", [])
|
||||
concepts = si.get("concepts", [])
|
||||
image_index = build_image_index(doc)
|
||||
node_index = build_logic_tree_node_index(doc)
|
||||
|
||||
all_errors = []
|
||||
all_warnings = []
|
||||
|
||||
# Test 1: unit_id and name validity
|
||||
errors = check_unit_ids(units)
|
||||
if errors:
|
||||
print(f"\n{FAIL} unit_id/name 检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} unit_id/name 检查: 全部通过 ({len(units)} 个功能单元)")
|
||||
|
||||
# Test 2: path fields
|
||||
errors = check_unit_paths(units)
|
||||
if errors:
|
||||
print(f"\n{FAIL} path 字段检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} path 字段检查: 全部通过")
|
||||
|
||||
# Test 3: concept parent references
|
||||
errors = check_concept_parents(concepts)
|
||||
if errors:
|
||||
print(f"\n{FAIL} concept parent 检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} concept parent 检查: 全部通过 ({len(concepts)} 个概念)")
|
||||
|
||||
# Test 4: source references exist
|
||||
errors = check_sources_exist(units, image_index, node_index)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 来源引用检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 来源引用检查: 全部通过")
|
||||
|
||||
# Test 5: Logic tree coverage
|
||||
warnings = check_logic_tree_coverage(units, node_index)
|
||||
if warnings:
|
||||
print(f"\n{WARN} 逻辑树节点覆盖率: {len(warnings)} 个警告")
|
||||
for w in warnings:
|
||||
print(f" - {w}")
|
||||
all_warnings.extend(warnings)
|
||||
else:
|
||||
print(f"\n{PASS} 逻辑树节点覆盖率: 全部通过")
|
||||
|
||||
# Test 6: Ensemble confidence fields on function_units
|
||||
errors = check_ensemble_confidence(units)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 集成置信度字段: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 集成置信度字段: 全部通过")
|
||||
|
||||
# Test 7: Confidence summary consistency
|
||||
errors = check_confidence_summary(si)
|
||||
if errors:
|
||||
print(f"\n{FAIL} confidence_summary 一致性: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
cs = si.get("confidence_summary", {})
|
||||
print(f"\n{PASS} confidence_summary 一致性: "
|
||||
f"high={cs.get('high',0)}, medium={cs.get('medium',0)}, "
|
||||
f"low={cs.get('low',0)}")
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
total_failures = len(all_errors)
|
||||
total_warnings = len(all_warnings)
|
||||
|
||||
if total_failures == 0 and total_warnings == 0:
|
||||
print(f"{PASS} 所有测试通过!")
|
||||
elif total_failures == 0:
|
||||
print(f"{WARN} 全部通过但有 {total_warnings} 个警告")
|
||||
else:
|
||||
print(f"{FAIL} 测试失败: {total_failures} 个错误, {total_warnings} 个警告")
|
||||
print("\n请检查 LLM 输出质量,可能需要调整 Prompt 并重新运行 step1_semantic_index.py")
|
||||
|
||||
print(f"\n统计:")
|
||||
print(f" 功能单元数: {len(units)}")
|
||||
print(f" 概念数: {len(concepts)}")
|
||||
print(f" 逻辑树图片数: {len(node_index)}")
|
||||
|
||||
return total_failures == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Tests for Stage 2 (IR Extraction).
|
||||
|
||||
Validates that ir_fragments.json meets quality and structural requirements:
|
||||
- All fragments have non-empty rules
|
||||
- All rules have path arrays
|
||||
- All rules have precondition.geographic_scope and precondition.screen_type
|
||||
- All trigger conditions have signal/operator/value
|
||||
- user_interaction content is non-empty and not a placeholder
|
||||
- No duplicate rule_ids (across all fragments)
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
import config
|
||||
|
||||
|
||||
PASS = "[PASS]"
|
||||
FAIL = "[FAIL]"
|
||||
WARN = "[WARN]"
|
||||
|
||||
# Forbidden placeholder phrases in user_interaction content
|
||||
FORBIDDEN_PLACEHOLDERS = [
|
||||
"文案由业务定义", "待定", "自定义", "TBD", "todo", "TODO"
|
||||
]
|
||||
|
||||
|
||||
def load_fragments():
|
||||
"""Load ir_fragments.json."""
|
||||
try:
|
||||
return config.load_json(config.IR_FRAGMENTS_JSON)
|
||||
except FileNotFoundError:
|
||||
print(f"{FAIL} ir_fragments.json 未找到: {config.IR_FRAGMENTS_JSON}")
|
||||
print(" 请先运行 step2_ir_extraction.py")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def check_non_empty_rules(fragments: list[dict]) -> list[str]:
|
||||
"""Every fragment must have at least one rule."""
|
||||
errors = []
|
||||
for f in fragments:
|
||||
uid = f.get("unit_id", "?")
|
||||
rules = f.get("rules", [])
|
||||
if not rules:
|
||||
if f.get("error"):
|
||||
errors.append(f"{uid}: 提取失败 — {f['error']}")
|
||||
else:
|
||||
errors.append(f"{uid}: rules 为空")
|
||||
return errors
|
||||
|
||||
|
||||
def check_rule_paths(fragments: list[dict]) -> list[str]:
|
||||
"""Every rule must have a non-empty path array."""
|
||||
errors = []
|
||||
for f in fragments:
|
||||
uid = f.get("unit_id", "?")
|
||||
for j, rule in enumerate(f.get("rules", [])):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
path = rule.get("path", [])
|
||||
if not path:
|
||||
errors.append(f"{rid}: path 字段为空或缺失")
|
||||
elif not isinstance(path, list):
|
||||
errors.append(f"{rid}: path 必须是数组")
|
||||
return errors
|
||||
|
||||
|
||||
def check_precondition_fields(fragments: list[dict]) -> list[str]:
|
||||
"""Every rule must have precondition with geographic_scope and screen_type."""
|
||||
errors = []
|
||||
for f in fragments:
|
||||
uid = f.get("unit_id", "?")
|
||||
for j, rule in enumerate(f.get("rules", [])):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
precond = rule.get("precondition", {})
|
||||
if not precond:
|
||||
errors.append(f"{rid}: precondition 缺失")
|
||||
continue
|
||||
if not precond.get("geographic_scope"):
|
||||
errors.append(f"{rid}: precondition.geographic_scope 缺失")
|
||||
if "screen_type" not in precond:
|
||||
errors.append(f"{rid}: precondition.screen_type 缺失")
|
||||
return errors
|
||||
|
||||
|
||||
def check_user_interaction_content(fragments: list[dict]) -> list[str]:
|
||||
"""user_interaction actions must have non-empty, non-placeholder content."""
|
||||
errors = []
|
||||
for f in fragments:
|
||||
uid = f.get("unit_id", "?")
|
||||
for j, rule in enumerate(f.get("rules", [])):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
for k, action in enumerate(rule.get("actions", [])):
|
||||
if action.get("type") != "user_interaction":
|
||||
continue
|
||||
content = action.get("content", "")
|
||||
if not content:
|
||||
errors.append(
|
||||
f"{rid}.actions[{k}]: user_interaction 的 content 为空"
|
||||
)
|
||||
elif any(ph in content for ph in FORBIDDEN_PLACEHOLDERS):
|
||||
errors.append(
|
||||
f"{rid}.actions[{k}]: content 包含占位符: '{content}'"
|
||||
)
|
||||
return errors
|
||||
|
||||
|
||||
def check_sources_have_logic_tree_nodes(fragments: list[dict]) -> list[str]:
|
||||
"""Every rule should reference at least one logic tree node in its sources."""
|
||||
errors = []
|
||||
for f in fragments:
|
||||
uid = f.get("unit_id", "?")
|
||||
for j, rule in enumerate(f.get("rules", [])):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
sources = rule.get("sources", [])
|
||||
has_logic_tree = any(
|
||||
src.get("type") == "logic_tree" and src.get("node_ids")
|
||||
for src in sources
|
||||
)
|
||||
if not has_logic_tree:
|
||||
has_text = any(
|
||||
src.get("type") in ("table", "para") for src in sources
|
||||
)
|
||||
if not has_text:
|
||||
errors.append(f"{rid}: sources 中既无逻辑树引用也无文字引用")
|
||||
return errors
|
||||
|
||||
|
||||
def check_trigger_conditions(fragments: list[dict]) -> list[str]:
|
||||
"""Every trigger condition must have signal, operator, value."""
|
||||
errors = []
|
||||
for f in fragments:
|
||||
uid = f.get("unit_id", "?")
|
||||
for j, rule in enumerate(f.get("rules", [])):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
trigger = rule.get("trigger", {})
|
||||
conditions = trigger.get("conditions", [])
|
||||
|
||||
if trigger.get("event") is not None:
|
||||
continue
|
||||
|
||||
for k, cond in enumerate(conditions):
|
||||
signal = cond.get("signal", "")
|
||||
operator = cond.get("operator", "")
|
||||
has_value = "value" in cond
|
||||
|
||||
if not signal:
|
||||
errors.append(f"{rid}.condition[{k}]: 缺少 signal")
|
||||
if not operator:
|
||||
errors.append(f"{rid}.condition[{k}]: 缺少 operator")
|
||||
if not has_value:
|
||||
errors.append(f"{rid}.condition[{k}]: 缺少 value")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_duplicate_rule_ids(fragments: list[dict]) -> list[str]:
|
||||
"""Check for duplicate rule_ids across all fragments."""
|
||||
all_rule_ids = []
|
||||
for f in fragments:
|
||||
for rule in f.get("rules", []):
|
||||
rid = rule.get("rule_id", "")
|
||||
if rid:
|
||||
all_rule_ids.append(rid)
|
||||
|
||||
duplicates = [rid for rid, count in Counter(all_rule_ids).items() if count > 1]
|
||||
errors = []
|
||||
if duplicates:
|
||||
errors.append(f"重复 rule_id: {duplicates}")
|
||||
return errors
|
||||
|
||||
|
||||
def check_action_types(fragments: list[dict]) -> list[str]:
|
||||
"""Verify that actions have valid types."""
|
||||
valid_types = {"system", "user_interaction"}
|
||||
errors = []
|
||||
for f in fragments:
|
||||
for j, rule in enumerate(f.get("rules", [])):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
for k, action in enumerate(rule.get("actions", [])):
|
||||
atype = action.get("type", "")
|
||||
if atype not in valid_types:
|
||||
errors.append(
|
||||
f"{rid}.action[{k}]: type='{atype}' 无效, "
|
||||
f"应为 {valid_types}"
|
||||
)
|
||||
if atype == "user_interaction" and "content" not in action:
|
||||
errors.append(
|
||||
f"{rid}.action[{k}]: user_interaction 类型缺少 content 字段"
|
||||
)
|
||||
return errors
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
print("=" * 60)
|
||||
print("Step 2 自检测试")
|
||||
print("=" * 60)
|
||||
|
||||
fragments = load_fragments()
|
||||
all_errors = []
|
||||
total_units = len(fragments)
|
||||
total_rules = sum(len(f.get("rules", [])) for f in fragments)
|
||||
|
||||
# Test 1: Non-empty rules
|
||||
errors = check_non_empty_rules(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 非空规则检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 非空规则检查: 全部通过 ({total_units} 个片段)")
|
||||
|
||||
# Test 2: Rule path arrays
|
||||
errors = check_rule_paths(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 规则 path 字段: {len(errors)} 个错误")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
if len(errors) > 10:
|
||||
print(f" ... 还有 {len(errors) - 10} 个")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 规则 path 字段: 全部通过")
|
||||
|
||||
# Test 3: Precondition fields
|
||||
errors = check_precondition_fields(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} precondition 字段: {len(errors)} 个错误")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
if len(errors) > 10:
|
||||
print(f" ... 还有 {len(errors) - 10} 个")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} precondition 字段: 全部通过")
|
||||
|
||||
# Test 4: user_interaction content
|
||||
errors = check_user_interaction_content(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} user_interaction content: {len(errors)} 个错误")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
if len(errors) > 10:
|
||||
print(f" ... 还有 {len(errors) - 10} 个")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} user_interaction content: 全部通过")
|
||||
|
||||
# Test 5: Sources have logic tree references
|
||||
errors = check_sources_have_logic_tree_nodes(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 来源节点引用: {len(errors)} 个规则缺少来源引用")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
if len(errors) > 10:
|
||||
print(f" ... 还有 {len(errors) - 10} 个")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 来源节点引用: 全部通过")
|
||||
|
||||
# Test 6: Trigger conditions completeness
|
||||
errors = check_trigger_conditions(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 触发条件完整性: {len(errors)} 个条件不完整")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
if len(errors) > 10:
|
||||
print(f" ... 还有 {len(errors) - 10} 个")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 触发条件完整性: 全部通过")
|
||||
|
||||
# Test 7: No duplicate rule_ids
|
||||
errors = check_duplicate_rule_ids(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} rule_id 唯一性: 发现重复")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} rule_id 唯一性: 全部通过")
|
||||
|
||||
# Test 8: Valid action types
|
||||
errors = check_action_types(fragments)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 动作类型检查: {len(errors)} 个问题")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 动作类型检查: 全部通过")
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
total_failures = len(all_errors)
|
||||
|
||||
if total_failures == 0:
|
||||
print(f"{PASS} 所有测试通过!")
|
||||
else:
|
||||
print(f"{FAIL} 测试失败: {total_failures} 个错误")
|
||||
print("\n建议:")
|
||||
print(" 1. 检查 ir_fragments.json 中出错的规则")
|
||||
print(" 2. 如果某些功能单元的规则为空,检查上下文包是否丢失了关键信息")
|
||||
print(" 3. 调整 Prompt (prompts/step2_ir_extraction.txt) 后重新运行")
|
||||
|
||||
print(f"\n统计:")
|
||||
print(f" 功能单元数: {total_units}")
|
||||
print(f" 规则总数: {total_rules}")
|
||||
error_units = sum(1 for f in fragments if f.get("error"))
|
||||
if error_units:
|
||||
print(f" 提取失败的单元: {error_units}")
|
||||
|
||||
return total_failures == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Tests for Stage 2.5 (Branch Coverage Auto-Completion).
|
||||
|
||||
Validates:
|
||||
- Path enumeration exists and is non-empty
|
||||
- Auto-complete fragments have valid structure
|
||||
- No duplicate unit_ids in autocomplete fragments
|
||||
- Path coverage improved after autocomplete (if applicable)
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
import config
|
||||
|
||||
|
||||
PASS = "[PASS]"
|
||||
FAIL = "[FAIL]"
|
||||
WARN = "[WARN]"
|
||||
|
||||
|
||||
def load_path_enumeration():
|
||||
"""Load path_enumeration.json."""
|
||||
try:
|
||||
return config.load_json(config.PATH_ENUM_JSON)
|
||||
except FileNotFoundError:
|
||||
print(f"{FAIL} path_enumeration.json 未找到: {config.PATH_ENUM_JSON}")
|
||||
print(" 请先运行 step2_5_branch_coverage.py")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_autocomplete_fragments():
|
||||
"""Load ir_autocomplete_fragments.json, or return [] if absent."""
|
||||
path = config.IR_AUTOCOMPLETE_FRAGMENTS_JSON
|
||||
if not Path(path).exists():
|
||||
return None
|
||||
return config.load_json(path)
|
||||
|
||||
|
||||
def check_path_enumeration(data: dict) -> list[str]:
|
||||
"""Check path enumeration has valid structure."""
|
||||
errors = []
|
||||
paths = data.get("logic_tree_paths", {})
|
||||
if not paths:
|
||||
errors.append("logic_tree_paths 为空")
|
||||
total = data.get("total_paths", 0)
|
||||
if total <= 0:
|
||||
errors.append(f"total_paths = {total}, 期望 > 0")
|
||||
|
||||
for image_id, image_paths in paths.items():
|
||||
if not image_paths:
|
||||
errors.append(f"{image_id}: 路径列表为空")
|
||||
continue
|
||||
for i, p in enumerate(image_paths):
|
||||
if not p.get("path_id"):
|
||||
errors.append(f"{image_id}[{i}]: 缺少 path_id")
|
||||
if not p.get("image_id"):
|
||||
errors.append(f"{image_id}[{i}]: 缺少 image_id")
|
||||
if not p.get("node_ids"):
|
||||
errors.append(f"{image_id}[{i}]: 缺少 node_ids")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_autocomplete_fragments(fragments: list[dict] | None) -> list[str]:
|
||||
"""Check auto-complete fragments have valid structure."""
|
||||
if fragments is None:
|
||||
return ["ir_autocomplete_fragments.json 未生成 (可能无需补全)"]
|
||||
|
||||
errors = []
|
||||
seen_unit_ids = set()
|
||||
|
||||
for frag in fragments:
|
||||
uid = frag.get("unit_id", "")
|
||||
if not uid:
|
||||
errors.append("fragment 缺少 unit_id")
|
||||
continue
|
||||
if uid in seen_unit_ids:
|
||||
errors.append(f"unit_id '{uid}' 重复")
|
||||
seen_unit_ids.add(uid)
|
||||
|
||||
if not frag.get("auto_generated"):
|
||||
errors.append(f"{uid}: auto_generated 应为 true")
|
||||
|
||||
rules = frag.get("rules", [])
|
||||
for j, rule in enumerate(rules):
|
||||
rid = rule.get("rule_id", f"rule[{j}]")
|
||||
if not rule.get("path"):
|
||||
errors.append(f"{rid}: path 字段缺失")
|
||||
precond = rule.get("precondition", {})
|
||||
if not precond.get("geographic_scope"):
|
||||
errors.append(f"{rid}: precondition.geographic_scope 缺失")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
print("=" * 60)
|
||||
print("Step 2.5 自检测试")
|
||||
print("=" * 60)
|
||||
|
||||
all_errors = []
|
||||
|
||||
# Test 1: Path enumeration exists
|
||||
try:
|
||||
path_data = load_path_enumeration()
|
||||
except SystemExit:
|
||||
return False
|
||||
|
||||
errors = check_path_enumeration(path_data)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 路径枚举检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
total = path_data.get("total_paths", 0)
|
||||
n_images = len(path_data.get("logic_tree_paths", {}))
|
||||
print(f"\n{PASS} 路径枚举检查: {total} 条路径, {n_images} 个逻辑树")
|
||||
|
||||
# Test 2: Auto-complete fragments
|
||||
fragments = load_autocomplete_fragments()
|
||||
errors = check_autocomplete_fragments(fragments)
|
||||
|
||||
if fragments is None:
|
||||
print(f"\n{WARN} 自动补全片段: 未生成 (可能所有路径已覆盖)")
|
||||
elif errors:
|
||||
print(f"\n{FAIL} 自动补全片段检查: {len(errors)} 个错误")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
auto_rules = sum(len(f.get("rules", [])) for f in fragments)
|
||||
print(f"\n{PASS} 自动补全片段检查: "
|
||||
f"{len(fragments)} 个片段, {auto_rules} 条规则")
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
total_failures = len(all_errors)
|
||||
|
||||
if total_failures == 0:
|
||||
print(f"{PASS} 所有测试通过!")
|
||||
else:
|
||||
print(f"{FAIL} 测试失败: {total_failures} 个错误")
|
||||
|
||||
return total_failures == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Tests for Stage 3 (Merge & Audit).
|
||||
|
||||
Validates:
|
||||
- ir_final.json exists and is well-formed
|
||||
- No duplicate rule_ids
|
||||
- All rule_ids follow new hierarchical naming convention
|
||||
- All rules have path arrays
|
||||
- ir_audit_report.md exists and contains all required sections
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
import config
|
||||
|
||||
|
||||
PASS = "[PASS]"
|
||||
FAIL = "[FAIL]"
|
||||
WARN = "[WARN]"
|
||||
|
||||
|
||||
def load_ir_final():
|
||||
"""Load ir_final.json."""
|
||||
try:
|
||||
return config.load_json(config.IR_FINAL_JSON)
|
||||
except FileNotFoundError:
|
||||
print(f"{FAIL} ir_final.json 未找到: {config.IR_FINAL_JSON}")
|
||||
print(" 请先运行 step3_merge_and_audit.py")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_audit_report():
|
||||
"""Load ir_audit_report.md if it exists."""
|
||||
try:
|
||||
with open(config.IR_AUDIT_REPORT_MD, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
print(f"{FAIL} ir_audit_report.md 未找到: {config.IR_AUDIT_REPORT_MD}")
|
||||
print(" 请先运行 step3_merge_and_audit.py")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def check_rule_ids(ir: dict) -> list[str]:
|
||||
"""Check for duplicate rule_ids and hierarchical naming convention.
|
||||
|
||||
Format: DRL-001-DOMESTIC-SYS-FG-INTERRUPT-01
|
||||
"""
|
||||
errors = []
|
||||
rules = ir.get("rules", [])
|
||||
rule_ids = [r.get("rule_id", "") for r in rules]
|
||||
|
||||
# No duplicates
|
||||
duplicates = [rid for rid, count in Counter(rule_ids).items() if count > 1]
|
||||
if duplicates:
|
||||
errors.append(f"重复 rule_id: {duplicates}")
|
||||
|
||||
# New hierarchical naming convention
|
||||
pattern = re.compile(
|
||||
r"^[A-Z]+-\d{3}-(DOMESTIC|OVERSEAS)-"
|
||||
r"(SYS|SDK|OTHER)-"
|
||||
r"(FG-INTERRUPT|BG-BLOCK|BG-PAUSE|NO-RESTRICT|SWITCH-OFF)-\d{2}$"
|
||||
)
|
||||
for rid in rule_ids:
|
||||
if rid and not pattern.match(rid):
|
||||
errors.append(
|
||||
f"rule_id 命名不规范: '{rid}' "
|
||||
f"(期望: FEATURE-SCOPE-METHOD-BEHAVIOR-NN)"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_top_level_structure(ir: dict) -> list[str]:
|
||||
"""Check that ir_final has the required top-level fields."""
|
||||
errors = []
|
||||
for field in ["feature", "feature_id", "rules"]:
|
||||
if field not in ir:
|
||||
errors.append(f"ir_final 缺少顶层字段: {field}")
|
||||
|
||||
if not isinstance(ir.get("rules"), list):
|
||||
errors.append("ir_final.rules 必须是数组")
|
||||
elif len(ir["rules"]) == 0:
|
||||
errors.append("ir_final.rules 为空")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_rule_paths(rules: list[dict]) -> list[str]:
|
||||
"""Every rule must have a non-empty path array."""
|
||||
errors = []
|
||||
for rule in rules:
|
||||
rid = rule.get("rule_id", "?")
|
||||
path = rule.get("path", [])
|
||||
if not path:
|
||||
errors.append(f"{rid}: path 字段为空或缺失")
|
||||
return errors
|
||||
|
||||
|
||||
def check_rule_completeness(rules: list[dict]) -> list[str]:
|
||||
"""Check each rule has all required fields."""
|
||||
errors = []
|
||||
required_fields = [
|
||||
"rule_id", "description", "priority", "sources",
|
||||
"precondition", "trigger", "actions"
|
||||
]
|
||||
for i, rule in enumerate(rules):
|
||||
rid = rule.get("rule_id", f"rule[{i}]")
|
||||
for field in required_fields:
|
||||
if field not in rule:
|
||||
errors.append(f"{rid}: 缺少字段 '{field}'")
|
||||
if not rule.get("sources"):
|
||||
errors.append(f"{rid}: sources 为空")
|
||||
if not rule.get("actions"):
|
||||
errors.append(f"{rid}: actions 为空")
|
||||
# Check precondition fields
|
||||
precond = rule.get("precondition", {})
|
||||
if not precond.get("geographic_scope"):
|
||||
errors.append(f"{rid}: precondition.geographic_scope 缺失")
|
||||
if "screen_type" not in precond:
|
||||
errors.append(f"{rid}: precondition.screen_type 缺失")
|
||||
return errors
|
||||
|
||||
|
||||
def check_audit_report(report: str) -> list[str]:
|
||||
"""Check audit report has all required sections."""
|
||||
errors = []
|
||||
|
||||
required_sections = [
|
||||
"逻辑树路径覆盖率",
|
||||
"表格枚举覆盖",
|
||||
"开关状态",
|
||||
"一致性扫描报告",
|
||||
"自动补全摘要",
|
||||
"规则清单",
|
||||
]
|
||||
for section in required_sections:
|
||||
if section not in report:
|
||||
errors.append(f"审计报告缺少章节: {section}")
|
||||
|
||||
# Should have the human review notice
|
||||
if "人工审查" not in report:
|
||||
errors.append("审计报告缺少人工审查提示")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
print("=" * 60)
|
||||
print("Step 3 自检测试")
|
||||
print("=" * 60)
|
||||
|
||||
ir = load_ir_final()
|
||||
report = load_audit_report()
|
||||
rules = ir.get("rules", [])
|
||||
all_errors = []
|
||||
|
||||
# Test 1: Top-level structure
|
||||
errors = check_top_level_structure(ir)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 顶层结构检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 顶层结构检查: 通过 "
|
||||
f"(feature={ir.get('feature')}, feature_id={ir.get('feature_id')})")
|
||||
|
||||
# Test 2: rule_id uniqueness and naming
|
||||
errors = check_rule_ids(ir)
|
||||
if errors:
|
||||
print(f"\n{FAIL} rule_id 检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} rule_id 检查: 全部通过 ({len(rules)} 个唯一 ID, 层次化格式)")
|
||||
|
||||
# Test 3: Rule path fields
|
||||
errors = check_rule_paths(rules)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 规则 path 字段: {len(errors)} 个错误")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 规则 path 字段: 全部通过")
|
||||
|
||||
# Test 4: Rule field completeness
|
||||
errors = check_rule_completeness(rules)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 规则字段完整性: {len(errors)} 个错误")
|
||||
for e in errors[:10]:
|
||||
print(f" - {e}")
|
||||
if len(errors) > 10:
|
||||
print(f" ... 还有 {len(errors) - 10} 个")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 规则字段完整性: 全部通过")
|
||||
|
||||
# Test 5: Audit report content
|
||||
errors = check_audit_report(report)
|
||||
if errors:
|
||||
print(f"\n{FAIL} 审计报告检查: {len(errors)} 个错误")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
all_errors.extend(errors)
|
||||
else:
|
||||
print(f"\n{PASS} 审计报告检查: 全部通过 (6 个章节)")
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
total_failures = len(all_errors)
|
||||
|
||||
if total_failures == 0:
|
||||
print(f"{PASS} 所有测试通过!")
|
||||
print(f"\n最终交付物:")
|
||||
print(f" - {config.IR_FINAL_JSON} ({len(rules)} 条规则)")
|
||||
print(f" - {config.IR_AUDIT_REPORT_MD}")
|
||||
else:
|
||||
print(f"{FAIL} 测试失败: {total_failures} 个错误")
|
||||
print("\n建议: 检查 ir_fragments.json 和合并逻辑,修复问题后重新运行 step3_merge_and_audit.py")
|
||||
|
||||
return total_failures == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -0,0 +1 @@
|
||||
# Tests package for document_analyzer
|
||||
@@ -0,0 +1 @@
|
||||
# QE Acceptance Tests for document_analyzer
|
||||
@@ -0,0 +1,186 @@
|
||||
"""Pytest configuration and shared fixtures for QE acceptance tests.
|
||||
|
||||
Usage::
|
||||
|
||||
pytest tests/acceptance/ -v --run-acceptance [--acceptance-runs=3]
|
||||
|
||||
Environment variables:
|
||||
DASHSCOPE_API_KEY — LLM API key (required for Layers B/C)
|
||||
TEST_IR_PATH — path to IR JSON to validate (default: ir_final.json sample)
|
||||
TEST_PARSED_PATH — path to _parsed.json or _updated.json for coverage analysis
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
# ── Path setup ──────────────────────────────────────────────────────────────
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||
|
||||
|
||||
def _skill_path(skill_name: str) -> str:
|
||||
return str(_PROJECT_ROOT / "skills" / skill_name / "scripts")
|
||||
|
||||
|
||||
# ── pytest configuration ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--run-acceptance",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run QE acceptance tests (requires DASHSCOPE_API_KEY)",
|
||||
)
|
||||
parser.addoption(
|
||||
"--acceptance-runs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of IR generation runs for Layer B stability testing (default: 1 = skip)",
|
||||
)
|
||||
parser.addoption(
|
||||
"--ir-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to IR JSON file to validate",
|
||||
)
|
||||
parser.addoption(
|
||||
"--parsed-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to _parsed.json or _updated.json for coverage analysis",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"acceptance: QE acceptance test (requires --run-acceptance flag and DASHSCOPE_API_KEY)",
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
acceptance_dir = str(_PROJECT_ROOT / "tests" / "acceptance")
|
||||
acceptance_items = [i for i in items if str(i.fspath).startswith(acceptance_dir)]
|
||||
non_acceptance_items = [i for i in items if not str(i.fspath).startswith(acceptance_dir)]
|
||||
|
||||
if not config.getoption("--run-acceptance"):
|
||||
skip_msg = pytest.mark.skip(reason="Need --run-acceptance flag to run")
|
||||
for item in acceptance_items:
|
||||
item.add_marker(skip_msg)
|
||||
# Don't skip non-acceptance tests
|
||||
return
|
||||
|
||||
if not os.environ.get("DASHSCOPE_API_KEY"):
|
||||
skip_msg = pytest.mark.skip(reason="DASHSCOPE_API_KEY not set")
|
||||
for item in acceptance_items:
|
||||
item.add_marker(skip_msg)
|
||||
|
||||
|
||||
# ── Shared fixtures ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def project_root() -> Path:
|
||||
return _PROJECT_ROOT
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def ir_path(request) -> str:
|
||||
"""Path to the IR JSON file under test."""
|
||||
path = (
|
||||
request.config.getoption("--ir-path")
|
||||
or os.environ.get("TEST_IR_PATH")
|
||||
or str(
|
||||
Path.home()
|
||||
/ ".openclaw/workspace/skills/doc_parser_skill/output/ir_final.json"
|
||||
)
|
||||
)
|
||||
if not os.path.exists(path):
|
||||
pytest.skip(f"IR file not found: {path}")
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def ir_data(ir_path: str) -> dict:
|
||||
"""Load the IR JSON data."""
|
||||
with open(ir_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def parsed_path(request) -> str | None:
|
||||
"""Path to the corresponding _parsed.json or _updated.json."""
|
||||
path = (
|
||||
request.config.getoption("--parsed-path")
|
||||
or os.environ.get("TEST_PARSED_PATH")
|
||||
or str(
|
||||
_PROJECT_ROOT
|
||||
/ "skills/ir_generation_skill/车机娱乐系统禁止功能文档_精简_updated.json"
|
||||
)
|
||||
)
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def parsed_data(parsed_path: str | None) -> dict | None:
|
||||
"""Load the parsed document JSON for coverage analysis."""
|
||||
if parsed_path is None:
|
||||
return None
|
||||
with open(parsed_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llm_client():
|
||||
"""Create an LLMClient instance for acceptance tests.
|
||||
|
||||
Uses the DashScope-compatible LLMClient from the project.
|
||||
"""
|
||||
sys.path.insert(0, _skill_path("doc_parser_skill"))
|
||||
from LLM import LLMClient
|
||||
|
||||
return LLMClient()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def acceptance_runs(request) -> int:
|
||||
return request.config.getoption("--acceptance-runs", default=1)
|
||||
|
||||
|
||||
# ── Pipeline runner ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def run_ir_pipeline():
|
||||
"""Return a callable that runs the IR generation pipeline on a parsed JSON.
|
||||
|
||||
Usage::
|
||||
|
||||
ir_data, ir_path = run_ir_pipeline(parsed_json_path, output_dir)
|
||||
"""
|
||||
sys.path.insert(0, _skill_path("ir_generation_skill"))
|
||||
from ir_generator import generate_ir
|
||||
|
||||
def _run(parsed_path: str, output_dir: str | None = None) -> tuple[dict, str]:
|
||||
"""Run IR generation and return (ir_data, ir_path)."""
|
||||
out = output_dir or tempfile.mkdtemp(prefix="qe_acceptance_")
|
||||
result = generate_ir(parsed_path, out, dry_run=False)
|
||||
ir_list = result.get("ir", [])
|
||||
ir_path = result.get("path", "")
|
||||
# ir_generator produces a list; wrap to match rich format expectations
|
||||
# for schema validation we accept both formats
|
||||
return ir_list, ir_path
|
||||
|
||||
return _run
|
||||
@@ -0,0 +1,325 @@
|
||||
"""Rich IR schema definition and validators for the document_analyzer QE framework.
|
||||
|
||||
Target format is the production IR (``ir_final.json``):
|
||||
{feature, feature_id, config_defaults?, rules: [{rule_id, path, description,
|
||||
priority, sources, precondition, trigger, actions}]}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
# ── Constants ────────────────────────────────────────────────────────────────
|
||||
|
||||
VALID_SOURCE_TYPES = {"table", "logic_tree", "text"}
|
||||
VALID_ACTION_TYPES = {"system", "user_interaction"}
|
||||
VALID_PRIORITIES = {"P0", "P1", "P2"}
|
||||
VALID_TRIGGER_OPERATORS = {"AND", "OR"}
|
||||
|
||||
# rule_id pattern: FEAT-NNN-SCOPE-TYPE-...-PATH-NN (variable middle segments)
|
||||
RULE_ID_RE = re.compile(
|
||||
r"^[A-Z]+-\d+(-[A-Z]+)+-\d+$"
|
||||
)
|
||||
|
||||
|
||||
# ── Validation helpers ──────────────────────────────────────────────────────
|
||||
|
||||
def _check(condition: bool, message: str) -> list[str]:
|
||||
"""Return a list with an error message if *condition* is False, else empty list."""
|
||||
return [] if condition else [message]
|
||||
|
||||
|
||||
def validate_rule(rule: dict, index: int = 0) -> list[str]:
|
||||
"""Validate a single rule dict. Returns a (possibly empty) list of error strings."""
|
||||
errors: list[str] = []
|
||||
label = f"rules[{index}]"
|
||||
|
||||
if not isinstance(rule, dict):
|
||||
return [f"{label}: not a dict"]
|
||||
|
||||
# ── required top-level fields ──
|
||||
for field in ("rule_id", "description"):
|
||||
errors.extend(_check(
|
||||
isinstance(rule.get(field), str) and bool(rule[field].strip()),
|
||||
f'{label}.{field}: required non-empty string',
|
||||
))
|
||||
|
||||
# sources is a list, not a string — validated separately below
|
||||
|
||||
# ── rule_id naming ──
|
||||
rid = rule.get("rule_id", "")
|
||||
if rid and isinstance(rid, str):
|
||||
errors.extend(_check(
|
||||
bool(RULE_ID_RE.match(rid)),
|
||||
f'{label}.rule_id: "{rid}" does not match pattern FEAT-NNN-SCOPE-TYPE-PATH-NN',
|
||||
))
|
||||
|
||||
# ── priority ──
|
||||
priority = rule.get("priority")
|
||||
if priority is not None:
|
||||
errors.extend(_check(
|
||||
priority in VALID_PRIORITIES,
|
||||
f'{label}.priority: "{priority}" not in {VALID_PRIORITIES}',
|
||||
))
|
||||
|
||||
# ── path ──
|
||||
path = rule.get("path")
|
||||
if path is not None:
|
||||
if not isinstance(path, list):
|
||||
errors.append(f"{label}.path: must be a list")
|
||||
elif len(path) == 0:
|
||||
errors.append(f"{label}.path: must not be empty")
|
||||
elif not all(isinstance(p, str) and p.strip() for p in path):
|
||||
errors.append(f"{label}.path: all segments must be non-empty strings")
|
||||
|
||||
# ── sources[] ──
|
||||
sources = rule.get("sources", [])
|
||||
if not isinstance(sources, list):
|
||||
errors.append(f"{label}.sources: must be a list")
|
||||
elif len(sources) == 0:
|
||||
errors.append(f"{label}.sources: must have at least one source")
|
||||
else:
|
||||
for si, src in enumerate(sources):
|
||||
errors.extend(_validate_source(src, f"{label}.sources[{si}]"))
|
||||
|
||||
# ── precondition ──
|
||||
precondition = rule.get("precondition")
|
||||
if precondition is not None:
|
||||
if not isinstance(precondition, dict):
|
||||
errors.append(f"{label}.precondition: must be a dict")
|
||||
elif len(precondition) == 0:
|
||||
errors.append(f"{label}.precondition: must not be empty")
|
||||
|
||||
# ── trigger ──
|
||||
trigger = rule.get("trigger")
|
||||
if trigger is not None:
|
||||
if not isinstance(trigger, dict):
|
||||
errors.append(f"{label}.trigger: must be a dict")
|
||||
else:
|
||||
errors.extend(_validate_trigger(trigger, f"{label}.trigger"))
|
||||
|
||||
# ── actions ──
|
||||
actions = rule.get("actions")
|
||||
if actions is not None:
|
||||
if not isinstance(actions, list):
|
||||
errors.append(f"{label}.actions: must be a list")
|
||||
else:
|
||||
for ai, act in enumerate(actions):
|
||||
errors.extend(_validate_action(act, f"{label}.actions[{ai}]"))
|
||||
|
||||
# ── no null values at any depth ──
|
||||
errors.extend(_find_nulls(rule, label))
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _validate_source(src: dict, label: str) -> list[str]:
|
||||
errors: list[str] = []
|
||||
if not isinstance(src, dict):
|
||||
return [f"{label}: not a dict"]
|
||||
|
||||
stype = src.get("type", "")
|
||||
errors.extend(_check(
|
||||
stype in VALID_SOURCE_TYPES,
|
||||
f'{label}.type: "{stype}" not in {VALID_SOURCE_TYPES}',
|
||||
))
|
||||
|
||||
priority = src.get("priority", "")
|
||||
if priority:
|
||||
errors.extend(_check(
|
||||
priority in ("primary_source", "supplementary"),
|
||||
f'{label}.priority: "{priority}" must be primary_source or supplementary',
|
||||
))
|
||||
|
||||
# type-specific fields
|
||||
if stype == "table":
|
||||
errors.extend(_check(
|
||||
isinstance(src.get("section"), str) and bool(src["section"].strip()),
|
||||
f"{label}.section: required non-empty string for table source",
|
||||
))
|
||||
errors.extend(_check(
|
||||
isinstance(src.get("row"), int),
|
||||
f"{label}.row: required int for table source",
|
||||
))
|
||||
elif stype == "logic_tree":
|
||||
errors.extend(_check(
|
||||
isinstance(src.get("image_id"), str) and bool(src["image_id"].strip()),
|
||||
f"{label}.image_id: required non-empty string for logic_tree source",
|
||||
))
|
||||
node_ids = src.get("node_ids", [])
|
||||
errors.extend(_check(
|
||||
isinstance(node_ids, list) and len(node_ids) > 0,
|
||||
f"{label}.node_ids: required non-empty list for logic_tree source",
|
||||
))
|
||||
elif stype == "text":
|
||||
errors.extend(_check(
|
||||
isinstance(src.get("section"), str) and bool(src["section"].strip()),
|
||||
f"{label}.section: required non-empty string for text source",
|
||||
))
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _validate_trigger(trigger: dict, label: str) -> list[str]:
|
||||
errors: list[str] = []
|
||||
operator = trigger.get("operator", "")
|
||||
errors.extend(_check(
|
||||
operator in VALID_TRIGGER_OPERATORS,
|
||||
f'{label}.operator: "{operator}" not in {VALID_TRIGGER_OPERATORS}',
|
||||
))
|
||||
|
||||
conditions = trigger.get("conditions")
|
||||
if conditions is not None:
|
||||
if not isinstance(conditions, list):
|
||||
errors.append(f"{label}.conditions: must be a list")
|
||||
else:
|
||||
for ci, cond in enumerate(conditions):
|
||||
if not isinstance(cond, dict):
|
||||
errors.append(f"{label}.conditions[{ci}]: not a dict")
|
||||
else:
|
||||
errors.extend(_check(
|
||||
isinstance(cond.get("signal"), str) and bool(cond["signal"].strip()),
|
||||
f"{label}.conditions[{ci}].signal: required non-empty string",
|
||||
))
|
||||
errors.extend(_check(
|
||||
"operator" in cond,
|
||||
f"{label}.conditions[{ci}].operator: required",
|
||||
))
|
||||
# empty conditions is valid (e.g. "switch always off, no conditions")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _validate_action(action: dict, label: str) -> list[str]:
|
||||
errors: list[str] = []
|
||||
if not isinstance(action, dict):
|
||||
return [f"{label}: not a dict"]
|
||||
|
||||
atype = action.get("type", "")
|
||||
errors.extend(_check(
|
||||
atype in VALID_ACTION_TYPES,
|
||||
f'{label}.type: "{atype}" not in {VALID_ACTION_TYPES}',
|
||||
))
|
||||
errors.extend(_check(
|
||||
isinstance(action.get("description"), str) and bool(action["description"].strip()),
|
||||
f"{label}.description: required non-empty string",
|
||||
))
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _find_nulls(obj: Any, label: str) -> list[str]:
|
||||
"""Find any None values at any depth in *obj*."""
|
||||
errors: list[str] = []
|
||||
if obj is None:
|
||||
return [f"{label}: null value"]
|
||||
elif isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
errors.extend(_find_nulls(v, f"{label}.{k}"))
|
||||
elif isinstance(obj, list):
|
||||
for i, v in enumerate(obj):
|
||||
errors.extend(_find_nulls(v, f"{label}[{i}]"))
|
||||
return errors
|
||||
|
||||
|
||||
# ── Top-level validation ────────────────────────────────────────────────────
|
||||
|
||||
def validate_ir(ir_data: dict) -> dict:
|
||||
"""Validate the entire IR document.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"valid": bool,
|
||||
"errors": [str, ...],
|
||||
"stats": {total_rules, valid_rules, has_config_defaults, ...}
|
||||
}
|
||||
"""
|
||||
errors: list[str] = []
|
||||
stats = {"total_rules": 0, "valid_rules": 0, "has_config_defaults": False, "features": 0}
|
||||
|
||||
if not isinstance(ir_data, dict):
|
||||
return {"valid": False, "errors": ["IR root is not a dict"], "stats": stats}
|
||||
|
||||
# top-level required fields
|
||||
for field in ("feature", "feature_id", "rules"):
|
||||
if field not in ir_data:
|
||||
errors.append(f"root.{field}: missing required field")
|
||||
elif field in ("feature", "feature_id") and not (
|
||||
isinstance(ir_data[field], str) and ir_data[field].strip()
|
||||
):
|
||||
errors.append(f"root.{field}: must be non-empty string")
|
||||
|
||||
# config_defaults (optional)
|
||||
if "config_defaults" in ir_data:
|
||||
stats["has_config_defaults"] = True
|
||||
cd = ir_data["config_defaults"]
|
||||
if not isinstance(cd, dict):
|
||||
errors.append("root.config_defaults: must be a dict")
|
||||
|
||||
# rules array
|
||||
rules = ir_data.get("rules", [])
|
||||
if not isinstance(rules, list):
|
||||
errors.append("root.rules: must be a list")
|
||||
else:
|
||||
stats["total_rules"] = len(rules)
|
||||
if len(rules) == 0:
|
||||
errors.append("root.rules: must have at least one rule")
|
||||
else:
|
||||
for i, rule in enumerate(rules):
|
||||
rule_errors = validate_rule(rule, i)
|
||||
if rule_errors:
|
||||
errors.extend(rule_errors)
|
||||
else:
|
||||
stats["valid_rules"] += 1
|
||||
|
||||
# feature count
|
||||
if isinstance(ir_data.get("feature_id"), str):
|
||||
stats["features"] = 1
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors,
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
|
||||
# ── Summary helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
def schema_checklist(ir_data: dict) -> list[dict]:
|
||||
"""Run individual checks and return a checklist for reporting.
|
||||
|
||||
Each item: {"check": str, "passed": bool, "detail": str}
|
||||
"""
|
||||
report = validate_ir(ir_data)
|
||||
checks: list[dict] = []
|
||||
|
||||
def _add(name: str, passed: bool, detail: str = ""):
|
||||
checks.append({"check": name, "passed": passed, "detail": detail})
|
||||
|
||||
# Top-level
|
||||
_add("root is dict", isinstance(ir_data, dict))
|
||||
_add("root.feature present", isinstance(ir_data.get("feature"), str) and bool(ir_data["feature"].strip()))
|
||||
_add("root.feature_id present", isinstance(ir_data.get("feature_id"), str) and bool(ir_data["feature_id"].strip()))
|
||||
_add("root.rules is non-empty list", isinstance(ir_data.get("rules"), list) and len(ir_data["rules"]) > 0)
|
||||
|
||||
# Per-rule checks
|
||||
rules = ir_data.get("rules", []) if isinstance(ir_data, dict) else []
|
||||
rule_ids = []
|
||||
for i, rule in enumerate(rules):
|
||||
if not isinstance(rule, dict):
|
||||
continue
|
||||
rid = rule.get("rule_id", f"rules[{i}]")
|
||||
rule_ids.append(rid)
|
||||
|
||||
errs = validate_rule(rule, i)
|
||||
_add(f"{rid}: valid", len(errs) == 0, "; ".join(errs) if errs else "")
|
||||
|
||||
# Aggregate checks
|
||||
_add("no duplicate rule_ids", len(rule_ids) == len(set(rule_ids)),
|
||||
f"duplicates: {[r for r in rule_ids if rule_ids.count(r) > 1]}" if len(rule_ids) != len(set(rule_ids)) else "")
|
||||
|
||||
_add("all rules valid", report["valid"],
|
||||
f"{report['stats']['valid_rules']}/{report['stats']['total_rules']} valid")
|
||||
|
||||
return checks
|
||||
@@ -0,0 +1,178 @@
|
||||
"""Structured JSON report generation for QE acceptance test results.
|
||||
|
||||
Produces a unified report with three-layer verdict:
|
||||
Layer A – Schema compliance
|
||||
Layer B – Structural coverage + stability
|
||||
Layer C – LLM QE expert audit
|
||||
|
||||
Final verdict: PASS (releasable) or FAIL (blocked).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def generate_report(
|
||||
schema_result: dict,
|
||||
coverage_result: dict,
|
||||
audit_result: dict | None,
|
||||
*,
|
||||
commit: str = "",
|
||||
branch: str = "main",
|
||||
output_path: str | None = None,
|
||||
) -> dict:
|
||||
"""Assemble the three-layer report and return it.
|
||||
|
||||
Args:
|
||||
schema_result: ``{"verdict": "PASS"|"FAIL", "total_checks": N, "passed": N, "failed": N}``
|
||||
coverage_result: ``{"verdict": "PASS"|"FAIL", "coverage_rate": float,
|
||||
"stability": {"runs": N, "values": [...], "std": float}}``
|
||||
audit_result: ``{"verdict": "ACCEPT"|"REJECT", "inadequate_ratio": float,
|
||||
"rationale": str, "section_assessments": [...]}`` or None
|
||||
commit: git commit SHA
|
||||
branch: branch name
|
||||
output_path: if set, write the report JSON to this path
|
||||
|
||||
Returns the report dict.
|
||||
"""
|
||||
layers: dict[str, Any] = {
|
||||
"A_schema": schema_result,
|
||||
"B_coverage": coverage_result,
|
||||
}
|
||||
if audit_result is not None:
|
||||
layers["C_qe_audit"] = audit_result
|
||||
|
||||
# ── final verdict ──
|
||||
a_pass = schema_result.get("verdict") == "PASS"
|
||||
b_pass = coverage_result.get("verdict") == "PASS"
|
||||
c_pass = (
|
||||
audit_result is None
|
||||
or audit_result.get("verdict") == "ACCEPT"
|
||||
)
|
||||
all_pass = a_pass and b_pass and c_pass
|
||||
|
||||
report = {
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"commit": commit,
|
||||
"branch": branch,
|
||||
"layers": layers,
|
||||
"final_verdict": "PASS" if all_pass else "FAIL",
|
||||
"releasable": all_pass,
|
||||
"failure_details": _failure_details(layers),
|
||||
}
|
||||
|
||||
if output_path:
|
||||
out = Path(output_path)
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
out.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def _failure_details(layers: dict) -> list[str]:
|
||||
"""Summarise which layers failed and why."""
|
||||
details: list[str] = []
|
||||
|
||||
schema = layers.get("A_schema", {})
|
||||
if schema.get("verdict") != "PASS":
|
||||
details.append(
|
||||
f"Layer A (Schema): {schema.get('failed', '?')}/{schema.get('total_checks', '?')} checks failed"
|
||||
)
|
||||
|
||||
coverage = layers.get("B_coverage", {})
|
||||
if coverage.get("verdict") != "PASS":
|
||||
cv = coverage.get("coverage_rate", "?")
|
||||
details.append(f"Layer B (Coverage): rate={cv} (threshold: 0.70)")
|
||||
|
||||
audit = layers.get("C_qe_audit", {})
|
||||
if audit.get("verdict") == "REJECT":
|
||||
details.append(
|
||||
f"Layer C (QE Audit): REJECT — inadequate_ratio={audit.get('inadequate_ratio', '?')}"
|
||||
)
|
||||
|
||||
return details
|
||||
|
||||
|
||||
# ── Layer-specific result builders ──────────────────────────────────────────
|
||||
|
||||
def schema_verdict(errors: list[str], stats: dict) -> dict:
|
||||
"""Build Layer A result from schema validation errors & stats."""
|
||||
total = stats.get("total_rules", 0)
|
||||
valid = stats.get("valid_rules", 0)
|
||||
failed_checks = len(errors) + (total - valid)
|
||||
|
||||
return {
|
||||
"verdict": "PASS" if failed_checks == 0 else "FAIL",
|
||||
"total_checks": max(total, 1), # at minimum, we checked the root
|
||||
"passed": valid if failed_checks == 0 else valid,
|
||||
"failed": failed_checks,
|
||||
"rule_pass_rate": round(valid / max(total, 1), 2) if total > 0 else 0,
|
||||
"sample_errors": errors[:10], # first 10 for the report
|
||||
}
|
||||
|
||||
|
||||
def coverage_verdict(
|
||||
coverage_rate: float,
|
||||
stability_std: float,
|
||||
stability_values: list[float],
|
||||
*,
|
||||
coverage_threshold: float = 0.70,
|
||||
stability_threshold: float = 0.05,
|
||||
section_coverage: dict | None = None,
|
||||
table_coverage: dict | None = None,
|
||||
diagram_coverage: dict | None = None,
|
||||
) -> dict:
|
||||
"""Build Layer B result from coverage metrics."""
|
||||
b1_pass = coverage_rate >= coverage_threshold
|
||||
b2_pass = stability_std <= stability_threshold
|
||||
both_pass = b1_pass and b2_pass
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"verdict": "PASS" if both_pass else "FAIL",
|
||||
"coverage_rate": round(coverage_rate, 3),
|
||||
"coverage_threshold": coverage_threshold,
|
||||
"coverage_pass": b1_pass,
|
||||
"stability": {
|
||||
"runs": len(stability_values),
|
||||
"values": [round(v, 3) for v in stability_values],
|
||||
"std": round(stability_std, 4),
|
||||
"threshold": stability_threshold,
|
||||
"pass": b2_pass,
|
||||
},
|
||||
}
|
||||
|
||||
if section_coverage:
|
||||
result["section_coverage"] = section_coverage
|
||||
if table_coverage:
|
||||
result["table_coverage"] = table_coverage
|
||||
if diagram_coverage:
|
||||
result["diagram_coverage"] = diagram_coverage
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def audit_verdict(audit_data: dict, *, inadequate_threshold: float = 0.30) -> dict:
|
||||
"""Build Layer C result from LLM QE audit.
|
||||
|
||||
*audit_data* should contain:
|
||||
inadequate_ratio: float
|
||||
rationale: str
|
||||
section_assessments: list[dict]
|
||||
"""
|
||||
ratio = audit_data.get("inadequate_ratio", 1.0)
|
||||
passed = ratio <= inadequate_threshold
|
||||
|
||||
return {
|
||||
"verdict": "ACCEPT" if passed else "REJECT",
|
||||
"inadequate_ratio": round(ratio, 3),
|
||||
"threshold": inadequate_threshold,
|
||||
"rationale": audit_data.get("rationale", ""),
|
||||
"total_sections": audit_data.get("total_functional_sections", 0),
|
||||
"adequate": audit_data.get("adequate", 0),
|
||||
"inadequate": audit_data.get("inadequate", 0),
|
||||
"not_applicable": audit_data.get("not_applicable", 0),
|
||||
}
|
||||
@@ -0,0 +1,558 @@
|
||||
"""QE Acceptance Test — Three-layer main branch health check.
|
||||
|
||||
Layer A (Schema): structural correctness of IR
|
||||
Layer B (Coverage): structural source-traceability coverage + stability
|
||||
Layer C (QE Audit): LLM as QE expert — functional coverage assessment
|
||||
|
||||
Final verdict: all three layers must pass for main to be releasable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import statistics
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from .ir_schema import validate_ir, schema_checklist
|
||||
from .report import generate_report, schema_verdict, coverage_verdict, audit_verdict
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Layer A: SCHEMA — deterministic structural validation
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def test_layer_a_schema(ir_data: dict, request):
|
||||
"""Validate IR structure: required fields, types, naming conventions, no nulls."""
|
||||
report = validate_ir(ir_data)
|
||||
checks = schema_checklist(ir_data)
|
||||
|
||||
# Build Layer A result
|
||||
a_errors = report["errors"]
|
||||
a_stats = report["stats"]
|
||||
a_result = schema_verdict(a_errors, a_stats)
|
||||
a_result["checks"] = checks
|
||||
|
||||
# Store for downstream layers & report
|
||||
_stash(request, "layer_a", a_result)
|
||||
|
||||
# Assert
|
||||
assert report["valid"], (
|
||||
f"Schema validation FAILED ({len(a_errors)} errors)\n"
|
||||
+ "\n".join(f" - {e}" for e in a_errors[:20])
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Layer B: STRUCTURAL COVERAGE + STABILITY
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# Section titles that are NOT functional requirements
|
||||
NON_FUNCTIONAL_PATTERNS = [
|
||||
re.compile(p) for p in [
|
||||
r"编制.*变更.*日志",
|
||||
r"文档背景",
|
||||
r"文档范围",
|
||||
r"术语解释",
|
||||
r"参考",
|
||||
r"附录",
|
||||
r"版本",
|
||||
r"变更记录",
|
||||
r"目录",
|
||||
r"前言",
|
||||
r"概述",
|
||||
r"简介",
|
||||
r"概述.*背景",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def _is_functional_section(section_name: str) -> bool:
|
||||
"""Heuristic: exclude background, glossary, changelog, scope sections.
|
||||
|
||||
Sections that are purely structural — preface, glossary, changelog — are excluded.
|
||||
Sections with numbering like '3.1.1' are always considered functional.
|
||||
"""
|
||||
# Numbered sections are functional
|
||||
if _section_number(section_name) != section_name:
|
||||
return True
|
||||
for pat in NON_FUNCTIONAL_PATTERNS:
|
||||
if pat.search(section_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _extract_content_units(parsed_data: dict) -> dict:
|
||||
"""Extract countable content units from parsed JSON.
|
||||
|
||||
Returns:
|
||||
{"sections": [{"name": ..., "number": ...}, ...],
|
||||
"table_rows": int, "diagram_images": [rid, ...]}
|
||||
"""
|
||||
sections = parsed_data.get("sections", [])
|
||||
|
||||
functional_sections: list[dict] = []
|
||||
total_table_rows = 0
|
||||
|
||||
for sec in sections:
|
||||
name = sec.get("source", "")
|
||||
if _is_functional_section(name):
|
||||
functional_sections.append({
|
||||
"name": name,
|
||||
"number": _section_number(name),
|
||||
})
|
||||
|
||||
for block in sec.get("blocks", []):
|
||||
if block.get("type") == "table":
|
||||
rows = block.get("rows", [])
|
||||
total_table_rows += len(rows)
|
||||
|
||||
# Diagram-type images from image_analysis
|
||||
diagram_rids: list[str] = []
|
||||
for img in parsed_data.get("image_analysis", []):
|
||||
img_type = img.get("type", "")
|
||||
if img_type in ("flowchart", "logic_tree", "architecture",
|
||||
"state", "sequence", "activity"):
|
||||
diagram_rids.append(img.get("rid", ""))
|
||||
|
||||
return {
|
||||
"functional_sections": functional_sections,
|
||||
"table_rows": total_table_rows,
|
||||
"diagram_images": diagram_rids,
|
||||
}
|
||||
|
||||
|
||||
def _section_number(section_name: str) -> str:
|
||||
"""Extract leading section number, e.g. '3.1.1 系统限制' → '3.1.1'."""
|
||||
import re
|
||||
m = re.match(r"^([\d.]+)", section_name)
|
||||
return m.group(1) if m else section_name
|
||||
|
||||
|
||||
def _section_matches(sec_ref: str, func_sections: list[dict]) -> str | None:
|
||||
"""Find a functional section matching *sec_ref*. Returns the section name or None.
|
||||
|
||||
Matching: exact match → starts-with match → number match → substring match.
|
||||
"""
|
||||
# exact
|
||||
for s in func_sections:
|
||||
if s["name"] == sec_ref:
|
||||
return s["name"]
|
||||
# starts with section number
|
||||
for s in func_sections:
|
||||
if s["name"].startswith(sec_ref) or sec_ref.startswith(s["name"]):
|
||||
return s["name"]
|
||||
# number match
|
||||
sec_num = _section_number(sec_ref)
|
||||
if sec_num:
|
||||
for s in func_sections:
|
||||
if s["number"] == sec_num:
|
||||
return s["name"]
|
||||
# substring
|
||||
for s in func_sections:
|
||||
if sec_ref in s["name"] or s["name"] in sec_ref:
|
||||
return s["name"]
|
||||
return None
|
||||
|
||||
|
||||
def _measure_coverage(ir_data: dict, parsed_data: dict) -> dict:
|
||||
"""Compute structural coverage of IR over parsed document.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"section_coverage": {total, covered, rate, uncovered},
|
||||
"table_coverage": {total_rows, covered_rows, rate},
|
||||
"diagram_coverage": {total, covered, rate},
|
||||
"overall_rate": float,
|
||||
}
|
||||
"""
|
||||
units = _extract_content_units(parsed_data)
|
||||
rules = ir_data.get("rules", [])
|
||||
|
||||
# ── section coverage ──
|
||||
func_sections = units["functional_sections"]
|
||||
covered_sections: set[str] = set()
|
||||
for rule in rules:
|
||||
for src in rule.get("sources", []):
|
||||
sec_ref = src.get("section", "")
|
||||
if sec_ref:
|
||||
matched = _section_matches(sec_ref, func_sections)
|
||||
if matched:
|
||||
covered_sections.add(matched)
|
||||
|
||||
section_coverage = {
|
||||
"total": len(func_sections),
|
||||
"covered": len(covered_sections),
|
||||
"rate": round(len(covered_sections) / max(len(func_sections), 1), 3),
|
||||
"uncovered": [s["name"] for s in func_sections
|
||||
if s["name"] not in covered_sections],
|
||||
}
|
||||
|
||||
# ── table row coverage ──
|
||||
covered_rows: set[tuple] = set()
|
||||
for rule in rules:
|
||||
for src in rule.get("sources", []):
|
||||
if src.get("type") == "table":
|
||||
sec = src.get("section", "")
|
||||
row = src.get("row")
|
||||
if sec and row is not None:
|
||||
covered_rows.add((sec, row))
|
||||
|
||||
total_rows = units["table_rows"]
|
||||
table_coverage = {
|
||||
"total_rows": total_rows,
|
||||
"covered_rows": len(covered_rows),
|
||||
"rate": round(len(covered_rows) / max(total_rows, 1), 3),
|
||||
}
|
||||
|
||||
# ── diagram coverage ──
|
||||
diagram_rids = units["diagram_images"]
|
||||
covered_rids: set[str] = set()
|
||||
for rule in rules:
|
||||
for src in rule.get("sources", []):
|
||||
if src.get("type") == "logic_tree":
|
||||
img_id = src.get("image_id", "")
|
||||
if img_id and img_id in diagram_rids:
|
||||
covered_rids.add(img_id)
|
||||
|
||||
diagram_coverage = {
|
||||
"total": len(diagram_rids),
|
||||
"covered": len(covered_rids),
|
||||
"rate": round(len(covered_rids) / max(len(diagram_rids), 1), 3),
|
||||
"uncovered": [r for r in diagram_rids if r not in covered_rids],
|
||||
}
|
||||
|
||||
# ── overall ──
|
||||
rates = [
|
||||
section_coverage["rate"],
|
||||
table_coverage["rate"],
|
||||
diagram_coverage["rate"],
|
||||
]
|
||||
overall = round(sum(rates) / len(rates), 3) if rates else 0.0
|
||||
|
||||
return {
|
||||
"section_coverage": section_coverage,
|
||||
"table_coverage": table_coverage,
|
||||
"diagram_coverage": diagram_coverage,
|
||||
"overall_rate": overall,
|
||||
}
|
||||
|
||||
|
||||
def test_layer_b_coverage(
|
||||
ir_data: dict,
|
||||
parsed_data: dict | None,
|
||||
ir_path: str,
|
||||
acceptance_runs: int,
|
||||
run_ir_pipeline,
|
||||
request,
|
||||
):
|
||||
"""Measure structural coverage and (optionally) coverage stability."""
|
||||
if parsed_data is None:
|
||||
pytest.skip("No parsed JSON available for coverage analysis")
|
||||
|
||||
# ── B1: single-run coverage ──
|
||||
cov = _measure_coverage(ir_data, parsed_data)
|
||||
|
||||
# ── B2: stability (multi-run) ──
|
||||
stability_values: list[float] = [cov["overall_rate"]]
|
||||
stability_std = 0.0
|
||||
|
||||
if acceptance_runs > 1:
|
||||
parsed_path = request.config.getoption("--parsed-path")
|
||||
if parsed_path and os.path.exists(parsed_path):
|
||||
for _ in range(acceptance_runs - 1):
|
||||
try:
|
||||
ir_list, _ = run_ir_pipeline(parsed_path)
|
||||
# Convert list-format IR to dict for coverage measurement
|
||||
run_ir = _wrap_list_ir(ir_list)
|
||||
run_cov = _measure_coverage(run_ir, parsed_data)
|
||||
stability_values.append(run_cov["overall_rate"])
|
||||
time.sleep(0.5) # rate limiting between runs
|
||||
except Exception as e:
|
||||
pytest.fail(f"Stability run failed: {e}")
|
||||
|
||||
if len(stability_values) > 1:
|
||||
stability_std = statistics.stdev(stability_values)
|
||||
|
||||
# Build Layer B result
|
||||
b_result = coverage_verdict(
|
||||
coverage_rate=cov["overall_rate"],
|
||||
stability_std=stability_std,
|
||||
stability_values=stability_values,
|
||||
section_coverage=cov["section_coverage"],
|
||||
table_coverage=cov["table_coverage"],
|
||||
diagram_coverage=cov["diagram_coverage"],
|
||||
)
|
||||
_stash(request, "layer_b", b_result)
|
||||
|
||||
# Assert — both B1 and B2 must pass
|
||||
assert b_result["coverage_pass"], (
|
||||
f"Coverage {cov['overall_rate']:.1%} < threshold 70%\n"
|
||||
f" Sections: {cov['section_coverage']['covered']}/{cov['section_coverage']['total']} "
|
||||
f"({cov['section_coverage']['rate']:.1%})\n"
|
||||
f" Uncovered: {cov['section_coverage']['uncovered']}\n"
|
||||
f" Table rows: {cov['table_coverage']['covered_rows']}/{cov['table_coverage']['total_rows']} "
|
||||
f"({cov['table_coverage']['rate']:.1%})\n"
|
||||
f" Diagrams: {cov['diagram_coverage']['covered']}/{cov['diagram_coverage']['total']} "
|
||||
f"({cov['diagram_coverage']['rate']:.1%})\n"
|
||||
f" Uncovered diagrams: {cov['diagram_coverage']['uncovered']}"
|
||||
)
|
||||
|
||||
if len(stability_values) > 1:
|
||||
assert b_result["stability"]["pass"], (
|
||||
f"Coverage stability std={stability_std:.4f} > threshold 0.05\n"
|
||||
f" Values across {len(stability_values)} runs: {stability_values}"
|
||||
)
|
||||
|
||||
|
||||
def _wrap_list_ir(ir_list: list) -> dict:
|
||||
"""Wrap a list-format IR (from ir_generator.py) into a dict for schema compat."""
|
||||
# Convert simple format to rich format for coverage measurement
|
||||
rules = []
|
||||
for i, entry in enumerate(ir_list):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
rule = {
|
||||
"rule_id": f"GEN-001-RULE-{i:03d}",
|
||||
"description": entry.get("function", ""),
|
||||
"path": [],
|
||||
"priority": "P2",
|
||||
"sources": [],
|
||||
"precondition": {},
|
||||
"trigger": entry.get("trigger", {"operator": "AND", "conditions": []}),
|
||||
"actions": [],
|
||||
}
|
||||
# Convert source
|
||||
src = entry.get("source", {})
|
||||
if src.get("section"):
|
||||
rule["sources"].append({
|
||||
"type": "text",
|
||||
"section": src["section"],
|
||||
"paragraph": 1,
|
||||
"text_snippet": src.get("location", ""),
|
||||
"priority": "primary_source",
|
||||
})
|
||||
rules.append(rule)
|
||||
|
||||
return {
|
||||
"feature": "generated",
|
||||
"feature_id": "GEN-001",
|
||||
"rules": rules,
|
||||
}
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Layer C: LLM QE EXPERT AUDIT
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
QE_AUDITOR_PROMPT = """你是一个资深 QE 专家,负责审查需求文档的 IR(中间表示层)是否充分覆盖了源文档的所有可测试功能点。
|
||||
|
||||
你不是 IR 的生成者,你是独立的质量审计员。你的职责是判断 IR 的功能覆盖率是否充分。
|
||||
|
||||
## 审计输入
|
||||
|
||||
### Layer B 结构化覆盖率数据(参考)
|
||||
{coverage_summary}
|
||||
|
||||
### 源文档内容(Parsed JSON)
|
||||
{parsed_content}
|
||||
|
||||
### 生成的 IR(待审计)
|
||||
{ir_content}
|
||||
|
||||
## 审计要求
|
||||
|
||||
对源文档中的每个章节逐一评估其功能需求是否被 IR 充分覆盖。
|
||||
|
||||
**判断标准**:
|
||||
- **adequate**(充分覆盖):该章节的所有功能需求在 IR 中都有对应的 rule,包括触发条件、执行动作
|
||||
- **inadequate**(覆盖不足):该章节存在功能需求未在 IR 中体现,或描述不完整(缺少触发条件或动作)
|
||||
- **not_applicable**(不适用):该章节为背景介绍、术语定义、变更日志等,不包含功能需求
|
||||
|
||||
**注意**:
|
||||
- 如果某个章节涉及多个决策路径(如流程图),检查 IR 是否覆盖了每条路径
|
||||
- 表格中的每个功能行都应被至少一个 IR rule 覆盖
|
||||
- 图片分析中的流程图/决策树节点应被 IR 引用
|
||||
|
||||
## 输出格式
|
||||
|
||||
请严格输出以下 JSON 格式(不要包含代码块标记):
|
||||
|
||||
{{
|
||||
"total_functional_sections": <number>,
|
||||
"adequate": <number>,
|
||||
"inadequate": <number>,
|
||||
"not_applicable": <number>,
|
||||
"inadequate_ratio": <float>,
|
||||
"verdict": "ACCEPT 或 REJECT",
|
||||
"rationale": "<一句话说明接受或拒绝的理由>",
|
||||
"section_assessments": [
|
||||
{{
|
||||
"section": "<章节名>",
|
||||
"assessment": "adequate | inadequate | not_applicable",
|
||||
"reason": "<评估理由>",
|
||||
"missing": ["<缺失项1>", "<缺失项2>"] // 仅 inadequate 时需要
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
verdict 判定规则:
|
||||
- inadequate_ratio ≤ 0.30 → "ACCEPT"(风险可控)
|
||||
- inadequate_ratio > 0.30 → "REJECT"(功能点认知差异大,需要补充 IR)
|
||||
"""
|
||||
|
||||
|
||||
def test_layer_c_qe_audit(
|
||||
ir_data: dict, parsed_data: dict | None, llm_client, request
|
||||
):
|
||||
"""LLM QE expert audit of functional coverage."""
|
||||
if parsed_data is None:
|
||||
pytest.skip("No parsed JSON available — cannot run QE audit")
|
||||
|
||||
# ── get Layer B summary for context ──
|
||||
layer_b = _unstash(request, "layer_b") or {}
|
||||
cov_summary = json.dumps(
|
||||
{
|
||||
"coverage_rate": layer_b.get("coverage_rate", "N/A"),
|
||||
"section_coverage": layer_b.get("section_coverage", {}),
|
||||
"diagram_coverage": layer_b.get("diagram_coverage", {}),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
# ── prepare content (trim to avoid token overflow) ──
|
||||
parsed_str = json.dumps(parsed_data, ensure_ascii=False)
|
||||
ir_str = json.dumps(ir_data, ensure_ascii=False)
|
||||
|
||||
max_parsed = 12000
|
||||
max_ir = 8000
|
||||
if len(parsed_str) > max_parsed:
|
||||
parsed_str = parsed_str[:max_parsed] + "\n...[truncated]"
|
||||
if len(ir_str) > max_ir:
|
||||
ir_str = ir_str[:max_ir] + "\n...[truncated]"
|
||||
|
||||
prompt = QE_AUDITOR_PROMPT.format(
|
||||
coverage_summary=cov_summary,
|
||||
parsed_content=parsed_str,
|
||||
ir_content=ir_str,
|
||||
)
|
||||
|
||||
# ── call LLM ──
|
||||
try:
|
||||
raw = llm_client.chat(
|
||||
model=llm_client.TEXT_MODEL,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"QE audit LLM call failed: {e}")
|
||||
|
||||
# ── parse response ──
|
||||
audit_data = _parse_json_response(raw)
|
||||
if audit_data is None:
|
||||
pytest.fail(f"QE audit returned unparseable response:\n{raw[:500]}")
|
||||
|
||||
# Build Layer C result
|
||||
c_result = audit_verdict(audit_data)
|
||||
c_result["raw_assessments"] = audit_data.get("section_assessments", [])
|
||||
_stash(request, "layer_c", c_result)
|
||||
|
||||
# Assert
|
||||
assert c_result["verdict"] == "ACCEPT", (
|
||||
f"QE Audit REJECTED — inadequate_ratio={c_result['inadequate_ratio']:.1%} > 30%\n"
|
||||
f" Rationale: {c_result['rationale']}\n"
|
||||
f" Adequate: {c_result['adequate']}, Inadequate: {c_result['inadequate']}"
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Final report (runs last)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def test_final_report(ir_data: dict, ir_path: str, request):
|
||||
"""Generate the final three-layer JSON report.
|
||||
|
||||
This test always passes (report generation). The verdicts from layers A/B/C
|
||||
determine the final releasable status, but the report itself is informational.
|
||||
"""
|
||||
layer_a = _unstash(request, "layer_a") or {"verdict": "SKIPPED"}
|
||||
layer_b = _unstash(request, "layer_b") or {"verdict": "SKIPPED"}
|
||||
layer_c = _unstash(request, "layer_c") or {"verdict": "SKIPPED"}
|
||||
|
||||
report_path = request.config.getoption("--json-report-file", None) or str(
|
||||
Path.cwd() / "acceptance-report.json"
|
||||
)
|
||||
|
||||
report = generate_report(
|
||||
layer_a,
|
||||
layer_b,
|
||||
layer_c,
|
||||
commit=os.environ.get("GITEA_SHA", ""),
|
||||
branch=os.environ.get("GITEA_BRANCH", "main"),
|
||||
output_path=report_path,
|
||||
)
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
print(f"QE ACCEPTANCE REPORT")
|
||||
print(f"{'='*60}")
|
||||
print(f" Layer A (Schema): {layer_a.get('verdict', '?')}")
|
||||
print(f" Layer B (Coverage): {layer_b.get('verdict', '?')} "
|
||||
f"(rate={layer_b.get('coverage_rate', '?')})")
|
||||
print(f" Layer C (QE Audit): {layer_c.get('verdict', '?')}")
|
||||
print(f" {'─'*40}")
|
||||
print(f" FINAL: {report['final_verdict']} | "
|
||||
f"Releasable: {report['releasable']}")
|
||||
print(f" Report: {report_path}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Fail if any layer failed (aggregate assertion)
|
||||
failures = report.get("failure_details", [])
|
||||
if failures:
|
||||
pytest.fail(
|
||||
"Acceptance tests FAILED:\n" + "\n".join(f" - {f}" for f in failures)
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Helpers
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
import os # noqa: E402
|
||||
|
||||
# Module-level stash for sharing results across tests in the same module.
|
||||
# Each test function stores its result here; later tests read earlier results.
|
||||
_module_stash: dict[str, dict] = {}
|
||||
|
||||
|
||||
def _stash(request, key: str, value: dict):
|
||||
"""Store a result dict for cross-test access within this module."""
|
||||
_module_stash[key] = value
|
||||
|
||||
|
||||
def _unstash(request, key: str) -> dict | None:
|
||||
"""Retrieve a stashed result."""
|
||||
return _module_stash.get(key)
|
||||
|
||||
|
||||
def _parse_json_response(raw: str) -> dict | None:
|
||||
"""Parse JSON from an LLM response, handling markdown code fences."""
|
||||
if not raw:
|
||||
return None
|
||||
text = raw.strip()
|
||||
if text.startswith("```"):
|
||||
nl = text.find("\n")
|
||||
text = text[nl + 1:] if nl != -1 else text[3:]
|
||||
if text.endswith("```"):
|
||||
text = text[:-3]
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
+10
-5
@@ -55,12 +55,17 @@ def test_import_detect_conflicts():
|
||||
|
||||
# -- IR generation tests ------------------------------------------------------
|
||||
|
||||
def test_import_ir_generator():
|
||||
"""ir_generator module should be importable."""
|
||||
def test_import_ir_main():
|
||||
"""ir_generation main module should be importable (new project structure)."""
|
||||
os.environ.setdefault("DASHSCOPE_API_KEY", "test-fake-key")
|
||||
_import_from_skill("ir_generation_skill", "ir_generator")
|
||||
import ir_generator
|
||||
assert hasattr(ir_generator, "generate_ir")
|
||||
skill_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"skills", "ir_generation_skill"
|
||||
)
|
||||
if skill_dir not in sys.path:
|
||||
sys.path.insert(0, skill_dir)
|
||||
import main
|
||||
assert hasattr(main, "main")
|
||||
|
||||
|
||||
# -- Resolution application tests ---------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user