import base64 import json import logging import os import re from typing import Optional from LLM import LLMClient logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Prompt loading # --------------------------------------------------------------------------- def _load_prompt() -> str: """Load PROMPT_IMAGE from external file, falling back to inline default.""" prompt_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "prompts") prompt_path = os.path.join(prompt_dir, "image_prompt.md") if os.path.isfile(prompt_path): with open(prompt_path, "r", encoding="utf-8") as f: return f.read() # Fallback inline prompt (nested tree format) return """请分析这张图片,判断类型并输出文字描述和(如适用)结构化逻辑树。 ## 判断图片类型 如果是 **流程图 / 架构图 / 状态图 / 时序图 / 活动图**,你需要输出三项内容: 1. 类型标签 2. **嵌套逻辑树 JSON**(见下方格式) 3. 文字描述 如果是 **其他类型**(UI原型图 / 界面截图 / 设计稿 / 手机屏幕截图 / 网页截图等),只输出类型标签和简要文字描述。 ## 嵌套逻辑树 JSON 格式(仅流程图/架构图/状态图/时序图/活动图需要) **核心原则:用嵌套的 `children` 数组表达流程的层级关系,而不是用 id 引用。** 节点类型:`start`(起始), `end`(结束), `process`(处理/状态), `decision`(判断), `action`(动作) 非判断节点的 `children` 是子节点数组。`end` 节点无 `children`。 判断节点的 `children` 格式: ```json {"condition": "是", "node": {"id": "n6", "name": "...", "type": "action", "children": [...]}} ``` 每条从根到 `end` 的路径必须是完整逻辑链。decision 必须穷举所有分支。 节点 id 使用 "n1", "n2", "n3"... 格式。 ## 输出格式 type: logic_tree: {...} description: 该图片的详细文字描述。""" PROMPT_IMAGE = _load_prompt() # --------------------------------------------------------------------------- # ImageParser # --------------------------------------------------------------------------- class ImageParser: """Vision LLM wrapper for parsing images (type + description + logic_tree). The nested-tree ``logic_tree`` is stored alongside a backward-compatible flat representation so downstream consumers are not broken. Usage:: parser = ImageParser() result = parser.parse_image("images/img1.png") """ _VALID_TYPES = {"flowchart", "architecture", "state", "sequence", "activity", "other"} def __init__(self, llm: LLMClient | None = None): self._llm = llm or LLMClient() @property def usage(self) -> dict: return self._llm.usage def parse_image(self, image_path: str) -> Optional[dict]: """Parse an image and return its type, description, and optional logic_tree. Returns ``{type, description, [logic_tree], [logic_tree_nested]}``. """ logger.info("Parsing image: %s", image_path) with open(image_path, "rb") as f: img_b64 = base64.b64encode(f.read()).decode() mime = self._mime_type(image_path) try: content = self._llm.chat( model=LLMClient.IMAGE_MODEL, messages=[{ "role": "user", "content": [ {"type": "image_url", "image_url": {"url": f"data:{mime};base64,{img_b64}"}}, {"type": "text", "text": PROMPT_IMAGE}, ], }], ) except RuntimeError as e: logger.error(str(e)) return {"type": "other", "description": "", "error": str(e)} parsed = self._parse_response(content) if parsed is None: return None ptype, description, logic_tree_nested = parsed result: dict = {"type": ptype, "description": description} if logic_tree_nested is not None: result["logic_tree_nested"] = logic_tree_nested result["logic_tree"] = self._flatten_tree(logic_tree_nested) return result # ---- internals ---------------------------------------------------------- def _parse_response(self, content: str) -> Optional[tuple[str, str, Optional[dict]]]: """Extract ``(type, description, logic_tree_nested)`` from LLM response. Parses the nested-tree format. Returns *None* for unparseable content. """ content = content.strip() parsed_type = "other" logic_tree = None description = "" # --- type --- type_match = re.search(r'(?:type|类型):\s*(\S+)', content) if type_match: type_val = type_match.group(1).strip().lower() if type_val in self._VALID_TYPES: parsed_type = type_val # --- logic_tree (anchored at line start) --- lt_match = re.search(r'(?m)^logic_tree:\s*', content) desc_match = re.search(r'(?m)^description:\s*', content) if lt_match: lt_start = lt_match.end() lt_end = desc_match.start() if desc_match and desc_match.start() > lt_start else len(content) lt_raw = content[lt_start:lt_end].strip() # Try multiple JSON extraction strategies logic_tree = self._extract_json(lt_raw) if logic_tree is not None: is_valid, err_msg = self._validate_flowchart(logic_tree) if not is_valid: logger.warning("Flowchart validation warning: %s", err_msg) else: logger.info("Failed to extract logic_tree JSON. Raw block length=%d", len(lt_raw)) logger.debug("Raw logic_tree block: %s", lt_raw[:500]) elif parsed_type in self._VALID_TYPES - {"other"}: logger.info("Diagram type=%s but no logic_tree: in response. Response length=%d", parsed_type, len(content)) logger.debug("Raw response (first 500): %s", content[:500]) # --- description --- if desc_match: description = content[desc_match.end():].strip() else: desc = content if type_match: desc = desc[type_match.end():] desc = re.sub(r'(?m)^logic_tree:\s*\{.*?\}\s*', '', desc, flags=re.DOTALL) description = desc.strip() return parsed_type, description, logic_tree @staticmethod def _validate_flowchart(tree: dict) -> tuple[bool, str]: """Validate a nested flowchart tree structure. Returns ``(is_valid, error_message)``. Non-fatal: returns ``False`` with a warning message but the tree is still kept. """ if not isinstance(tree, dict): return False, "logic_tree is not a dict" seen_ids: set[str] = set() def _walk(node: dict, depth: int = 0) -> tuple[bool, str]: if depth > 20: return False, f"Tree too deep (>20) at node {node.get('id', '?')}" nid = node.get("id", "") if not nid: return False, "Node missing 'id' field" if not isinstance(nid, str): return False, f"Node id must be string, got {type(nid).__name__}" if nid in seen_ids: return False, f"Duplicate node id: {nid}" seen_ids.add(nid) ntype = node.get("type", "") if ntype not in ("start", "end", "process", "decision", "action"): return False, f"Unknown node type '{ntype}' at {nid}" if ntype == "end": if "children" in node: return False, f"End node {nid} should not have children" return True, "" children = node.get("children") if not children: if ntype != "end": return False, f"Non-end node {nid} ({ntype}) has no children" return True, "" if not isinstance(children, list): return False, f"children of {nid} is not a list" if ntype == "decision": for child in children: if not isinstance(child, dict): return False, f"decision child of {nid} is not a dict" if "condition" not in child: return False, f"decision child of {nid} missing 'condition'" if "node" not in child: return False, f"decision child of {nid} missing 'node'" ok, err = _walk(child["node"], depth + 1) if not ok: return False, err else: for child in children: if not isinstance(child, dict): return False, f"child of {nid} is not a dict" ok, err = _walk(child, depth + 1) if not ok: return False, err return True, "" return _walk(tree) @staticmethod def _flatten_tree(tree: dict) -> dict: """Convert a nested flowchart tree into the legacy flat-nodes format. This preserves backward compatibility with downstream consumers (conflict_detection_skill, ir_generator) that expect the flat format. """ nodes: list[dict] = [] root_name = "" def _collect(node: dict): nonlocal root_name nid = node.get("id", "") ntype = node.get("type", "") name = node.get("name", "") if root_name == "" and "children" in node: root_name = name if ntype == "decision": branches = [] for child in node.get("children", []): branches.append({ "value": child.get("condition", ""), "target": child["node"].get("id", ""), }) _collect(child["node"]) nodes.append({ "id": nid, "type": ntype, "condition": name, "branches": branches, }) elif ntype in ("action", "process", "state"): nodes.append({ "id": nid, "type": ntype, "description": name, }) for child in node.get("children", []): _collect(child) elif ntype == "start": nodes.append({ "id": nid, "type": ntype, "description": name, }) for child in node.get("children", []): _collect(child) # end nodes are collected but have no children _collect(tree) # Add end nodes from the nested tree ends: list[dict] = [] def _collect_ends(node: dict): if node.get("type") == "end": ends.append({ "id": node.get("id", ""), "type": "end", "description": node.get("name", ""), }) elif "children" in node: for child in node.get("children", []): if isinstance(child, dict): if "node" in child: _collect_ends(child["node"]) else: _collect_ends(child) _collect_ends(tree) nodes.extend(ends) return {"root": root_name, "nodes": nodes} @staticmethod def extract_paths(tree: dict) -> list[list[dict]]: """Extract all root-to-leaf paths from a nested flowchart tree. Each path is a list of node dicts (each with id, name, type). Returns a list of paths useful for human review and LLM verification. """ paths: list[list[dict]] = [] def _walk(node: dict, current_path: list[dict]): entry = {"id": node.get("id", ""), "name": node.get("name", ""), "type": node.get("type", "")} new_path = current_path + [entry] if node.get("type") == "end": paths.append(new_path) return children = node.get("children", []) if not children: paths.append(new_path) return if node.get("type") == "decision": for child in children: _walk(child["node"], new_path) else: for child in children: _walk(child, new_path) _walk(tree, []) return paths @staticmethod def paths_to_text(paths: list[list[dict]]) -> str: """Render extracted paths as human-readable text for review.""" lines: list[str] = [] for i, path in enumerate(paths, 1): steps = [] for node in path: if node["type"] == "decision": steps.append(f"[判断] {node['name']}") elif node["type"] == "end": steps.append(f"[结束] {node['name']}") else: steps.append(f"[{node['type']}] {node['name']}") lines.append(f"路径 {i}: {' -> '.join(steps)}") return "\n".join(lines) @staticmethod def _extract_json(text: str) -> Optional[dict]: """Try multiple strategies to extract a JSON object from text. Returns the parsed dict or None. """ # Strategy 1: first { ... } pair (simple regex) json_match = re.search(r'\{.*\}', text, re.DOTALL) if json_match: try: return json.loads(json_match.group()) except json.JSONDecodeError: pass # Strategy 2: find balanced braces start = text.find("{") if start >= 0: depth = 0 for i in range(start, len(text)): if text[i] == "{": depth += 1 elif text[i] == "}": depth -= 1 if depth == 0: try: return json.loads(text[start:i + 1]) except json.JSONDecodeError: break return None @staticmethod def _mime_type(image_path: str) -> str: ext = os.path.splitext(image_path)[1].lstrip(".").lower() return { "png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg", "gif": "image/gif", "bmp": "image/bmp", "webp": "image/webp", "svg": "image/svg+xml", "tiff": "image/tiff", }.get(ext, "image/png")