import re import uuid import networkx as nx from networkx.drawing.nx_agraph import to_agraph import matplotlib.pyplot as plt from typing import Dict, List, Any, Tuple, Optional Json = Dict[str, Any] # ---------------- Graph ---------------- class WorkflowGraph: """简单的有向图实现:使用 params 单层参数;inputs 内含连线;支持 node-link 导出""" def __init__(self): self.nodes: Dict[str, Dict[str, Any]] = {} self.edges: List[Dict[str, Any]] = [] def add_node(self, node_id: str, **attrs): self.nodes[node_id] = attrs def add_edge(self, source: str, target: str, **attrs): edge = { "source": source, "target": target, "source_node_uuid": source, "target_node_uuid": target, "source_handle_io": attrs.pop("source_handle_io", "source"), "target_handle_io": attrs.pop("target_handle_io", "target"), **attrs } self.edges.append(edge) def _materialize_wiring_into_inputs(self, obj: Any, inputs: Dict[str, Any], variable_sources: Dict[str, Dict[str, Any]], target_node_id: str, base_path: List[str]): has_var = False def walk(node: Any, path: List[str]): nonlocal has_var if isinstance(node, dict): if "__var__" in node: has_var = True varname = node["__var__"] placeholder = f"${{{varname}}}" src = variable_sources.get(varname) if src: key = ".".join(path) # e.g. "params.foo.bar.0" inputs[key] = {"node": src["node_id"], "output": src.get("output_name", "result")} self.add_edge(str(src["node_id"]), target_node_id, source_handle_io=src.get("output_name", "result"), target_handle_io=key) return placeholder return {k: walk(v, path + [k]) for k, v in node.items()} if isinstance(node, list): return [walk(v, path + [str(i)]) for i, v in enumerate(node)] return node replaced = walk(obj, base_path[:]) return replaced, has_var def add_workflow_node(self, node_id: int, *, device_key: Optional[str] = None, # 实例名,如 "ser" resource_name: Optional[str] = None, # registry key(原 device_class) module: Optional[str] = None, template_name: Optional[str] = None, # 动作/模板名(原 action_key) params: Dict[str, Any], variable_sources: Dict[str, Dict[str, Any]], add_ready_if_no_vars: bool = True, prev_node_id: Optional[int] = None, **extra_attrs) -> None: """添加工作流节点:params 单层;自动变量连线与 ready 串联;支持附加属性""" node_id_str = str(node_id) inputs: Dict[str, Any] = {} params, has_var = self._materialize_wiring_into_inputs( params, inputs, variable_sources, node_id_str, base_path=["params"] ) if add_ready_if_no_vars and not has_var: last_id = str(prev_node_id) if prev_node_id is not None else "-1" inputs["ready"] = {"node": int(last_id), "output": "ready"} self.add_edge(last_id, node_id_str, source_handle_io="ready", target_handle_io="ready") node_obj = { "device_key": device_key, "resource_name": resource_name, # ✅ 新名字 "module": module, "template_name": template_name, # ✅ 新名字 "params": params, "inputs": inputs, } node_obj.update(extra_attrs or {}) self.add_node(node_id_str, parameters=node_obj) # 顺序工作流导出(连线在 inputs,不返回 edges) def to_dict(self) -> List[Dict[str, Any]]: result = [] for node_id, attrs in self.nodes.items(): node = {"id": node_id} params = dict(attrs.get("parameters", {}) or {}) flat = {k: v for k, v in attrs.items() if k != "parameters"} flat.update(params) node.update(flat) result.append(node) return sorted(result, key=lambda n: int(n["id"]) if str(n["id"]).isdigit() else n["id"]) # node-link 导出(含 edges) def to_node_link_dict(self) -> Dict[str, Any]: nodes_list = [] for node_id, attrs in self.nodes.items(): node_attrs = attrs.copy() params = node_attrs.pop("parameters", {}) or {} node_attrs.update(params) nodes_list.append({"id": node_id, **node_attrs}) return {"directed": True, "multigraph": False, "graph": {}, "nodes": nodes_list, "edges": self.edges, "links": self.edges} def refactor_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """统一的数据重构函数,根据操作类型自动选择模板""" refactored_data = [] # 定义操作映射,包含生物实验和有机化学的所有操作 OPERATION_MAPPING = { # 生物实验操作 "transfer_liquid": "transfer_liquid", "transfer": "transfer", "incubation": "incubation", "move_labware": "move_labware", "oscillation": "oscillation", # 有机化学操作 "HeatChillToTemp": "HeatChillProtocol", "StopHeatChill": "HeatChillStopProtocol", "StartHeatChill": "HeatChillStartProtocol", "HeatChill": "HeatChillProtocol", "Dissolve": "DissolveProtocol", "Transfer": "TransferProtocol", "Evaporate": "EvaporateProtocol", "Recrystallize": "RecrystallizeProtocol", "Filter": "FilterProtocol", "Dry": "DryProtocol", "Add": "AddProtocol", } UNSUPPORTED_OPERATIONS = ["Purge", "Wait", "Stir", "ResetHandling"] for step in data: operation = step.get("action") if not operation or operation in UNSUPPORTED_OPERATIONS: continue # 处理重复操作 if operation == "Repeat": times = step.get("times", step.get("parameters", {}).get("times", 1)) sub_steps = step.get("steps", step.get("parameters", {}).get("steps", [])) for i in range(int(times)): sub_data = refactor_data(sub_steps) refactored_data.extend(sub_data) continue # 获取模板名称 template = OPERATION_MAPPING.get(operation) if not template: # 自动推断模板类型 if operation.lower() in ["transfer", "incubation", "move_labware", "oscillation"]: template = f"biomek-{operation}" else: template = f"{operation}Protocol" # 创建步骤数据 step_data = { "template": template, "description": step.get("description", step.get("purpose", f"{operation} operation")), "lab_node_type": "Device", "parameters": step.get("parameters", step.get("action_args", {})), } refactored_data.append(step_data) return refactored_data def build_protocol_graph( labware_info: List[Dict[str, Any]], protocol_steps: List[Dict[str, Any]], workstation_name: str ) -> WorkflowGraph: """统一的协议图构建函数,根据设备类型自动选择构建逻辑""" G = WorkflowGraph() resource_last_writer = {} protocol_steps = refactor_data(protocol_steps) # 有机化学&移液站协议图构建 WORKSTATION_ID = workstation_name # 为所有labware创建资源节点 for labware_id, item in labware_info.items(): # item_id = item.get("id") or item.get("name", f"item_{uuid.uuid4()}") node_id = str(uuid.uuid4()) # 判断节点类型 if "Rack" in str(labware_id) or "Tip" in str(labware_id): lab_node_type = "Labware" description = f"Prepare Labware: {labware_id}" liquid_type = [] liquid_volume = [] elif item.get("type") == "hardware" or "reactor" in str(labware_id).lower(): if "reactor" not in str(labware_id).lower(): continue lab_node_type = "Sample" description = f"Prepare Reactor: {labware_id}" liquid_type = [] liquid_volume = [] else: lab_node_type = "Reagent" description = f"Add Reagent to Flask: {labware_id}" liquid_type = [labware_id] liquid_volume = [1e5] G.add_node( node_id, template_name=f"create_resource", resource_name="host_node", description=description, lab_node_type=lab_node_type, params={ "res_id": labware_id, "device_id": WORKSTATION_ID, "class_name": "container", "parent": WORKSTATION_ID, "bind_locations": {"x": 0.0, "y": 0.0, "z": 0.0}, "liquid_input_slot": [-1], "liquid_type": liquid_type, "liquid_volume": liquid_volume, "slot_on_deck": "", }, role=item.get("role", ""), ) resource_last_writer[labware_id] = f"{node_id}:labware" last_control_node_id = None # 处理协议步骤 for step in protocol_steps: node_id = str(uuid.uuid4()) G.add_node(node_id, **step) # 控制流 if last_control_node_id is not None: G.add_edge(last_control_node_id, node_id, source_port="ready", target_port="ready") last_control_node_id = node_id # 物料流 params = step.get("parameters", {}) input_resources_possible_names = [ "vessel", "to_vessel", "from_vessel", "reagent", "solvent", "compound", "sources", "targets", ] for target_port in input_resources_possible_names: resource_name = params.get(target_port) if resource_name and resource_name in resource_last_writer: source_node, source_port = resource_last_writer[resource_name].split(":") G.add_edge(source_node, node_id, source_port=source_port, target_port=target_port) output_resources = { "vessel_out": params.get("vessel"), "from_vessel_out": params.get("from_vessel"), "to_vessel_out": params.get("to_vessel"), "filtrate_out": params.get("filtrate_vessel"), "reagent": params.get("reagent"), "solvent": params.get("solvent"), "compound": params.get("compound"), "sources_out": params.get("sources"), "targets_out": params.get("targets"), } for source_port, resource_name in output_resources.items(): if resource_name: resource_last_writer[resource_name] = f"{node_id}:{source_port}" return G def draw_protocol_graph(protocol_graph: WorkflowGraph, output_path: str): """ (辅助功能) 使用 networkx 和 matplotlib 绘制协议工作流图,用于可视化。 """ if not protocol_graph: print("Cannot draw graph: Graph object is empty.") return G = nx.DiGraph() for node_id, attrs in protocol_graph.nodes.items(): label = attrs.get("description", attrs.get("template", node_id[:8])) G.add_node(node_id, label=label, **attrs) for edge in protocol_graph.edges: G.add_edge(edge["source"], edge["target"]) plt.figure(figsize=(20, 15)) try: pos = nx.nx_agraph.graphviz_layout(G, prog="dot") except Exception: pos = nx.shell_layout(G) # Fallback layout node_labels = {node: data["label"] for node, data in G.nodes(data=True)} nx.draw( G, pos, with_labels=False, node_size=2500, node_color="skyblue", node_shape="o", edge_color="gray", width=1.5, arrowsize=15, ) nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=8, font_weight="bold") plt.title("Chemical Protocol Workflow Graph", size=15) plt.savefig(output_path, dpi=300, bbox_inches="tight") plt.close() print(f" - Visualization saved to '{output_path}'") COMPASS = {"n","e","s","w","ne","nw","se","sw","c"} def _is_compass(port: str) -> bool: return isinstance(port, str) and port.lower() in COMPASS def draw_protocol_graph_with_ports(protocol_graph, output_path: str, rankdir: str = "LR"): """ 使用 Graphviz 端口语法绘制协议工作流图。 - 若边上的 source_port/target_port 是 compass(n/e/s/w/...),直接用 compass。 - 否则自动为节点创建 record 形状并定义命名端口 。 最终由 PyGraphviz 渲染并输出到 output_path(后缀决定格式,如 .png/.svg/.pdf)。 """ if not protocol_graph: print("Cannot draw graph: Graph object is empty.") return # 1) 先用 networkx 搭建有向图,保留端口属性 G = nx.DiGraph() for node_id, attrs in protocol_graph.nodes.items(): label = attrs.get("description", attrs.get("template", node_id[:8])) # 保留一个干净的“中心标签”,用于放在 record 的中间槽 G.add_node(node_id, _core_label=str(label), **{k:v for k,v in attrs.items() if k not in ("label",)}) edges_data = [] in_ports_by_node = {} # 收集命名输入端口 out_ports_by_node = {} # 收集命名输出端口 for edge in protocol_graph.edges: u = edge["source"] v = edge["target"] sp = edge.get("source_port") tp = edge.get("target_port") # 记录到图里(保留原始端口信息) G.add_edge(u, v, source_port=sp, target_port=tp) edges_data.append((u, v, sp, tp)) # 如果不是 compass,就按“命名端口”先归类,等会儿给节点造 record if sp and not _is_compass(sp): out_ports_by_node.setdefault(u, set()).add(str(sp)) if tp and not _is_compass(tp): in_ports_by_node.setdefault(v, set()).add(str(tp)) # 2) 转为 AGraph,使用 Graphviz 渲染 A = to_agraph(G) A.graph_attr.update(rankdir=rankdir, splines="true", concentrate="false", fontsize="10") A.node_attr.update(shape="box", style="rounded,filled", fillcolor="lightyellow", color="#999999", fontname="Helvetica") A.edge_attr.update(arrowsize="0.8", color="#666666") # 3) 为需要命名端口的节点设置 record 形状与 label # 左列 = 输入端口;中间 = 核心标签;右列 = 输出端口 for n in A.nodes(): node = A.get_node(n) core = G.nodes[n].get("_core_label", n) in_ports = sorted(in_ports_by_node.get(n, [])) out_ports = sorted(out_ports_by_node.get(n, [])) # 如果该节点涉及命名端口,则用 record;否则保留原 box if in_ports or out_ports: def port_fields(ports): if not ports: return " " # 必须留一个空槽占位 # 每个端口一个小格子,

