diff --git a/scripts/workflow.py b/scripts/workflow.py index be7bbd1..8bd8964 100644 --- a/scripts/workflow.py +++ b/scripts/workflow.py @@ -2,7 +2,6 @@ import json import logging import traceback import uuid -import xml.etree.ElementTree as ET from typing import Any, Dict, List import networkx as nx @@ -25,7 +24,15 @@ class SimpleGraph: def add_edge(self, source, target, **attrs): """添加边""" - edge = {"source": source, "target": target, **attrs} + # edge = {"source": source, "target": target, **attrs} + edge = { + "source": source, "target": target, + "source_node_uuid": source, + "target_node_uuid": target, + "source_handle_io": "source", + "target_handle_io": "target", + **attrs + } self.edges.append(edge) def to_dict(self): @@ -42,6 +49,7 @@ class SimpleGraph: "multigraph": False, "graph": {}, "nodes": nodes_list, + "edges": self.edges, "links": self.edges, } @@ -58,495 +66,8 @@ def extract_json_from_markdown(text: str) -> str: return text -def convert_to_type(val: str) -> Any: - """将字符串值转换为适当的数据类型""" - if val == "True": - return True - if val == "False": - return False - if val == "?": - return None - if val.endswith(" g"): - return float(val.split(" ")[0]) - if val.endswith("mg"): - return float(val.split("mg")[0]) - elif val.endswith("mmol"): - return float(val.split("mmol")[0]) / 1000 - elif val.endswith("mol"): - return float(val.split("mol")[0]) - elif val.endswith("ml"): - return float(val.split("ml")[0]) - elif val.endswith("RPM"): - return float(val.split("RPM")[0]) - elif val.endswith(" °C"): - return float(val.split(" ")[0]) - elif val.endswith(" %"): - return float(val.split(" ")[0]) - return val -def refactor_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """统一的数据重构函数,根据操作类型自动选择模板""" - refactored_data = [] - - # 定义操作映射,包含生物实验和有机化学的所有操作 - OPERATION_MAPPING = { - # 生物实验操作 - "transfer_liquid": "SynBioFactory-liquid_handler.prcxi-transfer_liquid", - "transfer": "SynBioFactory-liquid_handler.biomek-transfer", - "incubation": "SynBioFactory-liquid_handler.biomek-incubation", - "move_labware": "SynBioFactory-liquid_handler.biomek-move_labware", - "oscillation": "SynBioFactory-liquid_handler.biomek-oscillation", - # 有机化学操作 - "HeatChillToTemp": "SynBioFactory-workstation-HeatChillProtocol", - "StopHeatChill": "SynBioFactory-workstation-HeatChillStopProtocol", - "StartHeatChill": "SynBioFactory-workstation-HeatChillStartProtocol", - "HeatChill": "SynBioFactory-workstation-HeatChillProtocol", - "Dissolve": "SynBioFactory-workstation-DissolveProtocol", - "Transfer": "SynBioFactory-workstation-TransferProtocol", - "Evaporate": "SynBioFactory-workstation-EvaporateProtocol", - "Recrystallize": "SynBioFactory-workstation-RecrystallizeProtocol", - "Filter": "SynBioFactory-workstation-FilterProtocol", - "Dry": "SynBioFactory-workstation-DryProtocol", - "Add": "SynBioFactory-workstation-AddProtocol", - } - - UNSUPPORTED_OPERATIONS = ["Purge", "Wait", "Stir", "ResetHandling"] - - for step in data: - operation = step.get("action") - if not operation or operation in UNSUPPORTED_OPERATIONS: - continue - - # 处理重复操作 - if operation == "Repeat": - times = step.get("times", step.get("parameters", {}).get("times", 1)) - sub_steps = step.get("steps", step.get("parameters", {}).get("steps", [])) - for i in range(int(times)): - sub_data = refactor_data(sub_steps) - refactored_data.extend(sub_data) - continue - - # 获取模板名称 - template = OPERATION_MAPPING.get(operation) - if not template: - # 自动推断模板类型 - if operation.lower() in ["transfer", "incubation", "move_labware", "oscillation"]: - template = f"SynBioFactory-liquid_handler.biomek-{operation}" - else: - template = f"SynBioFactory-workstation-{operation}Protocol" - - # 创建步骤数据 - step_data = { - "template": template, - "description": step.get("description", step.get("purpose", f"{operation} operation")), - "lab_node_type": "Device", - "parameters": step.get("parameters", step.get("action_args", {})), - } - refactored_data.append(step_data) - - return refactored_data - - -def build_protocol_graph( - labware_info: List[Dict[str, Any]], protocol_steps: List[Dict[str, Any]], workstation_name: str -) -> SimpleGraph: - """统一的协议图构建函数,根据设备类型自动选择构建逻辑""" - G = SimpleGraph() - resource_last_writer = {} - LAB_NAME = "SynBioFactory" - - protocol_steps = refactor_data(protocol_steps) - - # 检查协议步骤中的模板来判断协议类型 - has_biomek_template = any( - ("biomek" in step.get("template", "")) or ("prcxi" in step.get("template", "")) - for step in protocol_steps - ) - - if has_biomek_template: - # 生物实验协议图构建 - for labware_id, labware in labware_info.items(): - node_id = str(uuid.uuid4()) - - labware_attrs = labware.copy() - labware_id = labware_attrs.pop("id", labware_attrs.get("name", f"labware_{uuid.uuid4()}")) - labware_attrs["description"] = labware_id - labware_attrs["lab_node_type"] = ( - "Reagent" if "Plate" in str(labware_id) else "Labware" if "Rack" in str(labware_id) else "Sample" - ) - labware_attrs["device_id"] = workstation_name - - G.add_node(node_id, template=f"{LAB_NAME}-host_node-create_resource", **labware_attrs) - resource_last_writer[labware_id] = f"{node_id}:labware" - - # 处理协议步骤 - prev_node = None - for i, step in enumerate(protocol_steps): - node_id = str(uuid.uuid4()) - G.add_node(node_id, **step) - - # 添加控制流边 - if prev_node is not None: - G.add_edge(prev_node, node_id, source_port="ready", target_port="ready") - prev_node = node_id - - # 处理物料流 - params = step.get("parameters", {}) - if "sources" in params and params["sources"] in resource_last_writer: - source_node, source_port = resource_last_writer[params["sources"]].split(":") - G.add_edge(source_node, node_id, source_port=source_port, target_port="labware") - - if "targets" in params: - resource_last_writer[params["targets"]] = f"{node_id}:labware" - - # 添加协议结束节点 - end_id = str(uuid.uuid4()) - G.add_node(end_id, template=f"{LAB_NAME}-liquid_handler.biomek-run_protocol") - if prev_node is not None: - G.add_edge(prev_node, end_id, source_port="ready", target_port="ready") - - else: - # 有机化学协议图构建 - WORKSTATION_ID = workstation_name - - # 为所有labware创建资源节点 - for item_id, item in labware_info.items(): - # item_id = item.get("id") or item.get("name", f"item_{uuid.uuid4()}") - node_id = str(uuid.uuid4()) - - # 判断节点类型 - if item.get("type") == "hardware" or "reactor" in str(item_id).lower(): - if "reactor" not in str(item_id).lower(): - continue - lab_node_type = "Sample" - description = f"Prepare Reactor: {item_id}" - liquid_type = [] - liquid_volume = [] - else: - lab_node_type = "Reagent" - description = f"Add Reagent to Flask: {item_id}" - liquid_type = [item_id] - liquid_volume = [1e5] - - G.add_node( - node_id, - template=f"{LAB_NAME}-host_node-create_resource", - description=description, - lab_node_type=lab_node_type, - res_id=item_id, - device_id=WORKSTATION_ID, - class_name="container", - parent=WORKSTATION_ID, - bind_locations={"x": 0.0, "y": 0.0, "z": 0.0}, - liquid_input_slot=[-1], - liquid_type=liquid_type, - liquid_volume=liquid_volume, - slot_on_deck="", - role=item.get("role", ""), - ) - resource_last_writer[item_id] = f"{node_id}:labware" - - last_control_node_id = None - - # 处理协议步骤 - for step in protocol_steps: - node_id = str(uuid.uuid4()) - G.add_node(node_id, **step) - - # 控制流 - if last_control_node_id is not None: - G.add_edge(last_control_node_id, node_id, source_port="ready", target_port="ready") - last_control_node_id = node_id - - # 物料流 - params = step.get("parameters", {}) - input_resources = { - "Vessel": params.get("vessel"), - "ToVessel": params.get("to_vessel"), - "FromVessel": params.get("from_vessel"), - "reagent": params.get("reagent"), - "solvent": params.get("solvent"), - "compound": params.get("compound"), - "sources": params.get("sources"), - "targets": params.get("targets"), - } - - for target_port, resource_name in input_resources.items(): - if resource_name and resource_name in resource_last_writer: - source_node, source_port = resource_last_writer[resource_name].split(":") - G.add_edge(source_node, node_id, source_port=source_port, target_port=target_port) - - output_resources = { - "VesselOut": params.get("vessel"), - "FromVesselOut": params.get("from_vessel"), - "ToVesselOut": params.get("to_vessel"), - "FiltrateOut": params.get("filtrate_vessel"), - "reagent": params.get("reagent"), - "solvent": params.get("solvent"), - "compound": params.get("compound"), - "sources_out": params.get("sources"), - "targets_out": params.get("targets"), - } - - for source_port, resource_name in output_resources.items(): - if resource_name: - resource_last_writer[resource_name] = f"{node_id}:{source_port}" - - return G - - -def draw_protocol_graph(protocol_graph: SimpleGraph, output_path: str): - """ - (辅助功能) 使用 networkx 和 matplotlib 绘制协议工作流图,用于可视化。 - """ - if not protocol_graph: - print("Cannot draw graph: Graph object is empty.") - return - - G = nx.DiGraph() - - for node_id, attrs in protocol_graph.nodes.items(): - label = attrs.get("description", attrs.get("template", node_id[:8])) - G.add_node(node_id, label=label, **attrs) - - for edge in protocol_graph.edges: - G.add_edge(edge["source"], edge["target"]) - - plt.figure(figsize=(20, 15)) - try: - pos = nx.nx_agraph.graphviz_layout(G, prog="dot") - except Exception: - pos = nx.shell_layout(G) # Fallback layout - - node_labels = {node: data["label"] for node, data in G.nodes(data=True)} - nx.draw( - G, - pos, - with_labels=False, - node_size=2500, - node_color="skyblue", - node_shape="o", - edge_color="gray", - width=1.5, - arrowsize=15, - ) - nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=8, font_weight="bold") - - plt.title("Chemical Protocol Workflow Graph", size=15) - plt.savefig(output_path, dpi=300, bbox_inches="tight") - plt.close() - print(f" - Visualization saved to '{output_path}'") - - -from networkx.drawing.nx_agraph import to_agraph -import re - -COMPASS = {"n","e","s","w","ne","nw","se","sw","c"} - -def _is_compass(port: str) -> bool: - return isinstance(port, str) and port.lower() in COMPASS - -def draw_protocol_graph_with_ports(protocol_graph, output_path: str, rankdir: str = "LR"): - """ - 使用 Graphviz 端口语法绘制协议工作流图。 - - 若边上的 source_port/target_port 是 compass(n/e/s/w/...),直接用 compass。 - - 否则自动为节点创建 record 形状并定义命名端口 。 - 最终由 PyGraphviz 渲染并输出到 output_path(后缀决定格式,如 .png/.svg/.pdf)。 - """ - if not protocol_graph: - print("Cannot draw graph: Graph object is empty.") - return - - # 1) 先用 networkx 搭建有向图,保留端口属性 - G = nx.DiGraph() - for node_id, attrs in protocol_graph.nodes.items(): - label = attrs.get("description", attrs.get("template", node_id[:8])) - # 保留一个干净的“中心标签”,用于放在 record 的中间槽 - G.add_node(node_id, _core_label=str(label), **{k:v for k,v in attrs.items() if k not in ("label",)}) - - edges_data = [] - in_ports_by_node = {} # 收集命名输入端口 - out_ports_by_node = {} # 收集命名输出端口 - - for edge in protocol_graph.edges: - u = edge["source"] - v = edge["target"] - sp = edge.get("source_port") - tp = edge.get("target_port") - - # 记录到图里(保留原始端口信息) - G.add_edge(u, v, source_port=sp, target_port=tp) - edges_data.append((u, v, sp, tp)) - - # 如果不是 compass,就按“命名端口”先归类,等会儿给节点造 record - if sp and not _is_compass(sp): - out_ports_by_node.setdefault(u, set()).add(str(sp)) - if tp and not _is_compass(tp): - in_ports_by_node.setdefault(v, set()).add(str(tp)) - - # 2) 转为 AGraph,使用 Graphviz 渲染 - A = to_agraph(G) - A.graph_attr.update(rankdir=rankdir, splines="true", concentrate="false", fontsize="10") - A.node_attr.update(shape="box", style="rounded,filled", fillcolor="lightyellow", color="#999999", fontname="Helvetica") - A.edge_attr.update(arrowsize="0.8", color="#666666") - - # 3) 为需要命名端口的节点设置 record 形状与 label - # 左列 = 输入端口;中间 = 核心标签;右列 = 输出端口 - for n in A.nodes(): - node = A.get_node(n) - core = G.nodes[n].get("_core_label", n) - - in_ports = sorted(in_ports_by_node.get(n, [])) - out_ports = sorted(out_ports_by_node.get(n, [])) - - # 如果该节点涉及命名端口,则用 record;否则保留原 box - if in_ports or out_ports: - def port_fields(ports): - if not ports: - return " " # 必须留一个空槽占位 - # 每个端口一个小格子,

