Add restart.

Temp allow action message.
This commit is contained in:
Xuwznln
2026-01-11 21:25:36 +08:00
parent aacf3497e0
commit 965bf36e8d
6 changed files with 342 additions and 39 deletions

View File

@@ -19,6 +19,11 @@ if unilabos_dir not in sys.path:
from unilabos.utils.banner_print import print_status, print_unilab_banner from unilabos.utils.banner_print import print_status, print_unilab_banner
from unilabos.config.config import load_config, BasicConfig, HTTPConfig from unilabos.config.config import load_config, BasicConfig, HTTPConfig
from unilabos.app.utils import cleanup_for_restart
# Global restart flags (used by ws_client and web/server)
_restart_requested: bool = False
_restart_reason: str = ""
def load_config_from_file(config_path): def load_config_from_file(config_path):
@@ -503,13 +508,19 @@ def main():
time.sleep(1) time.sleep(1)
else: else:
start_backend(**args_dict) start_backend(**args_dict)
start_server( restart_requested = start_server(
open_browser=not args_dict["disable_browser"], open_browser=not args_dict["disable_browser"],
port=BasicConfig.port, port=BasicConfig.port,
) )
if restart_requested:
print_status("[Main] Restart requested, cleaning up...", "info")
cleanup_for_restart()
return
else: else:
start_backend(**args_dict) start_backend(**args_dict)
start_server(
# 启动服务器默认支持WebSocket触发重启
restart_requested = start_server(
open_browser=not args_dict["disable_browser"], open_browser=not args_dict["disable_browser"],
port=BasicConfig.port, port=BasicConfig.port,
) )

144
unilabos/app/utils.py Normal file
View File

