Files
pzhang_zywl 4a8032665f
CI / test (pull_request) Successful in 8s
fix: ensemble 温度从 3 个增至 4 个增加多样性 - Closes #75
新增 t=0.5 温度变体,提高 ensemble 多样性以捕获更多功能单元。

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-02 18:55:16 +08:00

178 lines
6.4 KiB
Python

"""
Shared configuration for the IR Generation pipeline.
Reads API keys from a secrets.yaml file, falling back to environment variables.
"""
import os
import sys
import json
import yaml
# ---- Paths ----
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
WORKSPACE_DIR = os.path.dirname(BASE_DIR)
PROJECT_ROOT = os.path.dirname(WORKSPACE_DIR)
PROJECT_OUTPUT = os.path.join(PROJECT_ROOT, "output")
# Subdirectories under PROJECT_OUTPUT
IR_OUTPUT = os.path.join(PROJECT_OUTPUT, "ir")
FINAL_OUTPUT = os.path.join(PROJECT_OUTPUT, "final")
# Legacy paths (maintained for doc_parser integration)
DOC_PARSER_OUTPUT = os.path.join(WORKSPACE_DIR, "doc_parser_skill", "output")
PROMPTS_DIR = os.path.join(BASE_DIR, "prompts")
TESTS_DIR = os.path.join(BASE_DIR, "tests")
OUTPUT_DIR = IR_OUTPUT # backward compatibility alias
# Input file (the parsed PRD JSON) — must be set via env var or CLI
# No hardcoded default to avoid silently processing the wrong document.
INPUT_JSON = os.environ.get("IR_INPUT_JSON", None)
def set_input_file(path: str) -> None:
"""Override the default input JSON path."""
global INPUT_JSON
INPUT_JSON = path
# Secrets file — searched in order of priority:
# 1. IR_SECRETS_PATH env var
# 2. ~/.openclaw/config/secrets.yaml
# 3. ~/.openclaw/workspace-document-analyzer/config/secrets.yaml
_SECRETS_CANDIDATES = [
os.path.join(os.path.expanduser("~"), ".openclaw", "config", "secrets.yaml"),
os.path.join(os.path.expanduser("~"), ".openclaw", "workspace-document-analyzer",
"config", "secrets.yaml"),
]
_SECRETS_PATH = os.environ.get("IR_SECRETS_PATH", "")
if _SECRETS_PATH:
_SECRETS_CANDIDATES.insert(0, _SECRETS_PATH)
SECRETS_YAML = _SECRETS_CANDIDATES[0] # primary path (backward compat)
# Intermediate outputs (all under PROJECT_OUTPUT/ir/)
SEMANTIC_INDEX_R1_JSON = os.path.join(IR_OUTPUT, "semantic_index_r1.json")
SEMANTIC_INDEX_R2_JSON = os.path.join(IR_OUTPUT, "semantic_index_r2.json")
SEMANTIC_INDEX_R3_JSON = os.path.join(IR_OUTPUT, "semantic_index_r3.json")
SEMANTIC_INDEX_JSON = os.path.join(IR_OUTPUT, "semantic_index.json")
IR_FRAGMENTS_JSON = os.path.join(IR_OUTPUT, "ir_fragments.json")
PATH_ENUM_JSON = os.path.join(IR_OUTPUT, "path_enumeration.json")
IR_AUTOCOMPLETE_FRAGMENTS_JSON = os.path.join(IR_OUTPUT, "ir_autocomplete_fragments.json")
# Final deliverables (under PROJECT_OUTPUT/final/)
IR_FINAL_JSON = os.path.join(FINAL_OUTPUT, "ir_final.json")
IR_AUDIT_REPORT_MD = os.path.join(FINAL_OUTPUT, "ir_audit_report.md")
# ---- LLM API ----
# Choose provider: "deepseek" | "dashscope"
LLM_PROVIDER = os.environ.get("IR_PROVIDER", "deepseek")
# Model names per provider
PROVIDER_MODELS = {
"deepseek": os.environ.get("IR_MODEL", "deepseek-v4-flash"),
"dashscope": os.environ.get("IR_MODEL", "qwen-max"),
}
MODEL_NAME = PROVIDER_MODELS.get(LLM_PROVIDER, PROVIDER_MODELS["deepseek"])
# Maximum tokens for LLM responses
MAX_TOKENS = int(os.environ.get("IR_MAX_TOKENS", "16000"))
TEMPERATURE = float(os.environ.get("IR_TEMPERATURE", "0.1"))
# ---- Iteration & Quality ----
MAX_RETRIES_PER_STAGE = int(os.environ.get("IR_MAX_RETRIES", "3"))
COVERAGE_TARGET = float(os.environ.get("IR_COVERAGE_TARGET", "0.95"))
# Stage 1 ensemble temperatures (parallel multi-temperature generation)
ENSEMBLE_TEMPERATURES = [
float(os.environ.get("IR_ENSEMBLE_T1", "0.0")),
float(os.environ.get("IR_ENSEMBLE_T2", "0.3")),
float(os.environ.get("IR_ENSEMBLE_T3", "0.5")),
float(os.environ.get("IR_ENSEMBLE_T4", "0.7")),
]
def _load_secrets() -> dict[str, dict[str, str]]:
"""Load provider credentials from secrets.yaml.
Tries paths in order: IR_SECRETS_PATH env var → ~/.openclaw/config/ →
~/.openclaw/workspace-document-analyzer/config/.
Returns a dict like: {"deepseek": {"apiKey": "...", "baseUrl": "..."}, ...}
"""
for p in _SECRETS_CANDIDATES:
if os.path.isfile(p):
with open(p, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {}
return {}
def _get_provider_config(provider: str) -> dict[str, str]:
"""Get {apiKey, baseUrl} for a provider from secrets, with env-var fallback."""
secrets = _load_secrets()
entry = secrets.get(provider, {})
env_prefix = provider.upper()
api_key = (
os.environ.get(f"{env_prefix}_API_KEY")
or entry.get("apiKey", "")
)
base_url = (
os.environ.get(f"{env_prefix}_BASE_URL")
or entry.get("baseUrl", "https://api.deepseek.com/v1")
)
if not api_key:
tried_paths = "\n ".join(_SECRETS_CANDIDATES)
raise RuntimeError(
f"No API key found for provider '{provider}'.\n"
f"Tried secrets.yaml paths:\n {tried_paths}\n"
f"Or set {env_prefix}_API_KEY environment variable."
)
return {"apiKey": api_key, "baseUrl": base_url}
def llm_client():
"""Return an OpenAI-compatible client configured from secrets.yaml."""
from openai import OpenAI
cfg = _get_provider_config(LLM_PROVIDER)
return OpenAI(base_url=cfg["baseUrl"], api_key=cfg["apiKey"])
def load_input_document(path: str | None = None) -> dict:
"""Load the parsed PRD JSON document.
Args:
path: Explicit file path. If None, reads from IR_INPUT_JSON env var.
Raises:
FileNotFoundError: If no path is configured.
SystemExit: If the configured path does not exist.
"""
path = path or INPUT_JSON
if not path:
print("错误: 未指定输入文件。请通过以下任一方式指定:", file=sys.stderr)
print(" 1. 设置环境变量: IR_INPUT_JSON=<path>", file=sys.stderr)
print(" 2. 通过 main.py: python main.py --input <path>", file=sys.stderr)
print(" 3. 通过 step 脚本: python step1_semantic_index.py --input <path>", file=sys.stderr)
print(" 4. 程序调用: config.set_input_file(<path>)", file=sys.stderr)
sys.exit(1)
if not os.path.isfile(path):
print(f"错误: 输入文件不存在: {path}", file=sys.stderr)
sys.exit(1)
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def save_json(data, path: str) -> None:
"""Save data as formatted JSON."""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def load_json(path: str) -> dict:
"""Load a JSON file."""
with open(path, "r", encoding="utf-8") as f:
return json.load(f)