diff --git a/unilabos/app/ws_client.py b/unilabos/app/ws_client.py index 61e26f00..af802445 100644 --- a/unilabos/app/ws_client.py +++ b/unilabos/app/ws_client.py @@ -221,6 +221,32 @@ class WebSocketClient(BaseCommunicationClient): except Exception as e: logger.error(f"[WebSocket] Message handler error: {str(e)}") + async def _handle_query_state(self, data: Dict[str, str]) -> None: + device_id = data.get("device_id", "") + if not device_id: + logger.error("[WebSocket] query_action_state missing device_id") + return + action_name = data.get("action_name", "") + if not action_name: + logger.error("[WebSocket] query_action_state missing action_name") + return + task_id = data.get("task_id", "") + if not task_id: + logger.error("[WebSocket] query_action_state missing task_id") + return + device_action_key = f"/devices/{device_id}/{action_name}" + action_jobs = len(HostNode.get_instance()._device_action_status[device_action_key].job_ids) + message = { + "type": "report_action_state", + "data": { + "device_id": device_id, + "action_name": action_name, + "task_id": task_id, + "free": bool(action_jobs) + }, + } + await self._send_message(message) + async def _process_message(self, input_message: Dict[str, Any]): """处理收到的消息""" message_type = input_message.get("type", "") @@ -231,6 +257,8 @@ class WebSocketClient(BaseCommunicationClient): elif message_type == "pong": # 处理pong响应 self._handle_pong_sync(data) + elif message_type == "query_action_state": + await self._handle_query_state(data) else: logger.debug(f"[WebSocket] Unknown message type: {message_type}") diff --git a/unilabos/ros/nodes/presets/host_node.py b/unilabos/ros/nodes/presets/host_node.py index 4c7053c1..619fd377 100644 --- a/unilabos/ros/nodes/presets/host_node.py +++ b/unilabos/ros/nodes/presets/host_node.py @@ -1,5 +1,6 @@ import collections import copy +from dataclasses import dataclass, field import json import threading import time @@ -42,6 +43,11 @@ from unilabos.utils.exception import DeviceClassInvalid from unilabos.utils.type_check import serialize_result_info +@dataclass +class DeviceActionStatus: + job_ids: Dict[str, float] = field(default_factory=dict) + + class HostNode(BaseROS2DeviceNode): """ 主机节点类,负责管理设备、资源和控制器 @@ -51,6 +57,9 @@ class HostNode(BaseROS2DeviceNode): _instance: ClassVar[Optional["HostNode"]] = None _ready_event: ClassVar[threading.Event] = threading.Event() + _device_action_status: ClassVar[collections.defaultdict[str, DeviceActionStatus]] = collections.defaultdict( + DeviceActionStatus + ) @classmethod def get_instance(cls, timeout=None) -> Optional["HostNode"]: @@ -630,6 +639,12 @@ class HostNode(BaseROS2DeviceNode): goal_uuid: 目标UUID,如果为None则自动生成 server_info: 服务器发送信息,包含发送时间戳等 """ + if goal_uuid is None: + u = uuid.uuid4() + else: + u = uuid.UUID(goal_uuid) + device_action_key = f"/devices/{device_id}/{action_name}" + self._device_action_status[device_action_key].job_ids[str(u)] = time.time() if action_type.startswith("UniLabJsonCommand"): if action_name.startswith("auto-"): action_name = action_name[5:] @@ -657,22 +672,18 @@ class HostNode(BaseROS2DeviceNode): self.lab_logger().info(f"[Host Node] Sending goal for {action_id}: {goal_msg}") action_client.wait_for_server() - - uuid_str = goal_uuid - if goal_uuid is not None: - u = uuid.UUID(goal_uuid) - goal_uuid_obj = UUID(uuid=list(u.bytes)) - else: - goal_uuid_obj = None + goal_uuid_obj = UUID(uuid=list(u.bytes)) future = action_client.send_goal_async( goal_msg, - feedback_callback=lambda feedback_msg: self.feedback_callback(action_id, uuid_str, feedback_msg), + feedback_callback=lambda feedback_msg: self.feedback_callback(action_id, str(u), feedback_msg), goal_uuid=goal_uuid_obj, ) - future.add_done_callback(lambda future: self.goal_response_callback(action_id, uuid_str, future)) + future.add_done_callback( + lambda future: self.goal_response_callback(device_action_key, action_id, str(u), future) + ) - def goal_response_callback(self, action_id: str, uuid_str: Optional[str], future) -> None: + def goal_response_callback(self, device_action_key: str, action_id: str, uuid_str: str, future) -> None: """目标响应回调""" goal_handle = future.result() if not goal_handle.accepted: @@ -680,30 +691,28 @@ class HostNode(BaseROS2DeviceNode): return self.lab_logger().info(f"[Host Node] Goal {action_id} ({uuid_str}) accepted") - if uuid_str: - self._goals[uuid_str] = goal_handle - goal_handle.get_result_async().add_done_callback( - lambda future: self.get_result_callback(action_id, uuid_str, future) - ) + self._goals[uuid_str] = goal_handle + goal_handle.get_result_async().add_done_callback( + lambda future: self.get_result_callback(device_action_key, action_id, uuid_str, future) + ) - def feedback_callback(self, action_id: str, uuid_str: Optional[str], feedback_msg) -> None: + def feedback_callback(self, action_id: str, uuid_str: str, feedback_msg) -> None: """反馈回调""" feedback_data = convert_from_ros_msg(feedback_msg) feedback_data.pop("goal_id") self.lab_logger().debug(f"[Host Node] Feedback for {action_id} ({uuid_str}): {feedback_data}") - if uuid_str: - for bridge in self.bridges: - if hasattr(bridge, "publish_job_status"): - bridge.publish_job_status(feedback_data, uuid_str, "running") + for bridge in self.bridges: + if hasattr(bridge, "publish_job_status"): + bridge.publish_job_status(feedback_data, uuid_str, "running") - def get_result_callback(self, action_id: str, uuid_str: Optional[str], future) -> None: + def get_result_callback(self, device_action_key: str, action_id: str, uuid_str: str, future) -> None: """获取结果回调""" result_msg = future.result().result result_data = convert_from_ros_msg(result_msg) status = "success" return_info_str = result_data.get("return_info") - + self._device_action_status[device_action_key].job_ids.pop(uuid_str) if return_info_str is not None: try: ret = json.loads(return_info_str)