From 965bf36e8d55f1c6298360d13cc83fb240277cad Mon Sep 17 00:00:00 2001 From: Xuwznln <18435084+Xuwznln@users.noreply.github.com> Date: Sun, 11 Jan 2026 21:25:36 +0800 Subject: [PATCH] Add restart. Temp allow action message. --- unilabos/app/main.py | 15 ++- unilabos/app/utils.py | 144 ++++++++++++++++++++++++ unilabos/app/web/server.py | 42 ++++++- unilabos/app/ws_client.py | 66 ++++++++--- unilabos/ros/main_slave_run.py | 9 +- unilabos/ros/nodes/presets/host_node.py | 105 ++++++++++++++--- 6 files changed, 342 insertions(+), 39 deletions(-) create mode 100644 unilabos/app/utils.py diff --git a/unilabos/app/main.py b/unilabos/app/main.py index 3ad7310..8ec26c0 100644 --- a/unilabos/app/main.py +++ b/unilabos/app/main.py @@ -19,6 +19,11 @@ if unilabos_dir not in sys.path: from unilabos.utils.banner_print import print_status, print_unilab_banner from unilabos.config.config import load_config, BasicConfig, HTTPConfig +from unilabos.app.utils import cleanup_for_restart + +# Global restart flags (used by ws_client and web/server) +_restart_requested: bool = False +_restart_reason: str = "" def load_config_from_file(config_path): @@ -503,13 +508,19 @@ def main(): time.sleep(1) else: start_backend(**args_dict) - start_server( + restart_requested = start_server( open_browser=not args_dict["disable_browser"], port=BasicConfig.port, ) + if restart_requested: + print_status("[Main] Restart requested, cleaning up...", "info") + cleanup_for_restart() + return else: start_backend(**args_dict) - start_server( + + # 启动服务器(默认支持WebSocket触发重启) + restart_requested = start_server( open_browser=not args_dict["disable_browser"], port=BasicConfig.port, ) diff --git a/unilabos/app/utils.py b/unilabos/app/utils.py new file mode 100644 index 0000000..d10c2e0 --- /dev/null +++ b/unilabos/app/utils.py @@ -0,0 +1,144 @@ +""" +UniLabOS 应用工具函数 + +提供清理、重启等工具函数 +""" + +import gc +import os +import threading +import time + +from unilabos.utils.banner_print import print_status + + +def cleanup_for_restart() -> bool: + """ + Clean up all resources for restart without exiting the process. + + This function prepares the system for re-initialization by: + 1. Stopping all communication clients + 2. Destroying ROS nodes + 3. Resetting singletons + 4. Waiting for threads to finish + + Returns: + bool: True if cleanup was successful, False otherwise + """ + print_status("[Restart] Starting cleanup for restart...", "info") + + # Step 1: Stop WebSocket communication client + print_status("[Restart] Step 1: Stopping WebSocket client...", "info") + try: + from unilabos.app.communication import get_communication_client + + comm_client = get_communication_client() + if comm_client is not None: + comm_client.stop() + print_status("[Restart] WebSocket client stopped", "info") + except Exception as e: + print_status(f"[Restart] Error stopping WebSocket: {e}", "warning") + + # Step 2: Get HostNode and cleanup ROS + print_status("[Restart] Step 2: Cleaning up ROS nodes...", "info") + try: + from unilabos.ros.nodes.presets.host_node import HostNode + import rclpy + from rclpy.timer import Timer + + host_instance = HostNode.get_instance(timeout=5) + if host_instance is not None: + print_status(f"[Restart] Found HostNode: {host_instance.device_id}", "info") + + # Gracefully shutdown background threads + print_status("[Restart] Shutting down background threads...", "info") + HostNode.shutdown_background_threads(timeout=5.0) + print_status("[Restart] Background threads shutdown complete", "info") + + # Stop discovery timer + if hasattr(host_instance, "_discovery_timer") and isinstance(host_instance._discovery_timer, Timer): + host_instance._discovery_timer.cancel() + print_status("[Restart] Discovery timer cancelled", "info") + + # Destroy device nodes + device_count = len(host_instance.devices_instances) + print_status(f"[Restart] Destroying {device_count} device instances...", "info") + for device_id, device_node in list(host_instance.devices_instances.items()): + try: + if hasattr(device_node, "ros_node_instance") and device_node.ros_node_instance is not None: + device_node.ros_node_instance.destroy_node() + print_status(f"[Restart] Device {device_id} destroyed", "info") + except Exception as e: + print_status(f"[Restart] Error destroying device {device_id}: {e}", "warning") + + # Clear devices instances + host_instance.devices_instances.clear() + host_instance.devices_names.clear() + + # Destroy host node + try: + host_instance.destroy_node() + print_status("[Restart] HostNode destroyed", "info") + except Exception as e: + print_status(f"[Restart] Error destroying HostNode: {e}", "warning") + + # Reset HostNode state + HostNode.reset_state() + print_status("[Restart] HostNode state reset", "info") + + # Shutdown executor first (to stop executor.spin() gracefully) + if hasattr(rclpy, "__executor") and rclpy.__executor is not None: + try: + rclpy.__executor.shutdown() + rclpy.__executor = None # Clear for restart + print_status("[Restart] ROS executor shutdown complete", "info") + except Exception as e: + print_status(f"[Restart] Error shutting down executor: {e}", "warning") + + # Shutdown rclpy + if rclpy.ok(): + rclpy.shutdown() + print_status("[Restart] rclpy shutdown complete", "info") + + except ImportError as e: + print_status(f"[Restart] ROS modules not available: {e}", "warning") + except Exception as e: + print_status(f"[Restart] Error in ROS cleanup: {e}", "warning") + return False + + # Step 3: Reset communication client singleton + print_status("[Restart] Step 3: Resetting singletons...", "info") + try: + from unilabos.app import communication + + if hasattr(communication, "_communication_client"): + communication._communication_client = None + print_status("[Restart] Communication client singleton reset", "info") + except Exception as e: + print_status(f"[Restart] Error resetting communication singleton: {e}", "warning") + + # Step 4: Wait for threads to finish + print_status("[Restart] Step 4: Waiting for threads to finish...", "info") + time.sleep(3) # Give threads time to finish + + # Check remaining threads + remaining_threads = [] + for t in threading.enumerate(): + if t.name != "MainThread" and t.is_alive(): + remaining_threads.append(t.name) + + if remaining_threads: + print_status( + f"[Restart] Warning: {len(remaining_threads)} threads still running: {remaining_threads}", "warning" + ) + else: + print_status("[Restart] All threads stopped", "info") + + # Step 5: Force garbage collection + print_status("[Restart] Step 5: Running garbage collection...", "info") + gc.collect() + gc.collect() # Run twice for weak references + print_status("[Restart] Garbage collection complete", "info") + + print_status("[Restart] Cleanup complete. Ready for re-initialization.", "info") + return True diff --git a/unilabos/app/web/server.py b/unilabos/app/web/server.py index 2a85d10..8d09016 100644 --- a/unilabos/app/web/server.py +++ b/unilabos/app/web/server.py @@ -6,7 +6,6 @@ Web服务器模块 import webbrowser -import uvicorn from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from starlette.responses import Response @@ -96,7 +95,7 @@ def setup_server() -> FastAPI: return app -def start_server(host: str = "0.0.0.0", port: int = 8002, open_browser: bool = True) -> None: +def start_server(host: str = "0.0.0.0", port: int = 8002, open_browser: bool = True) -> bool: """ 启动服务器 @@ -104,7 +103,14 @@ def start_server(host: str = "0.0.0.0", port: int = 8002, open_browser: bool = T host: 服务器主机 port: 服务器端口 open_browser: 是否自动打开浏览器 + + Returns: + bool: True if restart was requested, False otherwise """ + import threading + import time + from uvicorn import Config, Server + # 设置服务器 setup_server() @@ -123,7 +129,37 @@ def start_server(host: str = "0.0.0.0", port: int = 8002, open_browser: bool = T # 启动服务器 info(f"[Web] 启动FastAPI服务器: {host}:{port}") - uvicorn.run(app, host=host, port=port, log_config=log_config) + + # 使用支持重启的模式 + config = Config(app=app, host=host, port=port, log_config=log_config) + server = Server(config) + + # 启动服务器线程 + server_thread = threading.Thread(target=server.run, daemon=True, name="uvicorn_server") + server_thread.start() + + info("[Web] Server started, monitoring for restart requests...") + + # 监控重启标志 + import unilabos.app.main as main_module + + while server_thread.is_alive(): + if hasattr(main_module, "_restart_requested") and main_module._restart_requested: + info( + f"[Web] Restart requested via WebSocket, reason: {getattr(main_module, '_restart_reason', 'unknown')}" + ) + main_module._restart_requested = False + + # 停止服务器 + server.should_exit = True + server_thread.join(timeout=5) + + info("[Web] Server stopped, ready for restart") + return True + + time.sleep(1) + + return False # 当脚本直接运行时启动服务器 diff --git a/unilabos/app/ws_client.py b/unilabos/app/ws_client.py index 4933b61..4c87d36 100644 --- a/unilabos/app/ws_client.py +++ b/unilabos/app/ws_client.py @@ -488,11 +488,16 @@ class MessageProcessor: async for message in self.websocket: try: data = json.loads(message) + message_type = data.get("action", "") + message_data = data.get("data") if self.session_id and self.session_id == data.get("edge_session"): - await self._process_message(data) + await self._process_message(message_type, message_data) else: - logger.trace(f"[MessageProcessor] 收到一条归属 {data.get('edge_session')} 的旧消息:{data}") - logger.debug(f"[MessageProcessor] 跳过了一条归属 {data.get('edge_session')} 的旧消息: {data.get('action')}") + if message_type.endswith("_material"): + logger.trace(f"[MessageProcessor] 收到一条归属 {data.get('edge_session')} 的旧消息:{data}") + logger.debug(f"[MessageProcessor] 跳过了一条归属 {data.get('edge_session')} 的旧消息: {data.get('action')}") + else: + await self._process_message(message_type, message_data) except json.JSONDecodeError: logger.error(f"[MessageProcessor] Invalid JSON received: {message}") except Exception as e: @@ -558,11 +563,8 @@ class MessageProcessor: finally: logger.debug("[MessageProcessor] Send handler stopped") - async def _process_message(self, data: Dict[str, Any]): + async def _process_message(self, message_type: str, message_data: Dict[str, Any]): """处理收到的消息""" - message_type = data.get("action", "") - message_data = data.get("data") - logger.debug(f"[MessageProcessor] Processing message: {message_type}") try: @@ -575,16 +577,19 @@ class MessageProcessor: elif message_type == "cancel_action" or message_type == "cancel_task": await self._handle_cancel_action(message_data) elif message_type == "add_material": + # noinspection PyTypeChecker await self._handle_resource_tree_update(message_data, "add") elif message_type == "update_material": + # noinspection PyTypeChecker await self._handle_resource_tree_update(message_data, "update") elif message_type == "remove_material": + # noinspection PyTypeChecker await self._handle_resource_tree_update(message_data, "remove") # elif message_type == "session_id": # self.session_id = message_data.get("session_id") # logger.info(f"[MessageProcessor] Session ID: {self.session_id}") - elif message_type == "request_reload": - await self._handle_request_reload(message_data) + elif message_type == "request_restart": + await self._handle_request_restart(message_data) else: logger.debug(f"[MessageProcessor] Unknown message type: {message_type}") @@ -894,19 +899,48 @@ class MessageProcessor: ) thread.start() - async def _handle_request_reload(self, data: Dict[str, Any]): + async def _handle_request_restart(self, data: Dict[str, Any]): """ - 处理重载请求 + 处理重启请求 - 当LabGo发送request_reload时,重新发送设备注册信息 + 当LabGo发送request_restart时,执行清理并触发重启 """ reason = data.get("reason", "unknown") - logger.info(f"[MessageProcessor] Received reload request, reason: {reason}") + delay = data.get("delay", 2) # 默认延迟2秒 + logger.info(f"[MessageProcessor] Received restart request, reason: {reason}, delay: {delay}s") - # 重新发送host_node_ready信息 + # 发送确认消息 if self.websocket_client: - self.websocket_client.publish_host_ready() - logger.info("[MessageProcessor] Re-sent host_node_ready after reload request") + await self.websocket_client.send_message({ + "action": "restart_acknowledged", + "data": {"reason": reason, "delay": delay} + }) + + # 设置全局重启标志 + import unilabos.app.main as main_module + main_module._restart_requested = True + main_module._restart_reason = reason + + # 延迟后执行清理 + await asyncio.sleep(delay) + + # 在新线程中执行清理,避免阻塞当前事件循环 + def do_cleanup(): + import time + time.sleep(0.5) # 给当前消息处理完成的时间 + logger.info(f"[MessageProcessor] Starting cleanup for restart, reason: {reason}") + try: + from unilabos.app.utils import cleanup_for_restart + if cleanup_for_restart(): + logger.info("[MessageProcessor] Cleanup successful, main() will restart") + else: + logger.error("[MessageProcessor] Cleanup failed") + except Exception as e: + logger.error(f"[MessageProcessor] Error during cleanup: {e}") + + cleanup_thread = threading.Thread(target=do_cleanup, name="RestartCleanupThread", daemon=True) + cleanup_thread.start() + logger.info(f"[MessageProcessor] Restart cleanup scheduled") async def _send_action_state_response( self, device_id: str, action_name: str, task_id: str, job_id: str, typ: str, free: bool, need_more: int diff --git a/unilabos/ros/main_slave_run.py b/unilabos/ros/main_slave_run.py index b79c368..c24f9e8 100644 --- a/unilabos/ros/main_slave_run.py +++ b/unilabos/ros/main_slave_run.py @@ -1,4 +1,5 @@ import json + # from nt import device_encoding import threading import time @@ -55,7 +56,11 @@ def main( ) -> None: """主函数""" - rclpy.init(args=rclpy_init_args) + # Support restart - check if rclpy is already initialized + if not rclpy.ok(): + rclpy.init(args=rclpy_init_args) + else: + logger.info("[ROS] rclpy already initialized, reusing context") executor = rclpy.__executor = MultiThreadedExecutor() # 创建主机节点 host_node = HostNode( @@ -88,7 +93,7 @@ def main( joint_republisher = JointRepublisher("joint_republisher", host_node.resource_tracker) # lh_joint_pub = LiquidHandlerJointPublisher( # resources_config=resources_list, resource_tracker=host_node.resource_tracker - # ) + # ) executor.add_node(resource_mesh_manager) executor.add_node(joint_republisher) # executor.add_node(lh_joint_pub) diff --git a/unilabos/ros/nodes/presets/host_node.py b/unilabos/ros/nodes/presets/host_node.py index e0a66bf..69c12f8 100644 --- a/unilabos/ros/nodes/presets/host_node.py +++ b/unilabos/ros/nodes/presets/host_node.py @@ -70,6 +70,8 @@ class HostNode(BaseROS2DeviceNode): _instance: ClassVar[Optional["HostNode"]] = None _ready_event: ClassVar[threading.Event] = threading.Event() + _shutting_down: ClassVar[bool] = False # Flag to signal shutdown to background threads + _background_threads: ClassVar[List[threading.Thread]] = [] # Track all background threads for cleanup _device_action_status: ClassVar[collections.defaultdict[str, DeviceActionStatus]] = collections.defaultdict( DeviceActionStatus ) @@ -81,6 +83,48 @@ class HostNode(BaseROS2DeviceNode): return cls._instance return None + @classmethod + def shutdown_background_threads(cls, timeout: float = 5.0) -> None: + """ + Gracefully shutdown all background threads for clean exit or restart. + + This method: + 1. Sets shutdown flag to stop background operations + 2. Waits for background threads to finish with timeout + 3. Cleans up finished threads from tracking list + + Args: + timeout: Maximum time to wait for each thread (seconds) + """ + cls._shutting_down = True + + # Wait for background threads to finish + active_threads = [] + for t in cls._background_threads: + if t.is_alive(): + t.join(timeout=timeout) + if t.is_alive(): + active_threads.append(t.name) + + if active_threads: + logger.warning(f"[Host Node] Some background threads still running: {active_threads}") + + # Clear the thread list + cls._background_threads.clear() + logger.info(f"[Host Node] Background threads shutdown complete") + + @classmethod + def reset_state(cls) -> None: + """ + Reset the HostNode singleton state for restart or clean exit. + Call this after destroying the instance. + """ + cls._instance = None + cls._ready_event.clear() + cls._shutting_down = False + cls._background_threads.clear() + logger.info("[Host Node] State reset complete") + def __init__( self, device_id: str, @@ -294,12 +338,37 @@ class HostNode(BaseROS2DeviceNode): bridge.publish_host_ready() self.lab_logger().debug(f"Host ready signal sent via {bridge.__class__.__name__}") - def _send_re_register(self, sclient): - sclient.wait_for_service() - request = SerialCommand.Request() - request.command = "" - future = sclient.call_async(request) - response = future.result() + def _send_re_register(self, sclient, device_namespace: str): + """ + Send re-register command to a device. This is a one-time operation. + + Args: + sclient: The service client + device_namespace: The device namespace for logging + """ + try: + # Use timeout to prevent indefinite blocking + if not sclient.wait_for_service(timeout_sec=10.0): + self.lab_logger().debug(f"[Host Node] Re-register timeout for {device_namespace}") + return + + # Check shutdown flag after wait + if self._shutting_down: + self.lab_logger().debug(f"[Host Node] Re-register aborted for {device_namespace} (shutdown)") + return + + request = SerialCommand.Request() + request.command = "" + future = sclient.call_async(request) + # Use timeout for result as well + future.result(timeout_sec=5.0) + self.lab_logger().debug(f"[Host Node] Re-register completed for {device_namespace}") + except Exception as e: + # Gracefully handle destruction during shutdown + if "destruction was requested" in str(e) or self._shutting_down: + self.lab_logger().debug(f"[Host Node] Re-register aborted for {device_namespace} (cleanup)") + else: + self.lab_logger().warning(f"[Host Node] Re-register failed for {device_namespace}: {e}") def _discover_devices(self) -> None: """ @@ -331,23 +400,27 @@ class HostNode(BaseROS2DeviceNode): self._create_action_clients_for_device(device_id, namespace) self._online_devices.add(device_key) sclient = self.create_client(SerialCommand, f"/srv{namespace}/re_register_device") - threading.Thread( + t = threading.Thread( target=self._send_re_register, - args=(sclient,), + args=(sclient, namespace), daemon=True, name=f"ROSDevice{self.device_id}_re_register_device_{namespace}", - ).start() + ) + self._background_threads.append(t) + t.start() elif device_key not in self._online_devices: # 设备重新上线 self.lab_logger().info(f"[Host Node] Device reconnected: {device_key}") self._online_devices.add(device_key) sclient = self.create_client(SerialCommand, f"/srv{namespace}/re_register_device") - threading.Thread( + t = threading.Thread( target=self._send_re_register, - args=(sclient,), + args=(sclient, namespace), daemon=True, name=f"ROSDevice{self.device_id}_re_register_device_{namespace}", - ).start() + ) + self._background_threads.append(t) + t.start() # 检测离线设备 offline_devices = self._online_devices - current_devices @@ -705,13 +778,14 @@ class HostNode(BaseROS2DeviceNode): raise ValueError(f"ActionClient {action_id} not found.") action_client: ActionClient = self._action_clients[action_id] + # 遍历action_kwargs下的所有子dict,将"sample_uuid"的值赋给"sample_id" def assign_sample_id(obj): if isinstance(obj, dict): if "sample_uuid" in obj: obj["sample_id"] = obj["sample_uuid"] obj.pop("sample_uuid") - for k,v in obj.items(): + for k, v in obj.items(): if k != "unilabos_extra": assign_sample_id(v) elif isinstance(obj, list): @@ -742,9 +816,7 @@ class HostNode(BaseROS2DeviceNode): self.lab_logger().info(f"[Host Node] Goal {action_id} ({item.job_id}) accepted") self._goals[item.job_id] = goal_handle goal_future = goal_handle.get_result_async() - goal_future.add_done_callback( - lambda f: self.get_result_callback(item, action_id, f) - ) + goal_future.add_done_callback(lambda f: self.get_result_callback(item, action_id, f)) goal_future.result() def feedback_callback(self, item: "QueueItem", action_id: str, feedback_msg) -> None: @@ -1167,6 +1239,7 @@ class HostNode(BaseROS2DeviceNode): """ try: from unilabos.app.web import http_client + data = json.loads(request.command) if "uuid" in data and data["uuid"] is not None: http_req = http_client.resource_tree_get([data["uuid"]], data["with_children"])