@@ -0,0 +1,144 @@
"""
UniLabOS 应用工具函数
提供清理、重启等工具函数
"""
import gc
import os
import threading
import time
from unilabos.utils.banner_print import print_status
def cleanup_for_restart() -> bool:
"""
Clean up all resources for restart without exiting the process.
This function prepares the system for re-initialization by:
1. Stopping all communication clients
2. Destroying ROS nodes
3. Resetting singletons
4. Waiting for threads to finish
Returns:
bool: True if cleanup was successful, False otherwise
"""
print_status("[Restart] Starting cleanup for restart...", "info")
# Step 1: Stop WebSocket communication client
print_status("[Restart] Step 1: Stopping WebSocket client...", "info")
try:
from unilabos.app.communication import get_communication_client
comm_client = get_communication_client()
if comm_client is not None:
comm_client.stop()
print_status("[Restart] WebSocket client stopped", "info")
except Exception as e:
print_status(f"[Restart] Error stopping WebSocket: {e}", "warning")
# Step 2: Get HostNode and cleanup ROS
print_status("[Restart] Step 2: Cleaning up ROS nodes...", "info")
try:
from unilabos.ros.nodes.presets.host_node import HostNode
import rclpy
from rclpy.timer import Timer
host_instance = HostNode.get_instance(timeout=5)
if host_instance is not None:
print_status(f"[Restart] Found HostNode: {host_instance.device_id}", "info")
# Gracefully shutdown background threads
print_status("[Restart] Shutting down background threads...", "info")
HostNode.shutdown_background_threads(timeout=5.0)
print_status("[Restart] Background threads shutdown complete", "info")
# Stop discovery timer
if hasattr(host_instance, "_discovery_timer") and isinstance(host_instance._discovery_timer, Timer):
host_instance._discovery_timer.cancel()
print_status("[Restart] Discovery timer cancelled", "info")
# Destroy device nodes
device_count = len(host_instance.devices_instances)
print_status(f"[Restart] Destroying {device_count} device instances...", "info")
for device_id, device_node in list(host_instance.devices_instances.items()):
try:
if hasattr(device_node, "ros_node_instance") and device_node.ros_node_instance is not None:
device_node.ros_node_instance.destroy_node()
print_status(f"[Restart] Device {device_id} destroyed", "info")
except Exception as e:
print_status(f"[Restart] Error destroying device {device_id}: {e}", "warning")
# Clear devices instances
host_instance.devices_instances.clear()
host_instance.devices_names.clear()
# Destroy host node
try:
host_instance.destroy_node()
print_status("[Restart] HostNode destroyed", "info")
except Exception as e:
print_status(f"[Restart] Error destroying HostNode: {e}", "warning")
# Reset HostNode state
HostNode.reset_state()
print_status("[Restart] HostNode state reset", "info")
# Shutdown executor first (to stop executor.spin() gracefully)
if hasattr(rclpy, "__executor") and rclpy.__executor is not None:
try:
rclpy.__executor.shutdown()
rclpy.__executor = None # Clear for restart
print_status("[Restart] ROS executor shutdown complete", "info")
except Exception as e:
print_status(f"[Restart] Error shutting down executor: {e}", "warning")
# Shutdown rclpy
if rclpy.ok():
rclpy.shutdown()
print_status("[Restart] rclpy shutdown complete", "info")
except ImportError as e:
print_status(f"[Restart] ROS modules not available: {e}", "warning")
except Exception as e:
print_status(f"[Restart] Error in ROS cleanup: {e}", "warning")
return False
# Step 3: Reset communication client singleton
print_status("[Restart] Step 3: Resetting singletons...", "info")
try:
from unilabos.app import communication
if hasattr(communication, "_communication_client"):
communication._communication_client = None
print_status("[Restart] Communication client singleton reset", "info")
except Exception as e:
print_status(f"[Restart] Error resetting communication singleton: {e}", "warning")
# Step 4: Wait for threads to finish
print_status("[Restart] Step 4: Waiting for threads to finish...", "info")
time.sleep(3) # Give threads time to finish
# Check remaining threads
remaining_threads = []
for t in threading.enumerate():
if t.name != "MainThread" and t.is_alive():
remaining_threads.append(t.name)
if remaining_threads:
print_status(
f"[Restart] Warning: {len(remaining_threads)} threads still running: {remaining_threads}", "warning"
)
else:
print_status("[Restart] All threads stopped", "info")
# Step 5: Force garbage collection
print_status("[Restart] Step 5: Running garbage collection...", "info")
gc.collect()
gc.collect() # Run twice for weak references
print_status("[Restart] Garbage collection complete", "info")
print_status("[Restart] Cleanup complete. Ready for re-initialization.", "info")
return True

View File

