""" Tests for Stage 2 (IR Extraction). Validates that ir_fragments.json meets quality and structural requirements: - All fragments have non-empty rules - All rules have path arrays - All rules have precondition.geographic_scope and precondition.screen_type - All trigger conditions have signal/operator/value - user_interaction content is non-empty and not a placeholder - No duplicate rule_ids (across all fragments) """ import json import sys from pathlib import Path from collections import Counter sys.path.insert(0, str(Path(__file__).parent.parent)) import config PASS = "[PASS]" FAIL = "[FAIL]" WARN = "[WARN]" # Forbidden placeholder phrases in user_interaction content FORBIDDEN_PLACEHOLDERS = [ "文案由业务定义", "待定", "自定义", "TBD", "todo", "TODO" ] def load_fragments(): """Load ir_fragments.json.""" try: return config.load_json(config.IR_FRAGMENTS_JSON) except FileNotFoundError: print(f"{FAIL} ir_fragments.json 未找到: {config.IR_FRAGMENTS_JSON}") print(" 请先运行 step2_ir_extraction.py") sys.exit(1) def check_non_empty_rules(fragments: list[dict]) -> list[str]: """Every fragment must have at least one rule.""" errors = [] for f in fragments: uid = f.get("unit_id", "?") rules = f.get("rules", []) if not rules: if f.get("error"): errors.append(f"{uid}: 提取失败 — {f['error']}") else: errors.append(f"{uid}: rules 为空") return errors def check_rule_paths(fragments: list[dict]) -> list[str]: """Every rule must have a non-empty path array.""" errors = [] for f in fragments: uid = f.get("unit_id", "?") for j, rule in enumerate(f.get("rules", [])): rid = rule.get("rule_id", f"rule[{j}]") path = rule.get("path", []) if not path: errors.append(f"{rid}: path 字段为空或缺失") elif not isinstance(path, list): errors.append(f"{rid}: path 必须是数组") return errors def check_precondition_fields(fragments: list[dict]) -> list[str]: """Every rule must have precondition with geographic_scope and screen_type.""" errors = [] for f in fragments: uid = f.get("unit_id", "?") for j, rule in enumerate(f.get("rules", [])): rid = rule.get("rule_id", f"rule[{j}]") precond = rule.get("precondition", {}) if not precond: errors.append(f"{rid}: precondition 缺失") continue if not precond.get("geographic_scope"): errors.append(f"{rid}: precondition.geographic_scope 缺失") if "screen_type" not in precond: errors.append(f"{rid}: precondition.screen_type 缺失") return errors def check_user_interaction_content(fragments: list[dict]) -> list[str]: """user_interaction actions must have non-empty, non-placeholder content.""" errors = [] for f in fragments: uid = f.get("unit_id", "?") for j, rule in enumerate(f.get("rules", [])): rid = rule.get("rule_id", f"rule[{j}]") for k, action in enumerate(rule.get("actions", [])): if action.get("type") != "user_interaction": continue content = action.get("content", "") if not content: errors.append( f"{rid}.actions[{k}]: user_interaction 的 content 为空" ) elif any(ph in content for ph in FORBIDDEN_PLACEHOLDERS): errors.append( f"{rid}.actions[{k}]: content 包含占位符: '{content}'" ) return errors def check_sources_have_logic_tree_nodes(fragments: list[dict]) -> list[str]: """Every rule should reference at least one logic tree node in its sources.""" errors = [] for f in fragments: uid = f.get("unit_id", "?") for j, rule in enumerate(f.get("rules", [])): rid = rule.get("rule_id", f"rule[{j}]") sources = rule.get("sources", []) has_logic_tree = any( src.get("type") == "logic_tree" and src.get("node_ids") for src in sources ) if not has_logic_tree: has_text = any( src.get("type") in ("table", "para") for src in sources ) if not has_text: errors.append(f"{rid}: sources 中既无逻辑树引用也无文字引用") return errors def check_trigger_conditions(fragments: list[dict]) -> list[str]: """Every trigger condition must have signal, operator, value.""" errors = [] for f in fragments: uid = f.get("unit_id", "?") for j, rule in enumerate(f.get("rules", [])): rid = rule.get("rule_id", f"rule[{j}]") trigger = rule.get("trigger") or {} conditions = trigger.get("conditions", []) if trigger.get("event") is not None: continue for k, cond in enumerate(conditions): signal = cond.get("signal", "") operator = cond.get("operator", "") has_value = "value" in cond if not signal: errors.append(f"{rid}.condition[{k}]: 缺少 signal") if not operator: errors.append(f"{rid}.condition[{k}]: 缺少 operator") if not has_value: errors.append(f"{rid}.condition[{k}]: 缺少 value") return errors def check_duplicate_rule_ids(fragments: list[dict]) -> list[str]: """Check for duplicate rule_ids across all fragments.""" all_rule_ids = [] for f in fragments: for rule in f.get("rules", []): rid = rule.get("rule_id", "") if rid: all_rule_ids.append(rid) duplicates = [rid for rid, count in Counter(all_rule_ids).items() if count > 1] errors = [] if duplicates: errors.append(f"重复 rule_id: {duplicates}") return errors def check_action_types(fragments: list[dict]) -> list[str]: """Verify that actions have valid types.""" valid_types = {"system", "user_interaction"} errors = [] for f in fragments: for j, rule in enumerate(f.get("rules", [])): rid = rule.get("rule_id", f"rule[{j}]") for k, action in enumerate(rule.get("actions", [])): atype = action.get("type", "") if atype not in valid_types: errors.append( f"{rid}.action[{k}]: type='{atype}' 无效, " f"应为 {valid_types}" ) if atype == "user_interaction" and "content" not in action: errors.append( f"{rid}.action[{k}]: user_interaction 类型缺少 content 字段" ) return errors def run_all_tests(): print("=" * 60) print("Step 2 自检测试") print("=" * 60) fragments = load_fragments() all_errors = [] total_units = len(fragments) total_rules = sum(len(f.get("rules", [])) for f in fragments) # Test 1: Non-empty rules errors = check_non_empty_rules(fragments) if errors: print(f"\n{FAIL} 非空规则检查: {len(errors)} 个错误") for e in errors: print(f" - {e}") all_errors.extend(errors) else: print(f"\n{PASS} 非空规则检查: 全部通过 ({total_units} 个片段)") # Test 2: Rule path arrays errors = check_rule_paths(fragments) if errors: print(f"\n{FAIL} 规则 path 字段: {len(errors)} 个错误") for e in errors[:10]: print(f" - {e}") if len(errors) > 10: print(f" ... 还有 {len(errors) - 10} 个") all_errors.extend(errors) else: print(f"\n{PASS} 规则 path 字段: 全部通过") # Test 3: Precondition fields errors = check_precondition_fields(fragments) if errors: print(f"\n{FAIL} precondition 字段: {len(errors)} 个错误") for e in errors[:10]: print(f" - {e}") if len(errors) > 10: print(f" ... 还有 {len(errors) - 10} 个") all_errors.extend(errors) else: print(f"\n{PASS} precondition 字段: 全部通过") # Test 4: user_interaction content errors = check_user_interaction_content(fragments) if errors: print(f"\n{FAIL} user_interaction content: {len(errors)} 个错误") for e in errors[:10]: print(f" - {e}") if len(errors) > 10: print(f" ... 还有 {len(errors) - 10} 个") all_errors.extend(errors) else: print(f"\n{PASS} user_interaction content: 全部通过") # Test 5: Sources have logic tree references errors = check_sources_have_logic_tree_nodes(fragments) if errors: print(f"\n{FAIL} 来源节点引用: {len(errors)} 个规则缺少来源引用") for e in errors[:10]: print(f" - {e}") if len(errors) > 10: print(f" ... 还有 {len(errors) - 10} 个") all_errors.extend(errors) else: print(f"\n{PASS} 来源节点引用: 全部通过") # Test 6: Trigger conditions completeness errors = check_trigger_conditions(fragments) if errors: print(f"\n{FAIL} 触发条件完整性: {len(errors)} 个条件不完整") for e in errors[:10]: print(f" - {e}") if len(errors) > 10: print(f" ... 还有 {len(errors) - 10} 个") all_errors.extend(errors) else: print(f"\n{PASS} 触发条件完整性: 全部通过") # Test 7: No duplicate rule_ids errors = check_duplicate_rule_ids(fragments) if errors: print(f"\n{FAIL} rule_id 唯一性: 发现重复") for e in errors: print(f" - {e}") all_errors.extend(errors) else: print(f"\n{PASS} rule_id 唯一性: 全部通过") # Test 8: Valid action types errors = check_action_types(fragments) if errors: print(f"\n{FAIL} 动作类型检查: {len(errors)} 个问题") for e in errors[:10]: print(f" - {e}") all_errors.extend(errors) else: print(f"\n{PASS} 动作类型检查: 全部通过") # Summary print(f"\n{'='*60}") total_failures = len(all_errors) if total_failures == 0: print(f"{PASS} 所有测试通过!") else: print(f"{FAIL} 测试失败: {total_failures} 个错误") print("\n建议:") print(" 1. 检查 ir_fragments.json 中出错的规则") print(" 2. 如果某些功能单元的规则为空,检查上下文包是否丢失了关键信息") print(" 3. 调整 Prompt (prompts/step2_ir_extraction.txt) 后重新运行") print(f"\n统计:") print(f" 功能单元数: {total_units}") print(f" 规则总数: {total_rules}") error_units = sum(1 for f in fragments if f.get("error")) if error_units: print(f" 提取失败的单元: {error_units}") return total_failures == 0 # ═══════════════════════════════════════════════════════════════════════════════ # pytest discovery support # ═══════════════════════════════════════════════════════════════════════════════ import pytest # noqa: E402 def _load_fragments_or_skip(): """Load ir_fragments.json or return None.""" try: return config.load_json(config.IR_FRAGMENTS_JSON) except FileNotFoundError: return None def test_step2_non_empty_rules(): """pytest: every fragment must have at least one rule.""" fragments = _load_fragments_or_skip() if fragments is None: pytest.skip("ir_fragments.json not found — run step2_ir_extraction.py first") errors = check_non_empty_rules(fragments) assert not errors, f"non-empty rule errors: {errors}" def test_step2_rule_paths(): """pytest: every rule must have a non-empty path array.""" fragments = _load_fragments_or_skip() if fragments is None: pytest.skip("ir_fragments.json not found") errors = check_rule_paths(fragments) assert not errors, f"rule path errors: {errors[:5]}" def test_step2_precondition_fields(): """Warn: rules missing precondition fields (depends on LLM output, defense in step3).""" fragments = _load_fragments_or_skip() if fragments is None: pytest.skip("ir_fragments.json not found") errors = check_precondition_fields(fragments) if errors: print(f"\n[WARN] {len(errors)} 个规则缺少 precondition 字段 (LLM 输出变异,step3 _normalize_rule 兜底)") for e in errors[:5]: print(f" - {e}") def test_step2_user_interaction_content(): """pytest: user_interaction actions must have non-empty, non-placeholder content.""" fragments = _load_fragments_or_skip() if fragments is None: pytest.skip("ir_fragments.json not found") errors = check_user_interaction_content(fragments) assert not errors, f"user_interaction content errors: {errors[:5]}" def test_step2_sources_have_refs(): """pytest: every rule should reference at least one source (warn only — depends on LLM output).""" fragments = _load_fragments_or_skip() if fragments is None: pytest.skip("ir_fragments.json not found") errors = check_sources_have_logic_tree_nodes(fragments) if errors: print(f"\n[WARN] {len(errors)} 个规则缺少来源引用 (LLM 输出质量问题)") def test_step2_trigger_conditions(): """pytest: every trigger condition must have signal, operator, value.""" fragments = _load_fragments_or_skip() if fragments is None: pytest.skip("ir_fragments.json not found") errors = check_trigger_conditions(fragments) assert not errors, f"trigger condition errors: {errors[:5]}" def test_step2_duplicate_rule_ids(): """pytest: no duplicate rule_ids across all fragments.""" fragments = _load_fragments_or_skip() if fragments is None: pytest.skip("ir_fragments.json not found") errors = check_duplicate_rule_ids(fragments) assert not errors, f"duplicate rule_id errors: {errors}" def test_step2_action_types(): """pytest: all actions must have valid types.""" fragments = _load_fragments_or_skip() if fragments is None: pytest.skip("ir_fragments.json not found") errors = check_action_types(fragments) assert not errors, f"action type errors: {errors[:5]}" if __name__ == "__main__": success = run_all_tests() sys.exit(0 if success else 1)