name - return "|".join(f"<{re.sub(r'[^A-Za-z0-9_:.|-]', '_', p)}> {p}" for p in ports) - - left = port_fields(in_ports) - right = port_fields(out_ports) - - # 三栏:左(入) | 中(节点名) | 右(出) - record_label = f"{{ {left} | {core} | {right} }}" - node.attr.update(shape="record", label=record_label) - else: - # 没有命名端口:普通盒子,显示核心标签 - node.attr.update(label=str(core)) - - # 4) 给边设置 headport / tailport - # - 若端口为 compass:直接用 compass(e.g., headport="e") - # - 若端口为命名端口:使用在 record 中定义的 名(同名即可) - for (u, v, sp, tp) in edges_data: - e = A.get_edge(u, v) - - # Graphviz 属性:tail 是源,head 是目标 - if sp: - if _is_compass(sp): - e.attr["tailport"] = sp.lower() - else: - # 与 record label 中 名一致;特殊字符已在 label 中做了清洗 - e.attr["tailport"] = re.sub(r'[^A-Za-z0-9_:.|-]', '_', str(sp)) - - if tp: - if _is_compass(tp): - e.attr["headport"] = tp.lower() - else: - e.attr["headport"] = re.sub(r'[^A-Za-z0-9_:.|-]', '_', str(tp)) - - # 可选:若想让边更贴边缘,可设置 constraint/spline 等 - # e.attr["arrowhead"] = "vee" - - # 5) 输出 - A.draw(output_path, prog="dot") - print(f" - Port-aware workflow rendered to '{output_path}'") - - -def flatten_xdl_procedure(procedure_elem: ET.Element) -> List[ET.Element]: - """展平嵌套的XDL程序结构""" - flattened_operations = [] - TEMP_UNSUPPORTED_PROTOCOL = ["Purge", "Wait", "Stir", "ResetHandling"] - - def extract_operations(element: ET.Element): - if element.tag not in ["Prep", "Reaction", "Workup", "Purification", "Procedure"]: - if element.tag not in TEMP_UNSUPPORTED_PROTOCOL: - flattened_operations.append(element) - - for child in element: - extract_operations(child) - - for child in procedure_elem: - extract_operations(child) - - return flattened_operations - - -def parse_xdl_content(xdl_content: str) -> tuple: - """解析XDL内容""" - try: - xdl_content_cleaned = "".join(c for c in xdl_content if c.isprintable()) - root = ET.fromstring(xdl_content_cleaned) - - synthesis_elem = root.find("Synthesis") - if synthesis_elem is None: - return None, None, None - - # 解析硬件组件 - hardware_elem = synthesis_elem.find("Hardware") - hardware = [] - if hardware_elem is not None: - hardware = [{"id": c.get("id"), "type": c.get("type")} for c in hardware_elem.findall("Component")] - - # 解析试剂 - reagents_elem = synthesis_elem.find("Reagents") - reagents = [] - if reagents_elem is not None: - reagents = [{"name": r.get("name"), "role": r.get("role", "")} for r in reagents_elem.findall("Reagent")] - - # 解析程序 - procedure_elem = synthesis_elem.find("Procedure") - if procedure_elem is None: - return None, None, None - - flattened_operations = flatten_xdl_procedure(procedure_elem) - return hardware, reagents, flattened_operations - - except ET.ParseError as e: - raise ValueError(f"Invalid XDL format: {e}") - - -def convert_xdl_to_dict(xdl_content: str) -> Dict[str, Any]: - """ - 将XDL XML格式转换为标准的字典格式 - - Args: - xdl_content: XDL XML内容 - - Returns: - 转换结果,包含步骤和器材信息 - """ - try: - hardware, reagents, flattened_operations = parse_xdl_content(xdl_content) - if hardware is None: - return {"error": "Failed to parse XDL content", "success": False} - - # 将XDL元素转换为字典格式 - steps_data = [] - for elem in flattened_operations: - # 转换参数类型 - parameters = {} - for key, val in elem.attrib.items(): - converted_val = convert_to_type(val) - if converted_val is not None: - parameters[key] = converted_val - - step_dict = { - "operation": elem.tag, - "parameters": parameters, - "description": elem.get("purpose", f"Operation: {elem.tag}"), - } - steps_data.append(step_dict) - - # 合并硬件和试剂为统一的labware_info格式 - labware_data = [] - labware_data.extend({"id": hw["id"], "type": "hardware", **hw} for hw in hardware) - labware_data.extend({"name": reagent["name"], "type": "reagent", **reagent} for reagent in reagents) - - return { - "success": True, - "steps": steps_data, - "labware": labware_data, - "message": f"Successfully converted XDL to dict format. Found {len(steps_data)} steps and {len(labware_data)} labware items.", - } - - except Exception as e: - error_msg = f"XDL conversion failed: {str(e)}" - logger.error(error_msg) - return {"error": error_msg, "success": False} def create_workflow( diff --git a/unilabos/test/resources/__init__.py b/test/resources/__init__.py similarity index 100% rename from unilabos/test/resources/__init__.py rename to test/resources/__init__.py diff --git a/unilabos/test/resources/bioyond_materials_liquidhandling_1.json b/test/resources/bioyond_materials_liquidhandling_1.json similarity index 100% rename from unilabos/test/resources/bioyond_materials_liquidhandling_1.json rename to test/resources/bioyond_materials_liquidhandling_1.json diff --git a/unilabos/test/resources/bioyond_materials_liquidhandling_2.json b/test/resources/bioyond_materials_liquidhandling_2.json similarity index 100% rename from unilabos/test/resources/bioyond_materials_liquidhandling_2.json rename to test/resources/bioyond_materials_liquidhandling_2.json diff --git a/unilabos/test/resources/bioyond_materials_reaction.json b/test/resources/bioyond_materials_reaction.json similarity index 100% rename from unilabos/test/resources/bioyond_materials_reaction.json rename to test/resources/bioyond_materials_reaction.json diff --git a/unilabos/test/resources/test_bottle_carrier.py b/test/resources/test_bottle_carrier.py similarity index 100% rename from unilabos/test/resources/test_bottle_carrier.py rename to test/resources/test_bottle_carrier.py diff --git a/unilabos/test/resources/test_converter_bioyond.py b/test/resources/test_converter_bioyond.py similarity index 100% rename from unilabos/test/resources/test_converter_bioyond.py rename to test/resources/test_converter_bioyond.py diff --git a/unilabos/test/resources/test_itemized_carrier.py b/test/resources/test_itemized_carrier.py similarity index 100% rename from unilabos/test/resources/test_itemized_carrier.py rename to test/resources/test_itemized_carrier.py diff --git a/unilabos/test/resources/test_resourcetreeset.py b/test/resources/test_resourcetreeset.py similarity index 100% rename from unilabos/test/resources/test_resourcetreeset.py rename to test/resources/test_resourcetreeset.py diff --git a/unilabos/test/ros/__init__.py b/test/ros/__init__.py similarity index 100% rename from unilabos/test/ros/__init__.py rename to test/ros/__init__.py diff --git a/unilabos/test/ros/msgs/__init__.py b/test/ros/msgs/__init__.py similarity index 100% rename from unilabos/test/ros/msgs/__init__.py rename to test/ros/msgs/__init__.py diff --git a/unilabos/test/ros/msgs/test_basic.py b/test/ros/msgs/test_basic.py similarity index 100% rename from unilabos/test/ros/msgs/test_basic.py rename to test/ros/msgs/test_basic.py diff --git a/unilabos/test/ros/msgs/test_conversion.py b/test/ros/msgs/test_conversion.py similarity index 100% rename from unilabos/test/ros/msgs/test_conversion.py rename to test/ros/msgs/test_conversion.py diff --git a/unilabos/test/ros/msgs/test_mapping.py b/test/ros/msgs/test_mapping.py similarity index 100% rename from unilabos/test/ros/msgs/test_mapping.py rename to test/ros/msgs/test_mapping.py diff --git a/unilabos/test/ros/msgs/test_runner.py b/test/ros/msgs/test_runner.py similarity index 100% rename from unilabos/test/ros/msgs/test_runner.py rename to test/ros/msgs/test_runner.py diff --git a/unilabos/test/workflow/__init__.py b/test/workflow/__init__.py similarity index 100% rename from unilabos/test/workflow/__init__.py rename to test/workflow/__init__.py diff --git a/unilabos/test/workflow/example_bio.json b/test/workflow/example_bio.json similarity index 100% rename from unilabos/test/workflow/example_bio.json rename to test/workflow/example_bio.json diff --git a/unilabos/test/workflow/example_bio_graph.png b/test/workflow/example_bio_graph.png similarity index 100% rename from unilabos/test/workflow/example_bio_graph.png rename to test/workflow/example_bio_graph.png diff --git a/unilabos/test/workflow/example_prcxi.json b/test/workflow/example_prcxi.json similarity index 100% rename from unilabos/test/workflow/example_prcxi.json rename to test/workflow/example_prcxi.json diff --git a/unilabos/test/workflow/example_prcxi_graph.png b/test/workflow/example_prcxi_graph.png similarity index 100% rename from unilabos/test/workflow/example_prcxi_graph.png rename to test/workflow/example_prcxi_graph.png diff --git a/unilabos/test/workflow/example_prcxi_graph_20251022_1359.png b/test/workflow/example_prcxi_graph_20251022_1359.png similarity index 100% rename from unilabos/test/workflow/example_prcxi_graph_20251022_1359.png rename to test/workflow/example_prcxi_graph_20251022_1359.png diff --git a/test/workflow/merge_workflow.py b/test/workflow/merge_workflow.py new file mode 100644 index 0000000..2801a74 --- /dev/null +++ b/test/workflow/merge_workflow.py @@ -0,0 +1,35 @@ +import sys +from datetime import datetime +from pathlib import Path + +ROOT_DIR = Path(__file__).resolve().parents[2] +if str(ROOT_DIR) not in sys.path: + sys.path.insert(0, str(ROOT_DIR)) + +import pytest + +from unilabos.workflow.convert_from_json import ( + convert_from_json, + normalize_steps as _normalize_steps, + normalize_labware as _normalize_labware, +) +from unilabos.workflow.common import draw_protocol_graph_with_ports + + +@pytest.mark.parametrize( + "protocol_name", + [ + "example_bio", + # "bioyond_materials_liquidhandling_1", + "example_prcxi", + ], +) +def test_build_protocol_graph(protocol_name): + data_path = Path(__file__).with_name(f"{protocol_name}.json") + + graph = convert_from_json(data_path, workstation_name="PRCXi") + + timestamp = datetime.now().strftime("%Y%m%d_%H%M") + output_path = data_path.with_name(f"{protocol_name}_graph_{timestamp}.png") + draw_protocol_graph_with_ports(graph, str(output_path)) + print(graph) diff --git a/unilabos/app/main.py b/unilabos/app/main.py index 5887552..f888f6f 100644 --- a/unilabos/app/main.py +++ b/unilabos/app/main.py @@ -20,6 +20,7 @@ if unilabos_dir not in sys.path: from unilabos.utils.banner_print import print_status, print_unilab_banner from unilabos.config.config import load_config, BasicConfig, HTTPConfig + def load_config_from_file(config_path): if config_path is None: config_path = os.environ.get("UNILABOS_BASICCONFIG_CONFIG_PATH", None) @@ -41,7 +42,7 @@ def convert_argv_dashes_to_underscores(args: argparse.ArgumentParser): for i, arg in enumerate(sys.argv): for option_string in option_strings: if arg.startswith(option_string): - new_arg = arg[:2] + arg[2:len(option_string)].replace("-", "_") + arg[len(option_string):] + new_arg = arg[:2] + arg[2 : len(option_string)].replace("-", "_") + arg[len(option_string) :] sys.argv[i] = new_arg break @@ -49,6 +50,8 @@ def convert_argv_dashes_to_underscores(args: argparse.ArgumentParser): def parse_args(): """解析命令行参数""" parser = argparse.ArgumentParser(description="Start Uni-Lab Edge server.") + subparsers = parser.add_subparsers(title="Valid subcommands", dest="command") + parser.add_argument("-g", "--graph", help="Physical setup graph file path.") parser.add_argument("-c", "--controllers", default=None, help="Controllers config file path.") parser.add_argument( @@ -153,6 +156,39 @@ def parse_args(): default=False, help="Complete registry information", ) + # workflow upload subcommand + workflow_parser = subparsers.add_parser( + "workflow_upload", + aliases=["wf"], + help="Upload workflow from xdl/json/python files", + ) + workflow_parser.add_argument( + "-f", + "--workflow_file", + type=str, + required=True, + help="Path to the workflow file (JSON format)", + ) + workflow_parser.add_argument( + "-n", + "--workflow_name", + type=str, + default=None, + help="Workflow name, if not provided will use the name from file or filename", + ) + workflow_parser.add_argument( + "--tags", + type=str, + nargs="*", + default=[], + help="Tags for the workflow (space-separated)", + ) + workflow_parser.add_argument( + "--published", + action="store_true", + default=False, + help="Whether to publish the workflow (default: False)", + ) return parser @@ -167,7 +203,6 @@ def main(): if not args_dict.get("skip_env_check", False): from unilabos.utils.environment_check import check_environment - print_status("正在进行环境依赖检查...", "info") if not check_environment(auto_install=True): print_status("环境检查失败,程序退出", "error") os._exit(1) @@ -239,9 +274,12 @@ def main(): if args_dict.get("sk", ""): BasicConfig.sk = args_dict.get("sk", "") print_status("传入了sk参数,优先采用传入参数!", "info") + BasicConfig.working_dir = working_dir + + workflow_upload = args_dict.get("command") in ("workflow_upload", "wf") # 使用远程资源启动 - if args_dict["use_remote_resource"]: + if not workflow_upload and args_dict["use_remote_resource"]: print_status("使用远程资源启动", "info") from unilabos.app.web import http_client @@ -254,7 +292,6 @@ def main(): BasicConfig.port = args_dict["port"] if args_dict["port"] else BasicConfig.port BasicConfig.disable_browser = args_dict["disable_browser"] or BasicConfig.disable_browser - BasicConfig.working_dir = working_dir BasicConfig.is_host_mode = not args_dict.get("is_slave", False) BasicConfig.slave_no_host = args_dict.get("slave_no_host", False) BasicConfig.upload_registry = args_dict.get("upload_registry", False) @@ -283,9 +320,31 @@ def main(): # 注册表 lab_registry = build_registry( - args_dict["registry_path"], args_dict.get("complete_registry", False), args_dict["upload_registry"] + args_dict["registry_path"], args_dict.get("complete_registry", False), BasicConfig.upload_registry ) + if BasicConfig.upload_registry: + # 设备注册到服务端 - 需要 ak 和 sk + if BasicConfig.ak and BasicConfig.sk: + print_status("开始注册设备到服务端...", "info") + try: + register_devices_and_resources(lab_registry) + print_status("设备注册完成", "info") + except Exception as e: + print_status(f"设备注册失败: {e}", "error") + else: + print_status("未提供 ak 和 sk,跳过设备注册", "info") + else: + print_status("本次启动注册表不报送云端,如果您需要联网调试,请在启动命令增加--upload_registry", "warning") + + # 处理 workflow_upload 子命令 + if workflow_upload: + from unilabos.workflow.wf_utils import handle_workflow_upload_command + + handle_workflow_upload_command(args_dict) + print_status("工作流上传完成,程序退出", "info") + os._exit(0) + if not BasicConfig.ak or not BasicConfig.sk: print_status("后续运行必须拥有一个实验室,请前往 https://uni-lab.bohrium.com 注册实验室!", "warning") os._exit(1) @@ -362,20 +421,6 @@ def main(): args_dict["devices_config"] = resource_tree_set args_dict["graph"] = graph_res.physical_setup_graph - if BasicConfig.upload_registry: - # 设备注册到服务端 - 需要 ak 和 sk - if BasicConfig.ak and BasicConfig.sk: - print_status("开始注册设备到服务端...", "info") - try: - register_devices_and_resources(lab_registry) - print_status("设备注册完成", "info") - except Exception as e: - print_status(f"设备注册失败: {e}", "error") - else: - print_status("未提供 ak 和 sk,跳过设备注册", "info") - else: - print_status("本次启动注册表不报送云端,如果您需要联网调试,请在启动命令增加--upload_registry", "warning") - if args_dict["controllers"] is not None: args_dict["controllers_config"] = yaml.safe_load(open(args_dict["controllers"], encoding="utf-8")) else: @@ -390,6 +435,7 @@ def main(): comm_client = get_communication_client() if "websocket" in args_dict["app_bridges"]: args_dict["bridges"].append(comm_client) + def _exit(signum, frame): comm_client.stop() sys.exit(0) @@ -431,16 +477,13 @@ def main(): resource_visualization.start() except OSError as e: if "AMENT_PREFIX_PATH" in str(e): - print_status( - f"ROS 2环境未正确设置,跳过3D可视化启动。错误详情: {e}", - "warning" - ) + print_status(f"ROS 2环境未正确设置,跳过3D可视化启动。错误详情: {e}", "warning") print_status( "建议解决方案:\n" "1. 激活Conda环境: conda activate unilab\n" "2. 或使用 --backend simple 参数\n" "3. 或使用 --visual disable 参数禁用可视化", - "info" + "info", ) else: raise diff --git a/unilabos/app/web/client.py b/unilabos/app/web/client.py index 72c079a..1f40a0b 100644 --- a/unilabos/app/web/client.py +++ b/unilabos/app/web/client.py @@ -76,7 +76,8 @@ class HTTPClient: Dict[str, str]: 旧UUID到新UUID的映射关系 {old_uuid: new_uuid} """ with open(os.path.join(BasicConfig.working_dir, "req_resource_tree_add.json"), "w", encoding="utf-8") as f: - f.write(json.dumps({"nodes": [x for xs in resources.dump() for x in xs], "mount_uuid": mount_uuid}, indent=4)) + payload = {"nodes": [x for xs in resources.dump() for x in xs], "mount_uuid": mount_uuid} + f.write(json.dumps(payload, indent=4)) # 从序列化数据中提取所有节点的UUID(保存旧UUID) old_uuids = {n.res_content.uuid: n for n in resources.all_nodes} if not self.initialized or first_add: @@ -331,6 +332,67 @@ class HTTPClient: logger.error(f"响应内容: {response.text}") return None + def workflow_import( + self, + name: str, + workflow_uuid: str, + workflow_name: str, + nodes: List[Dict[str, Any]], + edges: List[Dict[str, Any]], + tags: Optional[List[str]] = None, + published: bool = False, + ) -> Dict[str, Any]: + """ + 导入工作流到服务器 + + Args: + name: 工作流名称(顶层) + workflow_uuid: 工作流UUID + workflow_name: 工作流名称(data内部) + nodes: 工作流节点列表 + edges: 工作流边列表 + tags: 工作流标签列表,默认为空列表 + published: 是否发布工作流,默认为False + + Returns: + Dict: API响应数据,包含 code 和 data (uuid, name) + """ + # target_lab_uuid 暂时使用默认值,后续由后端根据 ak/sk 获取 + payload = { + "target_lab_uuid": "28c38bb0-63f6-4352-b0d8-b5b8eb1766d5", + "name": name, + "data": { + "workflow_uuid": workflow_uuid, + "workflow_name": workflow_name, + "nodes": nodes, + "edges": edges, + "tags": tags if tags is not None else [], + "published": published, + }, + } + # 保存请求到文件 + with open(os.path.join(BasicConfig.working_dir, "req_workflow_upload.json"), "w", encoding="utf-8") as f: + f.write(json.dumps(payload, indent=4, ensure_ascii=False)) + + response = requests.post( + f"{self.remote_addr}/lab/workflow/owner/import", + json=payload, + headers={"Authorization": f"Lab {self.auth}"}, + timeout=60, + ) + # 保存响应到文件 + with open(os.path.join(BasicConfig.working_dir, "res_workflow_upload.json"), "w", encoding="utf-8") as f: + f.write(f"{response.status_code}" + "\n" + response.text) + + if response.status_code == 200: + res = response.json() + if "code" in res and res["code"] != 0: + logger.error(f"导入工作流失败: {response.text}") + return res + else: + logger.error(f"导入工作流失败: {response.status_code}, {response.text}") + return {"code": response.status_code, "message": response.text} + # 创建默认客户端实例 http_client = HTTPClient() diff --git a/unilabos/app/ws_client.py b/unilabos/app/ws_client.py index 50204a2..8c44712 100644 --- a/unilabos/app/ws_client.py +++ b/unilabos/app/ws_client.py @@ -438,7 +438,7 @@ class MessageProcessor: self.connected = True self.reconnect_count = 0 - logger.info(f"[MessageProcessor] Connected to {self.websocket_url}") + logger.trace(f"[MessageProcessor] Connected to {self.websocket_url}") # 启动发送协程 send_task = asyncio.create_task(self._send_handler()) @@ -503,7 +503,7 @@ class MessageProcessor: async def _send_handler(self): """处理发送队列中的消息""" - logger.debug("[MessageProcessor] Send handler started") + logger.trace("[MessageProcessor] Send handler started") try: while self.connected and self.websocket: @@ -965,7 +965,7 @@ class QueueProcessor: def _run(self): """运行队列处理主循环""" - logger.debug("[QueueProcessor] Queue processor started") + logger.trace("[QueueProcessor] Queue processor started") while self.is_running: try: @@ -1175,7 +1175,6 @@ class WebSocketClient(BaseCommunicationClient): else: url = f"{scheme}://{parsed.netloc}/api/v1/ws/schedule" - logger.debug(f"[WebSocketClient] URL: {url}") return url def start(self) -> None: @@ -1188,13 +1187,11 @@ class WebSocketClient(BaseCommunicationClient): logger.error("[WebSocketClient] WebSocket URL not configured") return - logger.info(f"[WebSocketClient] Starting connection to {self.websocket_url}") - # 启动两个核心线程 self.message_processor.start() self.queue_processor.start() - logger.info("[WebSocketClient] All threads started") + logger.trace("[WebSocketClient] All threads started") def stop(self) -> None: """停止WebSocket客户端""" diff --git a/unilabos/config/config.py b/unilabos/config/config.py index c13064e..223d12c 100644 --- a/unilabos/config/config.py +++ b/unilabos/config/config.py @@ -21,7 +21,8 @@ class BasicConfig: startup_json_path = None # 填写绝对路径 disable_browser = False # 禁止浏览器自动打开 port = 8002 # 本地HTTP服务 - log_level: Literal['TRACE', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] = "DEBUG" # 'TRACE', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' + # 'TRACE', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' + log_level: Literal["TRACE", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "DEBUG" @classmethod def auth_secret(cls): @@ -65,13 +66,14 @@ def _update_config_from_module(module): if not attr.startswith("_"): setattr(obj, attr, getattr(getattr(module, name), attr)) + def _update_config_from_env(): prefix = "UNILABOS_" for env_key, env_value in os.environ.items(): if not env_key.startswith(prefix): continue try: - key_path = env_key[len(prefix):] # Remove UNILAB_ prefix + key_path = env_key[len(prefix) :] # Remove UNILAB_ prefix class_field = key_path.upper().split("_", 1) if len(class_field) != 2: logger.warning(f"[ENV] 环境变量格式不正确:{env_key}") diff --git a/unilabos/registry/devices/liquid_handler.yaml b/unilabos/registry/devices/liquid_handler.yaml index d38c43a..fdfb6b5 100644 --- a/unilabos/registry/devices/liquid_handler.yaml +++ b/unilabos/registry/devices/liquid_handler.yaml @@ -9333,7 +9333,34 @@ liquid_handler.prcxi: touch_tip: false use_channels: - 0 - handles: {} + handles: + input: + - data_key: liquid + data_source: handle + data_type: resource + handler_key: sources + label: sources + - data_key: liquid + data_source: executor + data_type: resource + handler_key: targets + label: targets + - data_key: liquid + data_source: executor + data_type: resource + handler_key: tip_rack + label: tip_rack + output: + - data_key: liquid + data_source: handle + data_type: resource + handler_key: sources_out + label: sources + - data_key: liquid + data_source: executor + data_type: resource + handler_key: targets_out + label: targets placeholder_keys: sources: unilabos_resources targets: unilabos_resources diff --git a/unilabos/registry/registry.py b/unilabos/registry/registry.py index 98ea7d7..49e5761 100644 --- a/unilabos/registry/registry.py +++ b/unilabos/registry/registry.py @@ -222,7 +222,7 @@ class Registry: abs_path = Path(path).absolute() resource_path = abs_path / "resources" files = list(resource_path.glob("*/*.yaml")) - logger.debug(f"[UniLab Registry] resources: {resource_path.exists()}, total: {len(files)}") + logger.trace(f"[UniLab Registry] load resources? {resource_path.exists()}, total: {len(files)}") current_resource_number = len(self.resource_type_registry) + 1 for i, file in enumerate(files): with open(file, encoding="utf-8", mode="r") as f: diff --git a/unilabos/resources/graphio.py b/unilabos/resources/graphio.py index 7f4479a..756c6f5 100644 --- a/unilabos/resources/graphio.py +++ b/unilabos/resources/graphio.py @@ -42,7 +42,7 @@ def canonicalize_nodes_data( Returns: ResourceTreeSet: 标准化后的资源树集合 """ - print_status(f"{len(nodes)} Resources loaded:", "info") + print_status(f"{len(nodes)} Resources loaded", "info") # 第一步:基本预处理(处理graphml的label字段) outer_host_node_id = None diff --git a/unilabos/ros/nodes/resource_tracker.py b/unilabos/ros/nodes/resource_tracker.py index 0eed117..849d64a 100644 --- a/unilabos/ros/nodes/resource_tracker.py +++ b/unilabos/ros/nodes/resource_tracker.py @@ -66,8 +66,8 @@ class ResourceDict(BaseModel): klass: str = Field(alias="class", description="Resource class name") pose: ResourceDictPosition = Field(description="Resource position", default_factory=ResourceDictPosition) config: Dict[str, Any] = Field(description="Resource configuration") - data: Dict[str, Any] = Field(description="Resource data") - extra: Dict[str, Any] = Field(description="Extra data") + data: Dict[str, Any] = Field(description="Resource data, eg: container liquid data") + extra: Dict[str, Any] = Field(description="Extra data, eg: slot index") @field_serializer("parent_uuid") def _serialize_parent(self, parent_uuid: Optional["ResourceDict"]): diff --git a/unilabos/test/workflow/merge_workflow.py b/unilabos/test/workflow/merge_workflow.py deleted file mode 100644 index fb40976..0000000 --- a/unilabos/test/workflow/merge_workflow.py +++ /dev/null @@ -1,94 +0,0 @@ -import json -import sys -from datetime import datetime -from pathlib import Path - -ROOT_DIR = Path(__file__).resolve().parents[2] -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - -import pytest - -from scripts.workflow import build_protocol_graph, draw_protocol_graph, draw_protocol_graph_with_ports - - -ROOT_DIR = Path(__file__).resolve().parents[2] -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - - -def _normalize_steps(data): - normalized = [] - for step in data: - action = step.get("action") or step.get("operation") - if not action: - continue - raw_params = step.get("parameters") or step.get("action_args") or {} - params = dict(raw_params) - - if "source" in raw_params and "sources" not in raw_params: - params["sources"] = raw_params["source"] - if "target" in raw_params and "targets" not in raw_params: - params["targets"] = raw_params["target"] - - description = step.get("description") or step.get("purpose") - step_dict = {"action": action, "parameters": params} - if description: - step_dict["description"] = description - normalized.append(step_dict) - return normalized - - -def _normalize_labware(data): - labware = {} - for item in data: - reagent_name = item.get("reagent_name") - key = reagent_name or item.get("material_name") or item.get("name") - if not key: - continue - key = str(key) - idx = 1 - original_key = key - while key in labware: - idx += 1 - key = f"{original_key}_{idx}" - - labware[key] = { - "slot": item.get("positions") or item.get("slot"), - "labware": item.get("material_name") or item.get("labware"), - "well": item.get("well", []), - "type": item.get("type", "reagent"), - "role": item.get("role", ""), - "name": key, - } - return labware - - -@pytest.mark.parametrize("protocol_name", [ - "example_bio", - # "bioyond_materials_liquidhandling_1", - "example_prcxi", -]) -def test_build_protocol_graph(protocol_name): - data_path = Path(__file__).with_name(f"{protocol_name}.json") - with data_path.open("r", encoding="utf-8") as fp: - d = json.load(fp) - - if "workflow" in d and "reagent" in d: - protocol_steps = d["workflow"] - labware_info = d["reagent"] - elif "steps_info" in d and "labware_info" in d: - protocol_steps = _normalize_steps(d["steps_info"]) - labware_info = _normalize_labware(d["labware_info"]) - else: - raise ValueError("Unsupported protocol format") - - graph = build_protocol_graph( - labware_info=labware_info, - protocol_steps=protocol_steps, - workstation_name="PRCXi", - ) - timestamp = datetime.now().strftime("%Y%m%d_%H%M") - output_path = data_path.with_name(f"{protocol_name}_graph_{timestamp}.png") - draw_protocol_graph_with_ports(graph, str(output_path)) - print(graph) \ No newline at end of file diff --git a/unilabos/workflow/__init__.py b/unilabos/workflow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/unilabos/workflow/common.py b/unilabos/workflow/common.py new file mode 100644 index 0000000..9bff049 --- /dev/null +++ b/unilabos/workflow/common.py @@ -0,0 +1,547 @@ +import re +import uuid + +import networkx as nx +from networkx.drawing.nx_agraph import to_agraph +import matplotlib.pyplot as plt +from typing import Dict, List, Any, Tuple, Optional + +Json = Dict[str, Any] + +# ---------------- Graph ---------------- + + +class WorkflowGraph: + """简单的有向图实现:使用 params 单层参数;inputs 内含连线;支持 node-link 导出""" + + def __init__(self): + self.nodes: Dict[str, Dict[str, Any]] = {} + self.edges: List[Dict[str, Any]] = [] + + def add_node(self, node_id: str, **attrs): + self.nodes[node_id] = attrs + + def add_edge(self, source: str, target: str, **attrs): + # 将 source_port/target_port 映射为服务端期望的 source_handle_key/target_handle_key + source_handle_key = attrs.pop("source_port", "") or attrs.pop("source_handle_key", "") + target_handle_key = attrs.pop("target_port", "") or attrs.pop("target_handle_key", "") + + edge = { + "source": source, + "target": target, + "source_node_uuid": source, + "target_node_uuid": target, + "source_handle_key": source_handle_key, + "source_handle_io": attrs.pop("source_handle_io", "source"), + "target_handle_key": target_handle_key, + "target_handle_io": attrs.pop("target_handle_io", "target"), + **attrs, + } + self.edges.append(edge) + + def _materialize_wiring_into_inputs( + self, + obj: Any, + inputs: Dict[str, Any], + variable_sources: Dict[str, Dict[str, Any]], + target_node_id: str, + base_path: List[str], + ): + has_var = False + + def walk(node: Any, path: List[str]): + nonlocal has_var + if isinstance(node, dict): + if "__var__" in node: + has_var = True + varname = node["__var__"] + placeholder = f"${{{varname}}}" + src = variable_sources.get(varname) + if src: + key = ".".join(path) # e.g. "params.foo.bar.0" + inputs[key] = {"node": src["node_id"], "output": src.get("output_name", "result")} + self.add_edge( + str(src["node_id"]), + target_node_id, + source_handle_io=src.get("output_name", "result"), + target_handle_io=key, + ) + return placeholder + return {k: walk(v, path + [k]) for k, v in node.items()} + if isinstance(node, list): + return [walk(v, path + [str(i)]) for i, v in enumerate(node)] + return node + + replaced = walk(obj, base_path[:]) + return replaced, has_var + + def add_workflow_node( + self, + node_id: int, + *, + device_key: Optional[str] = None, # 实例名,如 "ser" + resource_name: Optional[str] = None, # registry key(原 device_class) + module: Optional[str] = None, + template_name: Optional[str] = None, # 动作/模板名(原 action_key) + params: Dict[str, Any], + variable_sources: Dict[str, Dict[str, Any]], + add_ready_if_no_vars: bool = True, + prev_node_id: Optional[int] = None, + **extra_attrs, + ) -> None: + """添加工作流节点:params 单层;自动变量连线与 ready 串联;支持附加属性""" + node_id_str = str(node_id) + inputs: Dict[str, Any] = {} + + params, has_var = self._materialize_wiring_into_inputs( + params, inputs, variable_sources, node_id_str, base_path=["params"] + ) + + if add_ready_if_no_vars and not has_var: + last_id = str(prev_node_id) if prev_node_id is not None else "-1" + inputs["ready"] = {"node": int(last_id), "output": "ready"} + self.add_edge(last_id, node_id_str, source_handle_io="ready", target_handle_io="ready") + + node_obj = { + "device_key": device_key, + "resource_name": resource_name, # ✅ 新名字 + "module": module, + "template_name": template_name, # ✅ 新名字 + "params": params, + "inputs": inputs, + } + node_obj.update(extra_attrs or {}) + self.add_node(node_id_str, parameters=node_obj) + + # 顺序工作流导出(连线在 inputs,不返回 edges) + def to_dict(self) -> List[Dict[str, Any]]: + result = [] + for node_id, attrs in self.nodes.items(): + node = {"uuid": node_id} + params = dict(attrs.get("parameters", {}) or {}) + flat = {k: v for k, v in attrs.items() if k != "parameters"} + flat.update(params) + node.update(flat) + result.append(node) + return sorted(result, key=lambda n: int(n["uuid"]) if str(n["uuid"]).isdigit() else n["uuid"]) + + # node-link 导出(含 edges) + def to_node_link_dict(self) -> Dict[str, Any]: + nodes_list = [] + for node_id, attrs in self.nodes.items(): + node_attrs = attrs.copy() + params = node_attrs.pop("parameters", {}) or {} + node_attrs.update(params) + nodes_list.append({"uuid": node_id, **node_attrs}) + return { + "directed": True, + "multigraph": False, + "graph": {}, + "nodes": nodes_list, + "edges": self.edges, + "links": self.edges, + } + + +def refactor_data( + data: List[Dict[str, Any]], + action_resource_mapping: Optional[Dict[str, str]] = None, +) -> List[Dict[str, Any]]: + """统一的数据重构函数,根据操作类型自动选择模板 + + Args: + data: 原始步骤数据列表 + action_resource_mapping: action 到 resource_name 的映射字典,可选 + """ + refactored_data = [] + + # 定义操作映射,包含生物实验和有机化学的所有操作 + OPERATION_MAPPING = { + # 生物实验操作 + "transfer_liquid": "transfer_liquid", + "transfer": "transfer", + "incubation": "incubation", + "move_labware": "move_labware", + "oscillation": "oscillation", + # 有机化学操作 + "HeatChillToTemp": "HeatChillProtocol", + "StopHeatChill": "HeatChillStopProtocol", + "StartHeatChill": "HeatChillStartProtocol", + "HeatChill": "HeatChillProtocol", + "Dissolve": "DissolveProtocol", + "Transfer": "TransferProtocol", + "Evaporate": "EvaporateProtocol", + "Recrystallize": "RecrystallizeProtocol", + "Filter": "FilterProtocol", + "Dry": "DryProtocol", + "Add": "AddProtocol", + } + + UNSUPPORTED_OPERATIONS = ["Purge", "Wait", "Stir", "ResetHandling"] + + for step in data: + operation = step.get("action") + if not operation or operation in UNSUPPORTED_OPERATIONS: + continue + + # 处理重复操作 + if operation == "Repeat": + times = step.get("times", step.get("parameters", {}).get("times", 1)) + sub_steps = step.get("steps", step.get("parameters", {}).get("steps", [])) + for i in range(int(times)): + sub_data = refactor_data(sub_steps, action_resource_mapping) + refactored_data.extend(sub_data) + continue + + # 获取模板名称 + template_name = OPERATION_MAPPING.get(operation) + if not template_name: + # 自动推断模板类型 + if operation.lower() in ["transfer", "incubation", "move_labware", "oscillation"]: + template_name = f"biomek-{operation}" + else: + template_name = f"{operation}Protocol" + + # 获取 resource_name + resource_name = f"device.{operation.lower()}" + if action_resource_mapping: + resource_name = action_resource_mapping.get(operation, resource_name) + + # 获取步骤编号,生成 name 字段 + step_number = step.get("step_number") + name = f"Step {step_number}" if step_number is not None else None + + # 创建步骤数据 + step_data = { + "template_name": template_name, + "resource_name": resource_name, + "description": step.get("description", step.get("purpose", f"{operation} operation")), + "lab_node_type": "Device", + "param": step.get("parameters", step.get("action_args", {})), + "footer": f"{template_name}-{resource_name}", + } + if name: + step_data["name"] = name + refactored_data.append(step_data) + + return refactored_data + + +def build_protocol_graph( + labware_info: List[Dict[str, Any]], + protocol_steps: List[Dict[str, Any]], + workstation_name: str, + action_resource_mapping: Optional[Dict[str, str]] = None, +) -> WorkflowGraph: + """统一的协议图构建函数,根据设备类型自动选择构建逻辑 + + Args: + labware_info: labware 信息字典 + protocol_steps: 协议步骤列表 + workstation_name: 工作站名称 + action_resource_mapping: action 到 resource_name 的映射字典,可选 + """ + G = WorkflowGraph() + resource_last_writer = {} + + protocol_steps = refactor_data(protocol_steps, action_resource_mapping) + # 有机化学&移液站协议图构建 + WORKSTATION_ID = workstation_name + + # 为所有labware创建资源节点 + res_index = 0 + for labware_id, item in labware_info.items(): + # item_id = item.get("id") or item.get("name", f"item_{uuid.uuid4()}") + node_id = str(uuid.uuid4()) + + # 判断节点类型 + if "Rack" in str(labware_id) or "Tip" in str(labware_id): + lab_node_type = "Labware" + description = f"Prepare Labware: {labware_id}" + liquid_type = [] + liquid_volume = [] + elif item.get("type") == "hardware" or "reactor" in str(labware_id).lower(): + if "reactor" not in str(labware_id).lower(): + continue + lab_node_type = "Sample" + description = f"Prepare Reactor: {labware_id}" + liquid_type = [] + liquid_volume = [] + else: + lab_node_type = "Reagent" + description = f"Add Reagent to Flask: {labware_id}" + liquid_type = [labware_id] + liquid_volume = [1e5] + + res_index += 1 + G.add_node( + node_id, + template_name="create_resource", + resource_name="host_node", + name=f"Res {res_index}", + description=description, + lab_node_type=lab_node_type, + footer="create_resource-host_node", + param={ + "res_id": labware_id, + "device_id": WORKSTATION_ID, + "class_name": "container", + "parent": WORKSTATION_ID, + "bind_locations": {"x": 0.0, "y": 0.0, "z": 0.0}, + "liquid_input_slot": [-1], + "liquid_type": liquid_type, + "liquid_volume": liquid_volume, + "slot_on_deck": "", + }, + ) + resource_last_writer[labware_id] = f"{node_id}:labware" + + last_control_node_id = None + + # 处理协议步骤 + for step in protocol_steps: + node_id = str(uuid.uuid4()) + G.add_node(node_id, **step) + + # 控制流 + if last_control_node_id is not None: + G.add_edge(last_control_node_id, node_id, source_port="ready", target_port="ready") + last_control_node_id = node_id + + # 物料流 + params = step.get("param", {}) + input_resources_possible_names = [ + "vessel", + "to_vessel", + "from_vessel", + "reagent", + "solvent", + "compound", + "sources", + "targets", + ] + + for target_port in input_resources_possible_names: + resource_name = params.get(target_port) + if resource_name and resource_name in resource_last_writer: + source_node, source_port = resource_last_writer[resource_name].split(":") + G.add_edge(source_node, node_id, source_port=source_port, target_port=target_port) + + output_resources = { + "vessel_out": params.get("vessel"), + "from_vessel_out": params.get("from_vessel"), + "to_vessel_out": params.get("to_vessel"), + "filtrate_out": params.get("filtrate_vessel"), + "reagent": params.get("reagent"), + "solvent": params.get("solvent"), + "compound": params.get("compound"), + "sources_out": params.get("sources"), + "targets_out": params.get("targets"), + } + + for source_port, resource_name in output_resources.items(): + if resource_name: + resource_last_writer[resource_name] = f"{node_id}:{source_port}" + + return G + + +def draw_protocol_graph(protocol_graph: WorkflowGraph, output_path: str): + """ + (辅助功能) 使用 networkx 和 matplotlib 绘制协议工作流图,用于可视化。 + """ + if not protocol_graph: + print("Cannot draw graph: Graph object is empty.") + return + + G = nx.DiGraph() + + for node_id, attrs in protocol_graph.nodes.items(): + label = attrs.get("description", attrs.get("template_name", node_id[:8])) + G.add_node(node_id, label=label, **attrs) + + for edge in protocol_graph.edges: + G.add_edge(edge["source"], edge["target"]) + + plt.figure(figsize=(20, 15)) + try: + pos = nx.nx_agraph.graphviz_layout(G, prog="dot") + except Exception: + pos = nx.shell_layout(G) # Fallback layout + + node_labels = {node: data["label"] for node, data in G.nodes(data=True)} + nx.draw( + G, + pos, + with_labels=False, + node_size=2500, + node_color="skyblue", + node_shape="o", + edge_color="gray", + width=1.5, + arrowsize=15, + ) + nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=8, font_weight="bold") + + plt.title("Chemical Protocol Workflow Graph", size=15) + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close() + print(f" - Visualization saved to '{output_path}'") + + +COMPASS = {"n", "e", "s", "w", "ne", "nw", "se", "sw", "c"} + + +def _is_compass(port: str) -> bool: + return isinstance(port, str) and port.lower() in COMPASS + + +def draw_protocol_graph_with_ports(protocol_graph, output_path: str, rankdir: str = "LR"): + """ + 使用 Graphviz 端口语法绘制协议工作流图。 + - 若边上的 source_port/target_port 是 compass(n/e/s/w/...),直接用 compass。 + - 否则自动为节点创建 record 形状并定义命名端口 。 + 最终由 PyGraphviz 渲染并输出到 output_path(后缀决定格式,如 .png/.svg/.pdf)。 + """ + if not protocol_graph: + print("Cannot draw graph: Graph object is empty.") + return + + # 1) 先用 networkx 搭建有向图,保留端口属性 + G = nx.DiGraph() + for node_id, attrs in protocol_graph.nodes.items(): + label = attrs.get("description", attrs.get("template_name", node_id[:8])) + # 保留一个干净的“中心标签”,用于放在 record 的中间槽 + G.add_node(node_id, _core_label=str(label), **{k: v for k, v in attrs.items() if k not in ("label",)}) + + edges_data = [] + in_ports_by_node = {} # 收集命名输入端口 + out_ports_by_node = {} # 收集命名输出端口 + + for edge in protocol_graph.edges: + u = edge["source"] + v = edge["target"] + sp = edge.get("source_handle_key") or edge.get("source_port") + tp = edge.get("target_handle_key") or edge.get("target_port") + + # 记录到图里(保留原始端口信息) + G.add_edge(u, v, source_handle_key=sp, target_handle_key=tp) + edges_data.append((u, v, sp, tp)) + + # 如果不是 compass,就按“命名端口”先归类,等会儿给节点造 record + if sp and not _is_compass(sp): + out_ports_by_node.setdefault(u, set()).add(str(sp)) + if tp and not _is_compass(tp): + in_ports_by_node.setdefault(v, set()).add(str(tp)) + + # 2) 转为 AGraph,使用 Graphviz 渲染 + A = to_agraph(G) + A.graph_attr.update(rankdir=rankdir, splines="true", concentrate="false", fontsize="10") + A.node_attr.update( + shape="box", style="rounded,filled", fillcolor="lightyellow", color="#999999", fontname="Helvetica" + ) + A.edge_attr.update(arrowsize="0.8", color="#666666") + + # 3) 为需要命名端口的节点设置 record 形状与 label + # 左列 = 输入端口;中间 = 核心标签;右列 = 输出端口 + for n in A.nodes(): + node = A.get_node(n) + core = G.nodes[n].get("_core_label", n) + + in_ports = sorted(in_ports_by_node.get(n, [])) + out_ports = sorted(out_ports_by_node.get(n, [])) + + # 如果该节点涉及命名端口,则用 record;否则保留原 box + if in_ports or out_ports: + + def port_fields(ports): + if not ports: + return " " # 必须留一个空槽占位 + # 每个端口一个小格子,