@@ -6,7 +6,6 @@ Web服务器模块
import webbrowser import webbrowser
import uvicorn
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import Response from starlette.responses import Response
@@ -96,7 +95,7 @@ def setup_server() -> FastAPI:
return app return app
def start_server(host: str = "0.0.0.0", port: int = 8002, open_browser: bool = True) -> None: def start_server(host: str = "0.0.0.0", port: int = 8002, open_browser: bool = True) -> bool:
""" """
启动服务器 启动服务器
@@ -104,7 +103,14 @@ def start_server(host: str = "0.0.0.0", port: int = 8002, open_browser: bool = T
host: 服务器主机 host: 服务器主机
port: 服务器端口 port: 服务器端口
open_browser: 是否自动打开浏览器 open_browser: 是否自动打开浏览器
Returns:
bool: True if restart was requested, False otherwise
""" """
import threading
import time
from uvicorn import Config, Server
# 设置服务器 # 设置服务器
setup_server() setup_server()
@@ -123,7 +129,37 @@ def start_server(host: str = "0.0.0.0", port: int = 8002, open_browser: bool = T
# 启动服务器 # 启动服务器
info(f"[Web] 启动FastAPI服务器: {host}:{port}") info(f"[Web] 启动FastAPI服务器: {host}:{port}")
uvicorn.run(app, host=host, port=port, log_config=log_config)
# 使用支持重启的模式
config = Config(app=app, host=host, port=port, log_config=log_config)
server = Server(config)
# 启动服务器线程
server_thread = threading.Thread(target=server.run, daemon=True, name="uvicorn_server")
server_thread.start()
info("[Web] Server started, monitoring for restart requests...")
# 监控重启标志
import unilabos.app.main as main_module
while server_thread.is_alive():
if hasattr(main_module, "_restart_requested") and main_module._restart_requested:
info(
f"[Web] Restart requested via WebSocket, reason: {getattr(main_module, '_restart_reason', 'unknown')}"
)
main_module._restart_requested = False
# 停止服务器
server.should_exit = True
server_thread.join(timeout=5)
info("[Web] Server stopped, ready for restart")
return True
time.sleep(1)
return False
# 当脚本直接运行时启动服务器 # 当脚本直接运行时启动服务器

View File

@@ -488,11 +488,16 @@ class MessageProcessor:
async for message in self.websocket: async for message in self.websocket:
try: try:
data = json.loads(message) data = json.loads(message)
message_type = data.get("action", "")
message_data = data.get("data")
if self.session_id and self.session_id == data.get("edge_session"): if self.session_id and self.session_id == data.get("edge_session"):
await self._process_message(data) await self._process_message(message_type, message_data)
else: else:
logger.trace(f"[MessageProcessor] 收到一条归属 {data.get('edge_session')} 的旧消息:{data}") if message_type.endswith("_material"):
logger.debug(f"[MessageProcessor] 跳过了一条归属 {data.get('edge_session')} 的旧消息: {data.get('action')}") logger.trace(f"[MessageProcessor] 收到一条归属 {data.get('edge_session')} 的旧消息{data}")
logger.debug(f"[MessageProcessor] 跳过了一条归属 {data.get('edge_session')} 的旧消息: {data.get('action')}")
else:
await self._process_message(message_type, message_data)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error(f"[MessageProcessor] Invalid JSON received: {message}") logger.error(f"[MessageProcessor] Invalid JSON received: {message}")
except Exception as e: except Exception as e:
@@ -558,11 +563,8 @@ class MessageProcessor:
finally: finally:
logger.debug("[MessageProcessor] Send handler stopped") logger.debug("[MessageProcessor] Send handler stopped")
async def _process_message(self, data: Dict[str, Any]): async def _process_message(self, message_type: str, message_data: Dict[str, Any]):
"""处理收到的消息""" """处理收到的消息"""
message_type = data.get("action", "")
message_data = data.get("data")
logger.debug(f"[MessageProcessor] Processing message: {message_type}") logger.debug(f"[MessageProcessor] Processing message: {message_type}")
try: try:
@@ -575,16 +577,19 @@ class MessageProcessor:
elif message_type == "cancel_action" or message_type == "cancel_task": elif message_type == "cancel_action" or message_type == "cancel_task":
await self._handle_cancel_action(message_data) await self._handle_cancel_action(message_data)
elif message_type == "add_material": elif message_type == "add_material":
# noinspection PyTypeChecker
await self._handle_resource_tree_update(message_data, "add") await self._handle_resource_tree_update(message_data, "add")
elif message_type == "update_material": elif message_type == "update_material":
# noinspection PyTypeChecker
await self._handle_resource_tree_update(message_data, "update") await self._handle_resource_tree_update(message_data, "update")
elif message_type == "remove_material": elif message_type == "remove_material":
# noinspection PyTypeChecker
await self._handle_resource_tree_update(message_data, "remove") await self._handle_resource_tree_update(message_data, "remove")
# elif message_type == "session_id": # elif message_type == "session_id":
# self.session_id = message_data.get("session_id") # self.session_id = message_data.get("session_id")
# logger.info(f"[MessageProcessor] Session ID: {self.session_id}") # logger.info(f"[MessageProcessor] Session ID: {self.session_id}")
elif message_type == "request_reload": elif message_type == "request_restart":
await self._handle_request_reload(message_data) await self._handle_request_restart(message_data)
else: else:
logger.debug(f"[MessageProcessor] Unknown message type: {message_type}") logger.debug(f"[MessageProcessor] Unknown message type: {message_type}")
@@ -894,19 +899,48 @@ class MessageProcessor:
) )
thread.start() thread.start()
async def _handle_request_reload(self, data: Dict[str, Any]): async def _handle_request_restart(self, data: Dict[str, Any]):
""" """
处理重请求 处理重请求
当LabGo发送request_reload时重新发送设备注册信息 当LabGo发送request_restart时执行清理并触发重启
""" """
reason = data.get("reason", "unknown") reason = data.get("reason", "unknown")
logger.info(f"[MessageProcessor] Received reload request, reason: {reason}") delay = data.get("delay", 2) # 默认延迟2秒
logger.info(f"[MessageProcessor] Received restart request, reason: {reason}, delay: {delay}s")
# 重新发送host_node_ready信 # 发送确认消
if self.websocket_client: if self.websocket_client:
self.websocket_client.publish_host_ready() await self.websocket_client.send_message({
logger.info("[MessageProcessor] Re-sent host_node_ready after reload request") "action": "restart_acknowledged",
"data": {"reason": reason, "delay": delay}
})
# 设置全局重启标志
import unilabos.app.main as main_module
main_module._restart_requested = True
main_module._restart_reason = reason
# 延迟后执行清理
await asyncio.sleep(delay)
# 在新线程中执行清理,避免阻塞当前事件循环
def do_cleanup():
import time
time.sleep(0.5) # 给当前消息处理完成的时间
logger.info(f"[MessageProcessor] Starting cleanup for restart, reason: {reason}")
try:
from unilabos.app.utils import cleanup_for_restart
if cleanup_for_restart():
logger.info("[MessageProcessor] Cleanup successful, main() will restart")
else:
logger.error("[MessageProcessor] Cleanup failed")
except Exception as e:
logger.error(f"[MessageProcessor] Error during cleanup: {e}")
cleanup_thread = threading.Thread(target=do_cleanup, name="RestartCleanupThread", daemon=True)
cleanup_thread.start()
logger.info(f"[MessageProcessor] Restart cleanup scheduled")
async def _send_action_state_response( async def _send_action_state_response(
self, device_id: str, action_name: str, task_id: str, job_id: str, typ: str, free: bool, need_more: int self, device_id: str, action_name: str, task_id: str, job_id: str, typ: str, free: bool, need_more: int

View File

@@ -1,4 +1,5 @@
import json import json
# from nt import device_encoding # from nt import device_encoding
import threading import threading
import time import time
@@ -55,7 +56,11 @@ def main(
) -> None: ) -> None:
"""主函数""" """主函数"""
rclpy.init(args=rclpy_init_args) # Support restart - check if rclpy is already initialized
if not rclpy.ok():
rclpy.init(args=rclpy_init_args)
else:
logger.info("[ROS] rclpy already initialized, reusing context")
executor = rclpy.__executor = MultiThreadedExecutor() executor = rclpy.__executor = MultiThreadedExecutor()
# 创建主机节点 # 创建主机节点
host_node = HostNode( host_node = HostNode(
@@ -88,7 +93,7 @@ def main(
joint_republisher = JointRepublisher("joint_republisher", host_node.resource_tracker) joint_republisher = JointRepublisher("joint_republisher", host_node.resource_tracker)
# lh_joint_pub = LiquidHandlerJointPublisher( # lh_joint_pub = LiquidHandlerJointPublisher(
# resources_config=resources_list, resource_tracker=host_node.resource_tracker # resources_config=resources_list, resource_tracker=host_node.resource_tracker
# ) # )
executor.add_node(resource_mesh_manager) executor.add_node(resource_mesh_manager)
executor.add_node(joint_republisher) executor.add_node(joint_republisher)
# executor.add_node(lh_joint_pub) # executor.add_node(lh_joint_pub)

View File

@@ -70,6 +70,8 @@ class HostNode(BaseROS2DeviceNode):
_instance: ClassVar[Optional["HostNode"]] = None _instance: ClassVar[Optional["HostNode"]] = None
_ready_event: ClassVar[threading.Event] = threading.Event() _ready_event: ClassVar[threading.Event] = threading.Event()
_shutting_down: ClassVar[bool] = False # Flag to signal shutdown to background threads
_background_threads: ClassVar[List[threading.Thread]] = [] # Track all background threads for cleanup
_device_action_status: ClassVar[collections.defaultdict[str, DeviceActionStatus]] = collections.defaultdict( _device_action_status: ClassVar[collections.defaultdict[str, DeviceActionStatus]] = collections.defaultdict(
DeviceActionStatus DeviceActionStatus
) )
@@ -81,6 +83,48 @@ class HostNode(BaseROS2DeviceNode):
return cls._instance return cls._instance
return None return None
@classmethod
def shutdown_background_threads(cls, timeout: float = 5.0) -> None:
"""
Gracefully shutdown all background threads for clean exit or restart.
This method:
1. Sets shutdown flag to stop background operations
2. Waits for background threads to finish with timeout
3. Cleans up finished threads from tracking list
Args:
timeout: Maximum time to wait for each thread (seconds)
"""
cls._shutting_down = True
# Wait for background threads to finish
active_threads = []
for t in cls._background_threads:
if t.is_alive():
t.join(timeout=timeout)
if t.is_alive():
active_threads.append(t.name)
if active_threads:
logger.warning(f"[Host Node] Some background threads still running: {active_threads}")
# Clear the thread list
cls._background_threads.clear()
logger.info(f"[Host Node] Background threads shutdown complete")
@classmethod
def reset_state(cls) -> None:
"""
Reset the HostNode singleton state for restart or clean exit.
Call this after destroying the instance.
"""
cls._instance = None
cls._ready_event.clear()
cls._shutting_down = False
cls._background_threads.clear()
logger.info("[Host Node] State reset complete")
def __init__( def __init__(
self, self,
device_id: str, device_id: str,
@@ -294,12 +338,37 @@ class HostNode(BaseROS2DeviceNode):
bridge.publish_host_ready() bridge.publish_host_ready()
self.lab_logger().debug(f"Host ready signal sent via {bridge.__class__.__name__}") self.lab_logger().debug(f"Host ready signal sent via {bridge.__class__.__name__}")
def _send_re_register(self, sclient): def _send_re_register(self, sclient, device_namespace: str):
sclient.wait_for_service() """
request = SerialCommand.Request() Send re-register command to a device. This is a one-time operation.
request.command = ""
future = sclient.call_async(request) Args:
response = future.result() sclient: The service client
device_namespace: The device namespace for logging
"""
try:
# Use timeout to prevent indefinite blocking
if not sclient.wait_for_service(timeout_sec=10.0):
self.lab_logger().debug(f"[Host Node] Re-register timeout for {device_namespace}")
return
# Check shutdown flag after wait
if self._shutting_down:
self.lab_logger().debug(f"[Host Node] Re-register aborted for {device_namespace} (shutdown)")
return
request = SerialCommand.Request()
request.command = ""
future = sclient.call_async(request)
# Use timeout for result as well
future.result(timeout_sec=5.0)
self.lab_logger().debug(f"[Host Node] Re-register completed for {device_namespace}")
except Exception as e:
# Gracefully handle destruction during shutdown
if "destruction was requested" in str(e) or self._shutting_down:
self.lab_logger().debug(f"[Host Node] Re-register aborted for {device_namespace} (cleanup)")
else:
self.lab_logger().warning(f"[Host Node] Re-register failed for {device_namespace}: {e}")
def _discover_devices(self) -> None: def _discover_devices(self) -> None:
""" """
@@ -331,23 +400,27 @@ class HostNode(BaseROS2DeviceNode):
self._create_action_clients_for_device(device_id, namespace) self._create_action_clients_for_device(device_id, namespace)
self._online_devices.add(device_key) self._online_devices.add(device_key)
sclient = self.create_client(SerialCommand, f"/srv{namespace}/re_register_device") sclient = self.create_client(SerialCommand, f"/srv{namespace}/re_register_device")
threading.Thread( t = threading.Thread(
target=self._send_re_register, target=self._send_re_register,
args=(sclient,), args=(sclient, namespace),
daemon=True, daemon=True,
name=f"ROSDevice{self.device_id}_re_register_device_{namespace}", name=f"ROSDevice{self.device_id}_re_register_device_{namespace}",
).start() )
self._background_threads.append(t)
t.start()
elif device_key not in self._online_devices: elif device_key not in self._online_devices:
# 设备重新上线 # 设备重新上线
self.lab_logger().info(f"[Host Node] Device reconnected: {device_key}") self.lab_logger().info(f"[Host Node] Device reconnected: {device_key}")
self._online_devices.add(device_key) self._online_devices.add(device_key)
sclient = self.create_client(SerialCommand, f"/srv{namespace}/re_register_device") sclient = self.create_client(SerialCommand, f"/srv{namespace}/re_register_device")
threading.Thread( t = threading.Thread(
target=self._send_re_register, target=self._send_re_register,
args=(sclient,), args=(sclient, namespace),
daemon=True, daemon=True,
name=f"ROSDevice{self.device_id}_re_register_device_{namespace}", name=f"ROSDevice{self.device_id}_re_register_device_{namespace}",
).start() )
self._background_threads.append(t)
t.start()
# 检测离线设备 # 检测离线设备
offline_devices = self._online_devices - current_devices offline_devices = self._online_devices - current_devices
@@ -705,13 +778,14 @@ class HostNode(BaseROS2DeviceNode):
raise ValueError(f"ActionClient {action_id} not found.") raise ValueError(f"ActionClient {action_id} not found.")
action_client: ActionClient = self._action_clients[action_id] action_client: ActionClient = self._action_clients[action_id]
# 遍历action_kwargs下的所有子dict将"sample_uuid"的值赋给"sample_id" # 遍历action_kwargs下的所有子dict将"sample_uuid"的值赋给"sample_id"
def assign_sample_id(obj): def assign_sample_id(obj):
if isinstance(obj, dict): if isinstance(obj, dict):
if "sample_uuid" in obj: if "sample_uuid" in obj:
obj["sample_id"] = obj["sample_uuid"] obj["sample_id"] = obj["sample_uuid"]
obj.pop("sample_uuid") obj.pop("sample_uuid")
for k,v in obj.items(): for k, v in obj.items():
if k != "unilabos_extra": if k != "unilabos_extra":
assign_sample_id(v) assign_sample_id(v)
elif isinstance(obj, list): elif isinstance(obj, list):
@@ -742,9 +816,7 @@ class HostNode(BaseROS2DeviceNode):
self.lab_logger().info(f"[Host Node] Goal {action_id} ({item.job_id}) accepted") self.lab_logger().info(f"[Host Node] Goal {action_id} ({item.job_id}) accepted")
self._goals[item.job_id] = goal_handle self._goals[item.job_id] = goal_handle
goal_future = goal_handle.get_result_async() goal_future = goal_handle.get_result_async()
goal_future.add_done_callback( goal_future.add_done_callback(lambda f: self.get_result_callback(item, action_id, f))
lambda f: self.get_result_callback(item, action_id, f)
)
goal_future.result() goal_future.result()
def feedback_callback(self, item: "QueueItem", action_id: str, feedback_msg) -> None: def feedback_callback(self, item: "QueueItem", action_id: str, feedback_msg) -> None:
@@ -1167,6 +1239,7 @@ class HostNode(BaseROS2DeviceNode):
""" """
try: try:
from unilabos.app.web import http_client from unilabos.app.web import http_client
data = json.loads(request.command) data = json.loads(request.command)
if "uuid" in data and data["uuid"] is not None: if "uuid" in data and data["uuid"] is not None:
http_req = http_client.resource_tree_get([data["uuid"]], data["with_children"]) http_req = http_client.resource_tree_get([data["uuid"]], data["with_children"])