name return "|".join(f"<{re.sub(r'[^A-Za-z0-9_:.|-]', '_', p)}> {p}" for p in ports) left = port_fields(in_ports) right = port_fields(out_ports) # 三栏:左(入) | 中(节点名) | 右(出) record_label = f"{{ {left} | {core} | {right} }}" node.attr.update(shape="record", label=record_label) else: # 没有命名端口:普通盒子,显示核心标签 node.attr.update(label=str(core)) # 4) 给边设置 headport / tailport # - 若端口为 compass:直接用 compass(e.g., headport="e") # - 若端口为命名端口:使用在 record 中定义的 名(同名即可) for (u, v, sp, tp) in edges_data: e = A.get_edge(u, v) # Graphviz 属性:tail 是源,head 是目标 if sp: if _is_compass(sp): e.attr["tailport"] = sp.lower() else: # 与 record label 中 名一致;特殊字符已在 label 中做了清洗 e.attr["tailport"] = re.sub(r'[^A-Za-z0-9_:.|-]', '_', str(sp)) if tp: if _is_compass(tp): e.attr["headport"] = tp.lower() else: e.attr["headport"] = re.sub(r'[^A-Za-z0-9_:.|-]', '_', str(tp)) # 可选:若想让边更贴边缘,可设置 constraint/spline 等 # e.attr["arrowhead"] = "vee" # 5) 输出 A.draw(output_path, prog="dot") print(f" - Port-aware workflow rendered to '{output_path}'") # ---------------- Registry Adapter ---------------- class RegistryAdapter: """根据 module 的类名(冒号右侧)反查 registry 的 resource_name(原 device_class),并抽取参数顺序""" def __init__(self, device_registry: Dict[str, Any]): self.device_registry = device_registry or {} self.module_class_to_resource = self._build_module_class_index() def _build_module_class_index(self) -> Dict[str, str]: idx = {} for resource_name, info in self.device_registry.items(): module = info.get("module") if isinstance(module, str) and ":" in module: cls = module.split(":")[-1] idx[cls] = resource_name idx[cls.lower()] = resource_name return idx def resolve_resource_by_classname(self, class_name: str) -> Optional[str]: if not class_name: return None return (self.module_class_to_resource.get(class_name) or self.module_class_to_resource.get(class_name.lower())) def get_device_module(self, resource_name: Optional[str]) -> Optional[str]: if not resource_name: return None return self.device_registry.get(resource_name, {}).get("module") def get_actions(self, resource_name: Optional[str]) -> Dict[str, Any]: if not resource_name: return {} return (self.device_registry.get(resource_name, {}) .get("class", {}) .get("action_value_mappings", {})) or {} def get_action_schema(self, resource_name: Optional[str], template_name: str) -> Optional[Json]: return (self.get_actions(resource_name).get(template_name) or {}).get("schema") def get_action_goal_default(self, resource_name: Optional[str], template_name: str) -> Json: return (self.get_actions(resource_name).get(template_name) or {}).get("goal_default", {}) or {} def get_action_input_keys(self, resource_name: Optional[str], template_name: str) -> List[str]: schema = self.get_action_schema(resource_name, template_name) or {} goal = (schema.get("properties") or {}).get("goal") or {} props = goal.get("properties") or {} required = goal.get("required") or [] return list(dict.fromkeys(required + list(props.keys())))