name + return "|".join(f"<{re.sub(r'[^A-Za-z0-9_:.|-]', '_', p)}> {p}" for p in ports) + + left = port_fields(in_ports) + right = port_fields(out_ports) + + # 三栏:左(入) | 中(节点名) | 右(出) + record_label = f"{{ {left} | {core} | {right} }}" + node.attr.update(shape="record", label=record_label) + else: + # 没有命名端口:普通盒子,显示核心标签 + node.attr.update(label=str(core)) + + # 4) 给边设置 headport / tailport + # - 若端口为 compass:直接用 compass(e.g., headport="e") + # - 若端口为命名端口:使用在 record 中定义的 名(同名即可) + for u, v, sp, tp in edges_data: + e = A.get_edge(u, v) + + # Graphviz 属性:tail 是源,head 是目标 + if sp: + if _is_compass(sp): + e.attr["tailport"] = sp.lower() + else: + # 与 record label 中 名一致;特殊字符已在 label 中做了清洗 + e.attr["tailport"] = re.sub(r"[^A-Za-z0-9_:.|-]", "_", str(sp)) + + if tp: + if _is_compass(tp): + e.attr["headport"] = tp.lower() + else: + e.attr["headport"] = re.sub(r"[^A-Za-z0-9_:.|-]", "_", str(tp)) + + # 可选:若想让边更贴边缘,可设置 constraint/spline 等 + # e.attr["arrowhead"] = "vee" + + # 5) 输出 + A.draw(output_path, prog="dot") + print(f" - Port-aware workflow rendered to '{output_path}'") + + +# ---------------- Registry Adapter ---------------- + + +class RegistryAdapter: + """根据 module 的类名(冒号右侧)反查 registry 的 resource_name(原 device_class),并抽取参数顺序""" + + def __init__(self, device_registry: Dict[str, Any]): + self.device_registry = device_registry or {} + self.module_class_to_resource = self._build_module_class_index() + + def _build_module_class_index(self) -> Dict[str, str]: + idx = {} + for resource_name, info in self.device_registry.items(): + module = info.get("module") + if isinstance(module, str) and ":" in module: + cls = module.split(":")[-1] + idx[cls] = resource_name + idx[cls.lower()] = resource_name + return idx + + def resolve_resource_by_classname(self, class_name: str) -> Optional[str]: + if not class_name: + return None + return self.module_class_to_resource.get(class_name) or self.module_class_to_resource.get(class_name.lower()) + + def get_device_module(self, resource_name: Optional[str]) -> Optional[str]: + if not resource_name: + return None + return self.device_registry.get(resource_name, {}).get("module") + + def get_actions(self, resource_name: Optional[str]) -> Dict[str, Any]: + if not resource_name: + return {} + return (self.device_registry.get(resource_name, {}).get("class", {}).get("action_value_mappings", {})) or {} + + def get_action_schema(self, resource_name: Optional[str], template_name: str) -> Optional[Json]: + return (self.get_actions(resource_name).get(template_name) or {}).get("schema") + + def get_action_goal_default(self, resource_name: Optional[str], template_name: str) -> Json: + return (self.get_actions(resource_name).get(template_name) or {}).get("goal_default", {}) or {} + + def get_action_input_keys(self, resource_name: Optional[str], template_name: str) -> List[str]: + schema = self.get_action_schema(resource_name, template_name) or {} + goal = (schema.get("properties") or {}).get("goal") or {} + props = goal.get("properties") or {} + required = goal.get("required") or [] + return list(dict.fromkeys(required + list(props.keys()))) diff --git a/unilabos/workflow/convert_from_json.py b/unilabos/workflow/convert_from_json.py new file mode 100644 index 0000000..7a6d2b4 --- /dev/null +++ b/unilabos/workflow/convert_from_json.py @@ -0,0 +1,356 @@ +""" +JSON 工作流转换模块 + +提供从多种 JSON 格式转换为统一工作流格式的功能。 +支持的格式: +1. workflow/reagent 格式 +2. steps_info/labware_info 格式 +""" + +import json +from os import PathLike +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from unilabos.workflow.common import WorkflowGraph, build_protocol_graph +from unilabos.registry.registry import lab_registry + + +def get_action_handles(resource_name: str, template_name: str) -> Dict[str, List[str]]: + """ + 从 registry 获取指定设备和动作的 handles 配置 + + Args: + resource_name: 设备资源名称,如 "liquid_handler.prcxi" + template_name: 动作模板名称,如 "transfer_liquid" + + Returns: + 包含 source 和 target handler_keys 的字典: + {"source": ["sources_out", "targets_out", ...], "target": ["sources", "targets", ...]} + """ + result = {"source": [], "target": []} + + device_info = lab_registry.device_type_registry.get(resource_name, {}) + if not device_info: + return result + + action_mappings = device_info.get("class", {}).get("action_value_mappings", {}) + action_config = action_mappings.get(template_name, {}) + handles = action_config.get("handles", {}) + + if isinstance(handles, dict): + # 处理 input handles (作为 target) + for handle in handles.get("input", []): + handler_key = handle.get("handler_key", "") + if handler_key: + result["source"].append(handler_key) + # 处理 output handles (作为 source) + for handle in handles.get("output", []): + handler_key = handle.get("handler_key", "") + if handler_key: + result["target"].append(handler_key) + + return result + + +def validate_workflow_handles(graph: WorkflowGraph) -> Tuple[bool, List[str]]: + """ + 校验工作流图中所有边的句柄配置是否正确 + + Args: + graph: 工作流图对象 + + Returns: + (is_valid, errors): 是否有效,错误信息列表 + """ + errors = [] + nodes = graph.nodes + + for edge in graph.edges: + left_uuid = edge.get("source") + right_uuid = edge.get("target") + # target_handle_key是target, right的输入节点(入节点) + # source_handle_key是source, left的输出节点(出节点) + right_source_conn_key = edge.get("target_handle_key", "") + left_target_conn_key = edge.get("source_handle_key", "") + + # 获取源节点和目标节点信息 + left_node = nodes.get(left_uuid, {}) + right_node = nodes.get(right_uuid, {}) + + left_res_name = left_node.get("resource_name", "") + left_template_name = left_node.get("template_name", "") + right_res_name = right_node.get("resource_name", "") + right_template_name = right_node.get("template_name", "") + + # 获取源节点的 output handles + left_node_handles = get_action_handles(left_res_name, left_template_name) + target_valid_keys = left_node_handles.get("target", []) + target_valid_keys.append("ready") + + # 获取目标节点的 input handles + right_node_handles = get_action_handles(right_res_name, right_template_name) + source_valid_keys = right_node_handles.get("source", []) + source_valid_keys.append("ready") + + # 如果节点配置了 output handles,则 source_port 必须有效 + if not right_source_conn_key: + node_name = left_node.get("name", left_uuid[:8]) + errors.append(f"源节点 '{node_name}' 的 source_handle_key 为空," f"应设置为: {source_valid_keys}") + elif right_source_conn_key not in source_valid_keys: + node_name = left_node.get("name", left_uuid[:8]) + errors.append( + f"源节点 '{node_name}' 的 source 端点 '{right_source_conn_key}' 不存在," f"支持的端点: {source_valid_keys}" + ) + + # 如果节点配置了 input handles,则 target_port 必须有效 + if not left_target_conn_key: + node_name = right_node.get("name", right_uuid[:8]) + errors.append(f"目标节点 '{node_name}' 的 target_handle_key 为空," f"应设置为: {target_valid_keys}") + elif left_target_conn_key not in target_valid_keys: + node_name = right_node.get("name", right_uuid[:8]) + errors.append( + f"目标节点 '{node_name}' 的 target 端点 '{left_target_conn_key}' 不存在," + f"支持的端点: {target_valid_keys}" + ) + + return len(errors) == 0, errors + + +# action 到 resource_name 的映射 +ACTION_RESOURCE_MAPPING: Dict[str, str] = { + # 生物实验操作 + "transfer_liquid": "liquid_handler.prcxi", + "transfer": "liquid_handler.prcxi", + "incubation": "incubator.prcxi", + "move_labware": "labware_mover.prcxi", + "oscillation": "shaker.prcxi", + # 有机化学操作 + "HeatChillToTemp": "heatchill.chemputer", + "StopHeatChill": "heatchill.chemputer", + "StartHeatChill": "heatchill.chemputer", + "HeatChill": "heatchill.chemputer", + "Dissolve": "stirrer.chemputer", + "Transfer": "liquid_handler.chemputer", + "Evaporate": "rotavap.chemputer", + "Recrystallize": "reactor.chemputer", + "Filter": "filter.chemputer", + "Dry": "dryer.chemputer", + "Add": "liquid_handler.chemputer", +} + + +def normalize_steps(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 将不同格式的步骤数据规范化为统一格式 + + 支持的输入格式: + - action + parameters + - action + action_args + - operation + parameters + + Args: + data: 原始步骤数据列表 + + Returns: + 规范化后的步骤列表,格式为 [{"action": str, "parameters": dict, "description": str?, "step_number": int?}, ...] + """ + normalized = [] + for idx, step in enumerate(data): + # 获取动作名称(支持 action 或 operation 字段) + action = step.get("action") or step.get("operation") + if not action: + continue + + # 获取参数(支持 parameters 或 action_args 字段) + raw_params = step.get("parameters") or step.get("action_args") or {} + params = dict(raw_params) + + # 规范化 source/target -> sources/targets + if "source" in raw_params and "sources" not in raw_params: + params["sources"] = raw_params["source"] + if "target" in raw_params and "targets" not in raw_params: + params["targets"] = raw_params["target"] + + # 获取描述(支持 description 或 purpose 字段) + description = step.get("description") or step.get("purpose") + + # 获取步骤编号(优先使用原始数据中的 step_number,否则使用索引+1) + step_number = step.get("step_number", idx + 1) + + step_dict = {"action": action, "parameters": params, "step_number": step_number} + if description: + step_dict["description"] = description + + normalized.append(step_dict) + + return normalized + + +def normalize_labware(data: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + """ + 将不同格式的 labware 数据规范化为统一的字典格式 + + 支持的输入格式: + - reagent_name + material_name + positions + - name + labware + slot + + Args: + data: 原始 labware 数据列表 + + Returns: + 规范化后的 labware 字典,格式为 {name: {"slot": int, "labware": str, "well": list, "type": str, "role": str, "name": str}, ...} + """ + labware = {} + for item in data: + # 获取 key 名称(优先使用 reagent_name,其次是 material_name 或 name) + reagent_name = item.get("reagent_name") + key = reagent_name or item.get("material_name") or item.get("name") + if not key: + continue + + key = str(key) + + # 处理重复 key,自动添加后缀 + idx = 1 + original_key = key + while key in labware: + idx += 1 + key = f"{original_key}_{idx}" + + labware[key] = { + "slot": item.get("positions") or item.get("slot"), + "labware": item.get("material_name") or item.get("labware"), + "well": item.get("well", []), + "type": item.get("type", "reagent"), + "role": item.get("role", ""), + "name": key, + } + + return labware + + +def convert_from_json( + data: Union[str, PathLike, Dict[str, Any]], + workstation_name: str = "PRCXi", + validate: bool = True, +) -> WorkflowGraph: + """ + 从 JSON 数据或文件转换为 WorkflowGraph + + 支持的 JSON 格式: + 1. {"workflow": [...], "reagent": {...}} - 直接格式 + 2. {"steps_info": [...], "labware_info": [...]} - 需要规范化的格式 + + Args: + data: JSON 文件路径、字典数据、或 JSON 字符串 + workstation_name: 工作站名称,默认 "PRCXi" + validate: 是否校验句柄配置,默认 True + + Returns: + WorkflowGraph: 构建好的工作流图 + + Raises: + ValueError: 不支持的 JSON 格式 或 句柄校验失败 + FileNotFoundError: 文件不存在 + json.JSONDecodeError: JSON 解析失败 + """ + # 处理输入数据 + if isinstance(data, (str, PathLike)): + path = Path(data) + if path.exists(): + with path.open("r", encoding="utf-8") as fp: + json_data = json.load(fp) + elif isinstance(data, str): + # 尝试作为 JSON 字符串解析 + json_data = json.loads(data) + else: + raise FileNotFoundError(f"文件不存在: {data}") + elif isinstance(data, dict): + json_data = data + else: + raise TypeError(f"不支持的数据类型: {type(data)}") + + # 根据格式解析数据 + if "workflow" in json_data and "reagent" in json_data: + # 格式1: workflow/reagent(已经是规范格式) + protocol_steps = json_data["workflow"] + labware_info = json_data["reagent"] + elif "steps_info" in json_data and "labware_info" in json_data: + # 格式2: steps_info/labware_info(需要规范化) + protocol_steps = normalize_steps(json_data["steps_info"]) + labware_info = normalize_labware(json_data["labware_info"]) + elif "steps" in json_data and "labware" in json_data: + # 格式3: steps/labware(另一种常见格式) + protocol_steps = normalize_steps(json_data["steps"]) + if isinstance(json_data["labware"], list): + labware_info = normalize_labware(json_data["labware"]) + else: + labware_info = json_data["labware"] + else: + raise ValueError( + "不支持的 JSON 格式。支持的格式:\n" + "1. {'workflow': [...], 'reagent': {...}}\n" + "2. {'steps_info': [...], 'labware_info': [...]}\n" + "3. {'steps': [...], 'labware': [...]}" + ) + + # 构建工作流图 + graph = build_protocol_graph( + labware_info=labware_info, + protocol_steps=protocol_steps, + workstation_name=workstation_name, + action_resource_mapping=ACTION_RESOURCE_MAPPING, + ) + + # 校验句柄配置 + if validate: + is_valid, errors = validate_workflow_handles(graph) + if not is_valid: + import warnings + + for error in errors: + warnings.warn(f"句柄校验警告: {error}") + + return graph + + +def convert_json_to_node_link( + data: Union[str, PathLike, Dict[str, Any]], + workstation_name: str = "PRCXi", +) -> Dict[str, Any]: + """ + 将 JSON 数据转换为 node-link 格式的字典 + + Args: + data: JSON 文件路径、字典数据、或 JSON 字符串 + workstation_name: 工作站名称,默认 "PRCXi" + + Returns: + Dict: node-link 格式的工作流数据 + """ + graph = convert_from_json(data, workstation_name) + return graph.to_node_link_dict() + + +def convert_json_to_workflow_list( + data: Union[str, PathLike, Dict[str, Any]], + workstation_name: str = "PRCXi", +) -> List[Dict[str, Any]]: + """ + 将 JSON 数据转换为工作流列表格式 + + Args: + data: JSON 文件路径、字典数据、或 JSON 字符串 + workstation_name: 工作站名称,默认 "PRCXi" + + Returns: + List: 工作流节点列表 + """ + graph = convert_from_json(data, workstation_name) + return graph.to_dict() + + +# 为了向后兼容,保留下划线前缀的别名 +_normalize_steps = normalize_steps +_normalize_labware = normalize_labware diff --git a/unilabos/workflow/from_python_script.py b/unilabos/workflow/from_python_script.py new file mode 100644 index 0000000..5a8ce38 --- /dev/null +++ b/unilabos/workflow/from_python_script.py @@ -0,0 +1,241 @@ +import ast +import json +from typing import Dict, List, Any, Tuple, Optional + +from .common import WorkflowGraph, RegistryAdapter + +Json = Dict[str, Any] + +# ---------------- Converter ---------------- + +class DeviceMethodConverter: + """ + - 字段统一:resource_name(原 device_class)、template_name(原 action_key) + - params 单层;inputs 使用 'params.' 前缀 + - SimpleGraph.add_workflow_node 负责变量连线与边 + """ + def __init__(self, device_registry: Optional[Dict[str, Any]] = None): + self.graph = WorkflowGraph() + self.variable_sources: Dict[str, Dict[str, Any]] = {} # var -> {node_id, output_name} + self.instance_to_resource: Dict[str, Optional[str]] = {} # 实例名 -> resource_name + self.node_id_counter: int = 0 + self.registry = RegistryAdapter(device_registry or {}) + + # ---- helpers ---- + def _new_node_id(self) -> int: + nid = self.node_id_counter + self.node_id_counter += 1 + return nid + + def _assign_targets(self, targets) -> List[str]: + names: List[str] = [] + import ast + if isinstance(targets, ast.Tuple): + for elt in targets.elts: + if isinstance(elt, ast.Name): + names.append(elt.id) + elif isinstance(targets, ast.Name): + names.append(targets.id) + return names + + def _extract_device_instantiation(self, node) -> Optional[Tuple[str, str]]: + import ast + if not isinstance(node.value, ast.Call): + return None + callee = node.value.func + if isinstance(callee, ast.Name): + class_name = callee.id + elif isinstance(callee, ast.Attribute) and isinstance(callee.value, ast.Name): + class_name = callee.attr + else: + return None + if isinstance(node.targets[0], ast.Name): + instance = node.targets[0].id + return instance, class_name + return None + + def _extract_call(self, call) -> Tuple[str, str, Dict[str, Any], str]: + import ast + owner_name, method_name, call_kind = "", "", "func" + if isinstance(call.func, ast.Attribute): + method_name = call.func.attr + if isinstance(call.func.value, ast.Name): + owner_name = call.func.value.id + call_kind = "instance" if owner_name in self.instance_to_resource else "class_or_module" + elif isinstance(call.func.value, ast.Attribute) and isinstance(call.func.value.value, ast.Name): + owner_name = call.func.value.attr + call_kind = "class_or_module" + elif isinstance(call.func, ast.Name): + method_name = call.func.id + call_kind = "func" + + def pack(node): + if isinstance(node, ast.Name): + return {"type": "variable", "value": node.id} + if isinstance(node, ast.Constant): + return {"type": "constant", "value": node.value} + if isinstance(node, ast.Dict): + return {"type": "dict", "value": self._parse_dict(node)} + if isinstance(node, ast.List): + return {"type": "list", "value": self._parse_list(node)} + return {"type": "raw", "value": ast.unparse(node) if hasattr(ast, "unparse") else str(node)} + + args: Dict[str, Any] = {} + pos: List[Any] = [] + for a in call.args: + pos.append(pack(a)) + for kw in call.keywords: + args[kw.arg] = pack(kw.value) + if pos: + args["_positional"] = pos + return owner_name, method_name, args, call_kind + + def _parse_dict(self, node) -> Dict[str, Any]: + import ast + out: Dict[str, Any] = {} + for k, v in zip(node.keys, node.values): + if isinstance(k, ast.Constant): + key = str(k.value) + if isinstance(v, ast.Name): + out[key] = f"var:{v.id}" + elif isinstance(v, ast.Constant): + out[key] = v.value + elif isinstance(v, ast.Dict): + out[key] = self._parse_dict(v) + elif isinstance(v, ast.List): + out[key] = self._parse_list(v) + return out + + def _parse_list(self, node) -> List[Any]: + import ast + out: List[Any] = [] + for elt in node.elts: + if isinstance(elt, ast.Name): + out.append(f"var:{elt.id}") + elif isinstance(elt, ast.Constant): + out.append(elt.value) + elif isinstance(elt, ast.Dict): + out.append(self._parse_dict(elt)) + elif isinstance(elt, ast.List): + out.append(self._parse_list(elt)) + return out + + def _normalize_var_tokens(self, x: Any) -> Any: + if isinstance(x, str) and x.startswith("var:"): + return {"__var__": x[4:]} + if isinstance(x, list): + return [self._normalize_var_tokens(i) for i in x] + if isinstance(x, dict): + return {k: self._normalize_var_tokens(v) for k, v in x.items()} + return x + + def _make_params_payload(self, resource_name: Optional[str], template_name: str, call_args: Dict[str, Any]) -> Dict[str, Any]: + input_keys = self.registry.get_action_input_keys(resource_name, template_name) if resource_name else [] + defaults = self.registry.get_action_goal_default(resource_name, template_name) if resource_name else {} + params: Dict[str, Any] = dict(defaults) + + def unpack(p): + t, v = p.get("type"), p.get("value") + if t == "variable": + return {"__var__": v} + if t == "dict": + return self._normalize_var_tokens(v) + if t == "list": + return self._normalize_var_tokens(v) + return v + + for k, p in call_args.items(): + if k == "_positional": + continue + params[k] = unpack(p) + + pos = call_args.get("_positional", []) + if pos: + if input_keys: + for i, p in enumerate(pos): + if i >= len(input_keys): + break + name = input_keys[i] + if name in params: + continue + params[name] = unpack(p) + else: + for i, p in enumerate(pos): + params[f"arg_{i}"] = unpack(p) + return params + + # ---- handlers ---- + def _on_assign(self, stmt): + import ast + inst = self._extract_device_instantiation(stmt) + if inst: + instance, code_class = inst + resource_name = self.registry.resolve_resource_by_classname(code_class) + self.instance_to_resource[instance] = resource_name + return + + if isinstance(stmt.value, ast.Call): + owner, method, call_args, kind = self._extract_call(stmt.value) + if kind == "instance": + device_key = owner + resource_name = self.instance_to_resource.get(owner) + else: + device_key = owner + resource_name = self.registry.resolve_resource_by_classname(owner) + + module = self.registry.get_device_module(resource_name) + params = self._make_params_payload(resource_name, method, call_args) + + nid = self._new_node_id() + self.graph.add_workflow_node( + nid, + device_key=device_key, + resource_name=resource_name, # ✅ + module=module, + template_name=method, # ✅ + params=params, + variable_sources=self.variable_sources, + add_ready_if_no_vars=True, + prev_node_id=(nid - 1) if nid > 0 else None, + ) + + out_vars = self._assign_targets(stmt.targets[0]) + for var in out_vars: + self.variable_sources[var] = {"node_id": nid, "output_name": "result"} + + def _on_expr(self, stmt): + import ast + if not isinstance(stmt.value, ast.Call): + return + owner, method, call_args, kind = self._extract_call(stmt.value) + if kind == "instance": + device_key = owner + resource_name = self.instance_to_resource.get(owner) + else: + device_key = owner + resource_name = self.registry.resolve_resource_by_classname(owner) + + module = self.registry.get_device_module(resource_name) + params = self._make_params_payload(resource_name, method, call_args) + + nid = self._new_node_id() + self.graph.add_workflow_node( + nid, + device_key=device_key, + resource_name=resource_name, # ✅ + module=module, + template_name=method, # ✅ + params=params, + variable_sources=self.variable_sources, + add_ready_if_no_vars=True, + prev_node_id=(nid - 1) if nid > 0 else None, + ) + + def convert(self, python_code: str): + tree = ast.parse(python_code) + for stmt in tree.body: + if isinstance(stmt, ast.Assign): + self._on_assign(stmt) + elif isinstance(stmt, ast.Expr): + self._on_expr(stmt) + return self diff --git a/unilabos/workflow/from_xdl.py b/unilabos/workflow/from_xdl.py new file mode 100644 index 0000000..1041f9a --- /dev/null +++ b/unilabos/workflow/from_xdl.py @@ -0,0 +1,131 @@ +from typing import List, Any, Dict +import xml.etree.ElementTree as ET + + +def convert_to_type(val: str) -> Any: + """将字符串值转换为适当的数据类型""" + if val == "True": + return True + if val == "False": + return False + if val == "?": + return None + if val.endswith(" g"): + return float(val.split(" ")[0]) + if val.endswith("mg"): + return float(val.split("mg")[0]) + elif val.endswith("mmol"): + return float(val.split("mmol")[0]) / 1000 + elif val.endswith("mol"): + return float(val.split("mol")[0]) + elif val.endswith("ml"): + return float(val.split("ml")[0]) + elif val.endswith("RPM"): + return float(val.split("RPM")[0]) + elif val.endswith(" °C"): + return float(val.split(" ")[0]) + elif val.endswith(" %"): + return float(val.split(" ")[0]) + return val + + +def flatten_xdl_procedure(procedure_elem: ET.Element) -> List[ET.Element]: + """展平嵌套的XDL程序结构""" + flattened_operations = [] + TEMP_UNSUPPORTED_PROTOCOL = ["Purge", "Wait", "Stir", "ResetHandling"] + + def extract_operations(element: ET.Element): + if element.tag not in ["Prep", "Reaction", "Workup", "Purification", "Procedure"]: + if element.tag not in TEMP_UNSUPPORTED_PROTOCOL: + flattened_operations.append(element) + + for child in element: + extract_operations(child) + + for child in procedure_elem: + extract_operations(child) + + return flattened_operations + + +def parse_xdl_content(xdl_content: str) -> tuple: + """解析XDL内容""" + try: + xdl_content_cleaned = "".join(c for c in xdl_content if c.isprintable()) + root = ET.fromstring(xdl_content_cleaned) + + synthesis_elem = root.find("Synthesis") + if synthesis_elem is None: + return None, None, None + + # 解析硬件组件 + hardware_elem = synthesis_elem.find("Hardware") + hardware = [] + if hardware_elem is not None: + hardware = [{"id": c.get("id"), "type": c.get("type")} for c in hardware_elem.findall("Component")] + + # 解析试剂 + reagents_elem = synthesis_elem.find("Reagents") + reagents = [] + if reagents_elem is not None: + reagents = [{"name": r.get("name"), "role": r.get("role", "")} for r in reagents_elem.findall("Reagent")] + + # 解析程序 + procedure_elem = synthesis_elem.find("Procedure") + if procedure_elem is None: + return None, None, None + + flattened_operations = flatten_xdl_procedure(procedure_elem) + return hardware, reagents, flattened_operations + + except ET.ParseError as e: + raise ValueError(f"Invalid XDL format: {e}") + + +def convert_xdl_to_dict(xdl_content: str) -> Dict[str, Any]: + """ + 将XDL XML格式转换为标准的字典格式 + + Args: + xdl_content: XDL XML内容 + + Returns: + 转换结果,包含步骤和器材信息 + """ + try: + hardware, reagents, flattened_operations = parse_xdl_content(xdl_content) + if hardware is None: + return {"error": "Failed to parse XDL content", "success": False} + + # 将XDL元素转换为字典格式 + steps_data = [] + for elem in flattened_operations: + # 转换参数类型 + parameters = {} + for key, val in elem.attrib.items(): + converted_val = convert_to_type(val) + if converted_val is not None: + parameters[key] = converted_val + + step_dict = { + "operation": elem.tag, + "parameters": parameters, + "description": elem.get("purpose", f"Operation: {elem.tag}"), + } + steps_data.append(step_dict) + + # 合并硬件和试剂为统一的labware_info格式 + labware_data = [] + labware_data.extend({"id": hw["id"], "type": "hardware", **hw} for hw in hardware) + labware_data.extend({"name": reagent["name"], "type": "reagent", **reagent} for reagent in reagents) + + return { + "success": True, + "steps": steps_data, + "labware": labware_data, + "message": f"Successfully converted XDL to dict format. Found {len(steps_data)} steps and {len(labware_data)} labware items.", + } + + except Exception as e: + error_msg = f"XDL conversion failed: {str(e)}" + return {"error": error_msg, "success": False} diff --git a/unilabos/workflow/wf_utils.py b/unilabos/workflow/wf_utils.py new file mode 100644 index 0000000..f2dfc8c --- /dev/null +++ b/unilabos/workflow/wf_utils.py @@ -0,0 +1,138 @@ +""" +工作流工具模块 + +提供工作流上传等功能 +""" + +import json +import os +import uuid +from typing import Any, Dict, List, Optional + +from unilabos.utils.banner_print import print_status + + +def _is_node_link_format(data: Dict[str, Any]) -> bool: + """检查数据是否为 node-link 格式""" + return "nodes" in data and "edges" in data + + +def _convert_to_node_link(workflow_file: str, workflow_data: Dict[str, Any]) -> Dict[str, Any]: + """ + 将非 node-link 格式的工作流数据转换为 node-link 格式 + + Args: + workflow_file: 工作流文件路径(用于日志) + workflow_data: 原始工作流数据 + + Returns: + node-link 格式的工作流数据 + """ + from unilabos.workflow.convert_from_json import convert_json_to_node_link + + print_status(f"检测到非 node-link 格式,正在转换...", "info") + node_link_data = convert_json_to_node_link(workflow_data) + print_status(f"转换完成", "success") + return node_link_data + + +def upload_workflow( + workflow_file: str, + workflow_name: Optional[str] = None, + tags: Optional[List[str]] = None, + published: bool = False, +) -> Dict[str, Any]: + """ + 上传工作流到服务器 + + 支持的输入格式: + 1. node-link 格式: {"nodes": [...], "edges": [...]} + 2. workflow/reagent 格式: {"workflow": [...], "reagent": {...}} + 3. steps_info/labware_info 格式: {"steps_info": [...], "labware_info": [...]} + 4. steps/labware 格式: {"steps": [...], "labware": [...]} + + Args: + workflow_file: 工作流文件路径(JSON格式) + workflow_name: 工作流名称,如果不提供则从文件中读取或使用文件名 + tags: 工作流标签列表,默认为空列表 + published: 是否发布工作流,默认为False + + Returns: + Dict: API响应数据 + """ + # 延迟导入,避免在配置文件加载之前初始化 http_client + from unilabos.app.web import http_client + + if not os.path.exists(workflow_file): + print_status(f"工作流文件不存在: {workflow_file}", "error") + return {"code": -1, "message": f"文件不存在: {workflow_file}"} + + # 读取工作流文件 + try: + with open(workflow_file, "r", encoding="utf-8") as f: + workflow_data = json.load(f) + except json.JSONDecodeError as e: + print_status(f"工作流文件JSON解析失败: {e}", "error") + return {"code": -1, "message": f"JSON解析失败: {e}"} + + # 自动检测并转换格式 + if not _is_node_link_format(workflow_data): + try: + workflow_data = _convert_to_node_link(workflow_file, workflow_data) + except Exception as e: + print_status(f"工作流格式转换失败: {e}", "error") + return {"code": -1, "message": f"格式转换失败: {e}"} + + # 提取工作流数据 + nodes = workflow_data.get("nodes", []) + edges = workflow_data.get("edges", []) + workflow_uuid_val = workflow_data.get("workflow_uuid", str(uuid.uuid4())) + wf_name_from_file = workflow_data.get("workflow_name", os.path.basename(workflow_file).replace(".json", "")) + + # 确定工作流名称 + final_name = workflow_name or wf_name_from_file + + print_status(f"正在上传工作流: {final_name}", "info") + print_status(f" - 节点数量: {len(nodes)}", "info") + print_status(f" - 边数量: {len(edges)}", "info") + print_status(f" - 标签: {tags or []}", "info") + print_status(f" - 发布状态: {published}", "info") + + # 调用 http_client 上传 + result = http_client.workflow_import( + name=final_name, + workflow_uuid=workflow_uuid_val, + workflow_name=final_name, + nodes=nodes, + edges=edges, + tags=tags, + published=published, + ) + + if result.get("code") == 0: + data = result.get("data", {}) + print_status("工作流上传成功!", "success") + print_status(f" - UUID: {data.get('uuid', 'N/A')}", "info") + print_status(f" - 名称: {data.get('name', 'N/A')}", "info") + else: + print_status(f"工作流上传失败: {result.get('message', '未知错误')}", "error") + + return result + + +def handle_workflow_upload_command(args_dict: Dict[str, Any]) -> None: + """ + 处理 workflow_upload 子命令 + + Args: + args_dict: 命令行参数字典 + """ + workflow_file = args_dict.get("workflow_file") + workflow_name = args_dict.get("workflow_name") + tags = args_dict.get("tags", []) + published = args_dict.get("published", False) + + if workflow_file: + upload_workflow(workflow_file, workflow_name, tags, published) + else: + print_status("未指定工作流文件路径,请使用 -f/--workflow_file 参数", "error")