From cd84e261262285cdb819690cdbe126002c35120e Mon Sep 17 00:00:00 2001 From: Xuwznln <18435084+Xuwznln@users.noreply.github.com> Date: Thu, 28 Aug 2025 14:34:38 +0800 Subject: [PATCH] feat: websocket --- .conda/recipe.yaml | 2 +- README.md | 3 +- README_zh.md | 9 +- unilabos-linux-64.yaml | 2 +- unilabos-osx-64.yaml | 2 +- unilabos-osx-arm64.yaml | 3 +- unilabos-win64.yaml | 2 +- unilabos/app/communication.py | 204 ++++++++++++++++ unilabos/app/main.py | 22 +- unilabos/app/mq.py | 16 +- unilabos/app/register.py | 36 ++- unilabos/app/ws_client.py | 300 ++++++++++++++++++++++++ unilabos/config/config.py | 15 +- unilabos/config/example_config.py | 1 + unilabos/ros/nodes/presets/host_node.py | 75 +++--- unilabos/utils/environment_check.py | 1 + 16 files changed, 626 insertions(+), 67 deletions(-) create mode 100644 unilabos/app/communication.py create mode 100644 unilabos/app/ws_client.py diff --git a/.conda/recipe.yaml b/.conda/recipe.yaml index 44f2b2e8..182f3dab 100644 --- a/.conda/recipe.yaml +++ b/.conda/recipe.yaml @@ -61,7 +61,7 @@ requirements: - uvicorn - gradio - flask - - websocket + - websockets - ipython - jupyter - jupyros diff --git a/README.md b/README.md index 142e7f9d..93f2fcbe 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ # Uni-Lab-OS + **English** | [中文](README_zh.md) [![GitHub Stars](https://img.shields.io/github/stars/dptech-corp/Uni-Lab-OS.svg)](https://github.com/dptech-corp/Uni-Lab-OS/stargazers) @@ -74,4 +75,4 @@ This project is licensed under GPL-3.0 - see the [LICENSE](LICENSE) file for det ## Contact Us -- GitHub Issues: [https://github.com/dptech-corp/Uni-Lab-OS/issues](https://github.com/dptech-corp/Uni-Lab-OS/issues) \ No newline at end of file +- GitHub Issues: [https://github.com/dptech-corp/Uni-Lab-OS/issues](https://github.com/dptech-corp/Uni-Lab-OS/issues) diff --git a/README_zh.md b/README_zh.md index 9ac81598..07b400db 100644 --- a/README_zh.md +++ b/README_zh.md @@ -5,6 +5,7 @@ # Uni-Lab-OS + [English](README.md) | **中文** [![GitHub Stars](https://img.shields.io/github/stars/dptech-corp/Uni-Lab-OS.svg)](https://github.com/dptech-corp/Uni-Lab-OS/stargazers) @@ -12,7 +13,7 @@ [![GitHub Issues](https://img.shields.io/github/issues/dptech-corp/Uni-Lab-OS.svg)](https://github.com/dptech-corp/Uni-Lab-OS/issues) [![GitHub License](https://img.shields.io/github/license/dptech-corp/Uni-Lab-OS.svg)](https://github.com/dptech-corp/Uni-Lab-OS/blob/main/LICENSE) -Uni-Lab-OS是一个用于实验室自动化的综合平台,旨在连接和控制各种实验设备,实现实验流程的自动化和标准化。 +Uni-Lab-OS 是一个用于实验室自动化的综合平台,旨在连接和控制各种实验设备,实现实验流程的自动化和标准化。 ## 🏆 比赛 @@ -34,7 +35,7 @@ Uni-Lab-OS是一个用于实验室自动化的综合平台,旨在连接和控 ## 快速开始 -1. 配置Conda环境 +1. 配置 Conda 环境 Uni-Lab-OS 建议使用 `mamba` 管理环境。根据您的操作系统选择适当的环境文件: @@ -43,7 +44,7 @@ Uni-Lab-OS 建议使用 `mamba` 管理环境。根据您的操作系统选择适 mamba create -n unilab uni-lab::unilabos -c robostack-staging -c conda-forge ``` -2. 安装开发版Uni-Lab-OS: +2. 安装开发版 Uni-Lab-OS: ```bash # 克隆仓库 @@ -76,4 +77,4 @@ Uni-Lab-OS 使用预构建的 `unilabos_msgs` 进行系统通信。您可以在 ## 联系我们 -- GitHub Issues: [https://github.com/dptech-corp/Uni-Lab-OS/issues](https://github.com/dptech-corp/Uni-Lab-OS/issues) \ No newline at end of file +- GitHub Issues: [https://github.com/dptech-corp/Uni-Lab-OS/issues](https://github.com/dptech-corp/Uni-Lab-OS/issues) diff --git a/unilabos-linux-64.yaml b/unilabos-linux-64.yaml index c84e0451..2604b051 100644 --- a/unilabos-linux-64.yaml +++ b/unilabos-linux-64.yaml @@ -34,7 +34,7 @@ dependencies: - uvicorn - gradio - flask - - websocket + - websockets # Notebook - ipython - jupyter diff --git a/unilabos-osx-64.yaml b/unilabos-osx-64.yaml index ca9a96fa..2d0c3325 100644 --- a/unilabos-osx-64.yaml +++ b/unilabos-osx-64.yaml @@ -34,7 +34,7 @@ dependencies: - uvicorn - gradio - flask - - websocket + - websockets # Notebook - ipython - jupyter diff --git a/unilabos-osx-arm64.yaml b/unilabos-osx-arm64.yaml index 7f9675db..a4e88016 100644 --- a/unilabos-osx-arm64.yaml +++ b/unilabos-osx-arm64.yaml @@ -35,8 +35,7 @@ dependencies: - uvicorn - gradio - flask - - websocket - - paho-mqtt + - websockets # Notebook - ipython - jupyter diff --git a/unilabos-win64.yaml b/unilabos-win64.yaml index b2065a0c..9eb55fd3 100644 --- a/unilabos-win64.yaml +++ b/unilabos-win64.yaml @@ -34,7 +34,7 @@ dependencies: - uvicorn - gradio - flask - - websocket + - websockets # Notebook - ipython - jupyter diff --git a/unilabos/app/communication.py b/unilabos/app/communication.py new file mode 100644 index 00000000..60b93818 --- /dev/null +++ b/unilabos/app/communication.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python +# coding=utf-8 +""" +通信模块 + +提供MQTT和WebSocket的统一接口,支持通过配置选择通信协议。 +包含通信抽象层基类和通信客户端工厂。 +""" + +from abc import ABC, abstractmethod +from typing import Optional +from unilabos.config.config import BasicConfig +from unilabos.utils import logger + + +class BaseCommunicationClient(ABC): + """ + 通信客户端抽象基类 + + 定义了所有通信客户端(MQTT、WebSocket等)需要实现的接口。 + """ + + def __init__(self): + self.is_disabled = True + self.client_id = "" + + @abstractmethod + def start(self) -> None: + """ + 启动通信客户端连接 + """ + pass + + @abstractmethod + def stop(self) -> None: + """ + 停止通信客户端连接 + """ + pass + + @abstractmethod + def publish_device_status(self, device_status: dict, device_id: str, property_name: str) -> None: + """ + 发布设备状态信息 + + Args: + device_status: 设备状态字典 + device_id: 设备ID + property_name: 属性名称 + """ + pass + + @abstractmethod + def publish_job_status( + self, feedback_data: dict, job_id: str, status: str, return_info: Optional[str] = None + ) -> None: + """ + 发布作业状态信息 + + Args: + feedback_data: 反馈数据 + job_id: 作业ID + status: 作业状态 + return_info: 返回信息 + """ + pass + + @abstractmethod + def send_ping(self, ping_id: str, timestamp: float) -> None: + """ + 发送ping消息 + + Args: + ping_id: ping ID + timestamp: 时间戳 + """ + pass + + def setup_pong_subscription(self) -> None: + """ + 设置pong消息订阅(可选实现) + """ + pass + + @property + def is_connected(self) -> bool: + """ + 检查是否已连接 + + Returns: + 是否已连接 + """ + return not self.is_disabled + + +class CommunicationClientFactory: + """ + 通信客户端工厂类 + + 根据配置文件中的通信协议设置创建相应的客户端实例。 + """ + + _client_cache: Optional[BaseCommunicationClient] = None + + @classmethod + def create_client(cls, protocol: Optional[str] = None) -> BaseCommunicationClient: + """ + 创建通信客户端实例 + + Args: + protocol: 指定的协议类型,如果为None则使用配置文件中的设置 + + Returns: + 通信客户端实例 + + Raises: + ValueError: 当协议类型不支持时 + """ + if protocol is None: + protocol = BasicConfig.communication_protocol + + protocol = protocol.lower() + + if protocol == "mqtt": + return cls._create_mqtt_client() + elif protocol == "websocket": + return cls._create_websocket_client() + else: + logger.error(f"[CommunicationFactory] Unsupported protocol: {protocol}") + logger.warning(f"[CommunicationFactory] Falling back to MQTT") + return cls._create_mqtt_client() + + @classmethod + def get_client(cls, protocol: Optional[str] = None) -> BaseCommunicationClient: + """ + 获取通信客户端实例(单例模式) + + Args: + protocol: 指定的协议类型,如果为None则使用配置文件中的设置 + + Returns: + 通信客户端实例 + """ + if cls._client_cache is None: + cls._client_cache = cls.create_client(protocol) + logger.info(f"[CommunicationFactory] Created {type(cls._client_cache).__name__} client") + + return cls._client_cache + + @classmethod + def _create_mqtt_client(cls) -> BaseCommunicationClient: + """创建MQTT客户端""" + try: + from unilabos.app.mq import mqtt_client + return mqtt_client + except Exception as e: + logger.error(f"[CommunicationFactory] Failed to create MQTT client: {str(e)}") + raise + + @classmethod + def _create_websocket_client(cls) -> BaseCommunicationClient: + """创建WebSocket客户端""" + try: + from unilabos.app.ws_client import WebSocketClient + return WebSocketClient() + except Exception as e: + logger.error(f"[CommunicationFactory] Failed to create WebSocket client: {str(e)}") + logger.warning(f"[CommunicationFactory] Falling back to MQTT") + return cls._create_mqtt_client() + + @classmethod + def reset_client(cls): + """重置客户端缓存(用于测试或重新配置)""" + if cls._client_cache: + try: + cls._client_cache.stop() + except Exception as e: + logger.warning(f"[CommunicationFactory] Error stopping old client: {str(e)}") + + cls._client_cache = None + logger.info("[CommunicationFactory] Client cache reset") + + @classmethod + def get_supported_protocols(cls) -> list[str]: + """ + 获取支持的协议列表 + + Returns: + 支持的协议列表 + """ + return ["mqtt", "websocket"] + + +def get_communication_client(protocol: Optional[str] = None) -> BaseCommunicationClient: + """ + 获取通信客户端实例的便捷函数 + + Args: + protocol: 指定的协议类型,如果为None则使用配置文件中的设置 + + Returns: + 通信客户端实例 + """ + return CommunicationClientFactory.get_client(protocol) diff --git a/unilabos/app/main.py b/unilabos/app/main.py index 59b73c81..3d7761e8 100644 --- a/unilabos/app/main.py +++ b/unilabos/app/main.py @@ -10,7 +10,6 @@ from copy import deepcopy import yaml -from unilabos.resources.graphio import modify_to_backend_format # 首先添加项目根目录到路径 current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -20,6 +19,7 @@ if unilabos_dir not in sys.path: from unilabos.config.config import load_config, BasicConfig from unilabos.utils.banner_print import print_status, print_unilab_banner +from unilabos.resources.graphio import modify_to_backend_format def load_config_from_file(config_path, override_labid=None): @@ -146,6 +146,11 @@ def parse_args(): default="", help="实验室请求的sk", ) + parser.add_argument( + "--websocket", + action="store_true", + help="使用websocket而非mqtt作为通信协议", + ) parser.add_argument( "--skip_env_check", action="store_true", @@ -179,7 +184,7 @@ def main(): else: working_dir = os.path.abspath(os.path.join(os.getcwd(), "unilabos_data")) if args_dict.get("working_dir"): - working_dir = args_dict.get("working_dir") + working_dir = args_dict.get("working_dir", "") if config_path and not os.path.exists(config_path): config_path = os.path.join(working_dir, "local_config.py") if not os.path.exists(config_path): @@ -215,6 +220,7 @@ def main(): if args_dict["use_remote_resource"]: print_status("使用远程资源启动", "info") from unilabos.app.web import http_client + res = http_client.resource_get("host_node", False) if str(res.get("code", 0)) == "0" and len(res.get("data", [])) > 0: print_status("远程资源已存在,使用云端物料!", "info") @@ -229,6 +235,7 @@ def main(): BasicConfig.is_host_mode = not args_dict.get("without_host", False) BasicConfig.slave_no_host = args_dict.get("slave_no_host", False) BasicConfig.upload_registry = args_dict.get("upload_registry", False) + BasicConfig.communication_protocol = "websocket" if args_dict.get("websocket", False) else "mqtt" machine_name = os.popen("hostname").read().strip() machine_name = "".join([c if c.isalnum() or c == "_" else "_" for c in machine_name]) BasicConfig.machine_name = machine_name @@ -241,7 +248,7 @@ def main(): dict_to_nested_dict, initialize_resources, ) - from unilabos.app.mq import mqtt_client + from unilabos.app.communication import get_communication_client from unilabos.registry.registry import build_registry from unilabos.app.backend import start_backend from unilabos.app.web import http_client @@ -289,19 +296,22 @@ def main(): args_dict["bridges"] = [] + # 获取通信客户端(根据配置选择MQTT或WebSocket) + comm_client = get_communication_client() + if "mqtt" in args_dict["app_bridges"]: - args_dict["bridges"].append(mqtt_client) + args_dict["bridges"].append(comm_client) if "fastapi" in args_dict["app_bridges"]: args_dict["bridges"].append(http_client) if "mqtt" in args_dict["app_bridges"]: def _exit(signum, frame): - mqtt_client.stop() + comm_client.stop() sys.exit(0) signal.signal(signal.SIGINT, _exit) signal.signal(signal.SIGTERM, _exit) - mqtt_client.start() + comm_client.start() args_dict["resources_mesh_config"] = {} args_dict["resources_edge_config"] = resource_edge_info # web visiualize 2D diff --git a/unilabos/app/mq.py b/unilabos/app/mq.py index 1d3f9695..65d0ab42 100644 --- a/unilabos/app/mq.py +++ b/unilabos/app/mq.py @@ -15,17 +15,20 @@ import os from unilabos.config.config import MQConfig from unilabos.app.controler import job_add from unilabos.app.model import JobAddReq +from unilabos.app.communication import BaseCommunicationClient from unilabos.utils import logger from unilabos.utils.type_check import TypeEncoder from paho.mqtt.enums import CallbackAPIVersion -class MQTTClient: +class MQTTClient(BaseCommunicationClient): mqtt_disable = True def __init__(self): + super().__init__() self.mqtt_disable = not MQConfig.lab_id + self.is_disabled = self.mqtt_disable # 更新父类属性 self.client_id = f"{MQConfig.group_id}@@@{MQConfig.lab_id}{uuid.uuid4()}" logger.info("[MQTT] Client_id: " + self.client_id) self.client = mqtt.Client(CallbackAPIVersion.VERSION2, client_id=self.client_id, protocol=mqtt.MQTTv5) @@ -208,11 +211,12 @@ class MQTTClient: self.client.subscribe(pong_topic, 0) logger.debug(f"Subscribed to pong topic: {pong_topic}") - def handle_pong(self, pong_data: dict): - """处理pong响应(这个方法会在收到pong消息时被调用)""" - logger.debug(f"Pong received: {pong_data}") - # 这里会被HostNode的ping-pong处理逻辑调用 - pass + @property + def is_connected(self) -> bool: + """检查MQTT是否已连接""" + if self.is_disabled: + return False + return hasattr(self.client, "is_connected") and self.client.is_connected() mqtt_client = MQTTClient() diff --git a/unilabos/app/register.py b/unilabos/app/register.py index 06469018..a96ff16c 100644 --- a/unilabos/app/register.py +++ b/unilabos/app/register.py @@ -10,35 +10,44 @@ from unilabos.utils.log import logger from unilabos.utils.type_check import TypeEncoder -def register_devices_and_resources(mqtt_client, lab_registry): +def register_devices_and_resources(comm_client, lab_registry): """ - 注册设备和资源到 MQTT + 注册设备和资源到通信服务器(MQTT/WebSocket) """ # 注册资源信息 - 使用HTTP方式 from unilabos.app.web.client import http_client + logger.info("[UniLab Register] 开始注册设备和资源...") if BasicConfig.auth_secret(): # 注册设备信息 devices_to_register = {} for device_info in lab_registry.obtain_registry_device_info(): - devices_to_register[device_info["id"]] = json.loads(json.dumps(device_info, ensure_ascii=False, cls=TypeEncoder)) + devices_to_register[device_info["id"]] = json.loads( + json.dumps(device_info, ensure_ascii=False, cls=TypeEncoder) + ) logger.debug(f"[UniLab Register] 收集设备: {device_info['id']}") resources_to_register = {} for resource_info in lab_registry.obtain_registry_resource_info(): resources_to_register[resource_info["id"]] = resource_info logger.debug(f"[UniLab Register] 收集资源: {resource_info['id']}") - print("[UniLab Register] 设备注册", http_client.resource_registry({"resources": list(devices_to_register.values())}).text) - print("[UniLab Register] 资源注册", http_client.resource_registry({"resources": list(resources_to_register.values())}).text) + print( + "[UniLab Register] 设备注册", + http_client.resource_registry({"resources": list(devices_to_register.values())}).text, + ) + print( + "[UniLab Register] 资源注册", + http_client.resource_registry({"resources": list(resources_to_register.values())}).text, + ) else: # 注册设备信息 for device_info in lab_registry.obtain_registry_device_info(): - mqtt_client.publish_registry(device_info["id"], device_info, False) + comm_client.publish_registry(device_info["id"], device_info, False) logger.debug(f"[UniLab Register] 注册设备: {device_info['id']}") # # 注册资源信息 # for resource_info in lab_registry.obtain_registry_resource_info(): - # mqtt_client.publish_registry(resource_info["id"], resource_info, False) + # comm_client.publish_registry(resource_info["id"], resource_info, False) # logger.debug(f"[UniLab Register] 注册资源: {resource_info['id']}") resources_to_register = {} @@ -53,7 +62,9 @@ def register_devices_and_resources(mqtt_client, lab_registry): if response.status_code in [200, 201]: logger.info(f"[UniLab Register] 成功通过HTTP注册 {len(resources_to_register)} 个资源 {cost_time}ms") else: - logger.error(f"[UniLab Register] HTTP注册资源失败: {response.status_code}, {response.text} {cost_time}ms") + logger.error( + f"[UniLab Register] HTTP注册资源失败: {response.status_code}, {response.text} {cost_time}ms" + ) logger.info("[UniLab Register] 设备和资源注册完成.") @@ -99,15 +110,16 @@ def main(): BasicConfig.sk = args.sk # 构建注册表 build_registry(args.registry, args.complete_registry, True) - from unilabos.app.mq import mqtt_client + from unilabos.app.communication import get_communication_client - # 连接mqtt - mqtt_client.start() + # 获取通信客户端并启动连接 + comm_client = get_communication_client() + comm_client.start() from unilabos.registry.registry import lab_registry # 注册设备和资源 - register_devices_and_resources(mqtt_client, lab_registry) + register_devices_and_resources(comm_client, lab_registry) if __name__ == "__main__": diff --git a/unilabos/app/ws_client.py b/unilabos/app/ws_client.py new file mode 100644 index 00000000..94fa6c6d --- /dev/null +++ b/unilabos/app/ws_client.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python +# coding=utf-8 +""" +WebSocket通信客户端 + +基于WebSocket协议的通信客户端实现,继承自BaseCommunicationClient。 +""" + +import json +import time +import uuid +import threading +import asyncio +import traceback +from typing import Optional, Dict, Any +from urllib.parse import urlparse +from unilabos.app.controler import job_add +from unilabos.app.model import JobAddReq +from unilabos.ros.nodes.presets.host_node import HostNode + +try: + import websockets + import ssl as ssl_module + + HAS_WEBSOCKETS = True +except ImportError: + HAS_WEBSOCKETS = False + +from unilabos.app.communication import BaseCommunicationClient +from unilabos.config.config import WSConfig, HTTPConfig, BasicConfig +from unilabos.utils import logger + + +class WebSocketClient(BaseCommunicationClient): + """ + WebSocket通信客户端类 + + 实现基于WebSocket协议的实时通信功能。 + """ + + def __init__(self): + super().__init__() + + if not HAS_WEBSOCKETS: + logger.error("[WebSocket] websockets库未安装,WebSocket功能不可用") + self.is_disabled = True + return + + self.is_disabled = False + self.client_id = f"{uuid.uuid4()}" + + # WebSocket连接相关 + self.websocket = None + self.connection_loop = None + self.event_loop = None + self.connection_thread = None + self.is_running = False + self.connected = False + + # 消息处理 + self.message_queue = asyncio.Queue() if not self.is_disabled else None + self.reconnect_count = 0 + + # 构建WebSocket URL + self._build_websocket_url() + + logger.info(f"[WebSocket] Client_id: {self.client_id}") + + def _build_websocket_url(self): + """构建WebSocket连接URL""" + if not HTTPConfig.remote_addr: + self.websocket_url = None + return + + # 解析服务器URL + parsed = urlparse(HTTPConfig.remote_addr) + + # 根据SSL配置选择协议 + if parsed.scheme == "https": + scheme = "wss" + else: + scheme = "ws" + self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/lab" + + logger.debug(f"[WebSocket] URL: {self.websocket_url}") + + def start(self) -> None: + """启动WebSocket连接""" + if self.is_disabled: + logger.warning("[WebSocket] WebSocket is disabled, skipping connection.") + return + + if not self.websocket_url: + logger.error("[WebSocket] WebSocket URL not configured") + return + + logger.info(f"[WebSocket] Starting connection to {self.websocket_url}") + + self.is_running = True + + # 在单独线程中运行WebSocket连接 + self.connection_thread = threading.Thread(target=self._run_connection, daemon=True, name="WebSocketConnection") + self.connection_thread.start() + + def stop(self) -> None: + """停止WebSocket连接""" + if self.is_disabled: + return + + logger.info("[WebSocket] Stopping connection") + self.is_running = False + self.connected = False + + if self.event_loop and self.event_loop.is_running(): + asyncio.run_coroutine_threadsafe(self._close_connection(), self.event_loop) + + if self.connection_thread and self.connection_thread.is_alive(): + self.connection_thread.join(timeout=5) + + def _run_connection(self): + """在独立线程中运行WebSocket连接""" + try: + # 创建新的事件循环 + self.event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.event_loop) + + # 运行连接逻辑 + self.event_loop.run_until_complete(self._connection_handler()) + except Exception as e: + logger.error(f"[WebSocket] Connection thread error: {str(e)}") + logger.error(traceback.format_exc()) + finally: + if self.event_loop: + self.event_loop.close() + + async def _connection_handler(self): + """处理WebSocket连接和重连逻辑""" + while self.is_running: + try: + # 构建SSL上下文 + ssl_context = None + assert self.websocket_url is not None + if self.websocket_url.startswith("wss://"): + ssl_context = ssl_module.create_default_context() + + async with websockets.connect( + self.websocket_url, + ssl=ssl_context, + ping_interval=WSConfig.ping_interval, + ping_timeout=10, + additional_headers={"Authorization": f"Bearer {BasicConfig.auth_secret()}"}, + ) as websocket: + self.websocket = websocket + self.connected = True + self.reconnect_count = 0 + + logger.info(f"[WebSocket] Connected to {self.websocket_url}") + # 处理消息 + await self._message_handler() + + except websockets.exceptions.ConnectionClosed: + logger.warning("[WebSocket] Connection closed") + self.connected = False + except Exception as e: + logger.error(f"[WebSocket] Connection error: {str(e)}") + self.connected = False + + # 重连逻辑 + if self.is_running and self.reconnect_count < WSConfig.max_reconnect_attempts: + self.reconnect_count += 1 + logger.info( + f"[WebSocket] Reconnecting in {WSConfig.reconnect_interval}s " + f"(attempt {self.reconnect_count}/{WSConfig.max_reconnect_attempts})" + ) + await asyncio.sleep(WSConfig.reconnect_interval) + elif self.reconnect_count >= WSConfig.max_reconnect_attempts: + logger.error("[WebSocket] Max reconnection attempts reached") + break + + async def _close_connection(self): + """关闭WebSocket连接""" + if self.websocket: + await self.websocket.close() + self.websocket = None + + async def _send_message(self, message: Dict[str, Any]): + """发送消息""" + if not self.connected or not self.websocket: + logger.warning("[WebSocket] Not connected, cannot send message") + return + + try: + message_str = json.dumps(message, ensure_ascii=False) + await self.websocket.send(message_str) + logger.debug(f"[WebSocket] Message sent: {message['type']}") + except Exception as e: + logger.error(f"[WebSocket] Failed to send message: {str(e)}") + + async def _message_handler(self): + """处理接收到的消息""" + if not self.websocket: + logger.error("[WebSocket] WebSocket connection is None") + return + + try: + async for message in self.websocket: + try: + data = json.loads(message) + await self._process_message(data) + except json.JSONDecodeError: + logger.error(f"[WebSocket] Invalid JSON received: {message}") + except Exception as e: + logger.error(f"[WebSocket] Error processing message: {str(e)}") + except websockets.exceptions.ConnectionClosed: + logger.info("[WebSocket] Message handler stopped - connection closed") + except Exception as e: + logger.error(f"[WebSocket] Message handler error: {str(e)}") + + async def _process_message(self, input_message: Dict[str, Any]): + """处理收到的消息""" + message_type = input_message.get("type", "") + data = input_message.get("data", {}) + if message_type == "job_start": + # 处理作业启动消息 + await self._handle_job_start(data) + elif message_type == "pong": + # 处理pong响应 + self._handle_pong_sync(data) + else: + logger.debug(f"[WebSocket] Unknown message type: {message_type}") + + async def _handle_job_start(self, data: Dict[str, Any]): + """处理作业启动消息""" + try: + job_req = JobAddReq(**data.get("job_data", {})) + job_add(job_req) + job_id = getattr(job_req, "id", "unknown") + logger.info(f"[WebSocket] Job started: {job_id}") + except Exception as e: + logger.error(f"[WebSocket] Error handling job start: {str(e)}") + + def _handle_pong_sync(self, pong_data: Dict[str, Any]): + """同步处理pong响应""" + host_node = HostNode.get_instance(0) + if host_node: + host_node.handle_pong_response(pong_data) + + # 实现抽象基类的方法 + def publish_device_status(self, device_status: dict, device_id: str, property_name: str) -> None: + """发布设备状态""" + if self.is_disabled or not self.connected: + return + message = { + "type": "device_status", + "data": { + "device_id": device_id, + "property_name": property_name, + "status": device_status.get(device_id, {}).get(property_name), + "timestamp": time.time(), + }, + } + if self.event_loop: + asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop) + logger.debug(f"[WebSocket] Device status published: {device_id}.{property_name}") + + def publish_job_status( + self, feedback_data: dict, job_id: str, status: str, return_info: Optional[str] = None + ) -> None: + """发布作业状态""" + if self.is_disabled or not self.connected: + logger.warning("[WebSocket] Not connected, cannot publish job status") + return + message = { + "type": "job_status", + "data": { + "job_id": job_id, + "status": status, + "feedback_data": feedback_data, + "return_info": return_info, + "timestamp": time.time(), + }, + } + if self.event_loop: + asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop) + logger.debug(f"[WebSocket] Job status published: {job_id} - {status}") + + def send_ping(self, ping_id: str, timestamp: float) -> None: + """发送ping消息""" + if self.is_disabled or not self.connected: + logger.warning("[WebSocket] Not connected, cannot send ping") + return + message = {"type": "ping", "data": {"ping_id": ping_id, "client_timestamp": timestamp}} + if self.event_loop: + asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop) + logger.debug(f"[WebSocket] Ping sent: {ping_id}") + + @property + def is_connected(self) -> bool: + """检查是否已连接""" + return self.connected and not self.is_disabled diff --git a/unilabos/config/config.py b/unilabos/config/config.py index 19f91b3a..c0edf2ba 100644 --- a/unilabos/config/config.py +++ b/unilabos/config/config.py @@ -5,6 +5,7 @@ import base64 import traceback import os import importlib.util +from typing import Optional from unilabos.utils import logger @@ -20,6 +21,8 @@ class BasicConfig: machine_name = "undefined" vis_2d_enable = False enable_resource_load = True + # 通信协议配置 + communication_protocol = "mqtt" # 支持: "mqtt", "websocket" @classmethod def auth_secret(cls): @@ -27,7 +30,7 @@ class BasicConfig: if not cls.ak or not cls.sk: return "" target = f"{cls.ak}:{cls.sk}" - base64_target = base64.b64encode(target.encode('utf-8')).decode('utf-8') + base64_target = base64.b64encode(target.encode("utf-8")).decode("utf-8") return base64_target @@ -50,6 +53,13 @@ class MQConfig: key_file = "" # 相对config.py所在目录的路径 +# WebSocket配置 +class WSConfig: + reconnect_interval = 5 # 重连间隔(秒) + max_reconnect_attempts = 10 # 最大重连次数 + ping_interval = 30 # ping间隔(秒) + + # OSS上传配置 class OSSUploadConfig: api_host = "" @@ -77,7 +87,7 @@ class ROSConfig: ] -def _update_config_from_module(module, override_labid: str): +def _update_config_from_module(module, override_labid: Optional[str]): for name, obj in globals().items(): if isinstance(obj, type) and name.endswith("Config"): if hasattr(module, name) and isinstance(getattr(module, name), type): @@ -171,7 +181,6 @@ def _update_config_from_env(): logger.warning(f"[ENV] 解析环境变量 {env_key} 失败: {e}") - def load_config(config_path=None, override_labid=None): # 如果提供了配置文件路径,从该文件导入配置 if config_path: diff --git a/unilabos/config/example_config.py b/unilabos/config/example_config.py index 91e08303..07018cba 100644 --- a/unilabos/config/example_config.py +++ b/unilabos/config/example_config.py @@ -12,6 +12,7 @@ class MQConfig: cert_file = "./lab.crt" key_file = "./lab.key" + # HTTP配置 class HTTPConfig: remote_addr = "https://uni-lab.bohrium.com/api/v1" diff --git a/unilabos/ros/nodes/presets/host_node.py b/unilabos/ros/nodes/presets/host_node.py index 5fe90684..4c7053c1 100644 --- a/unilabos/ros/nodes/presets/host_node.py +++ b/unilabos/ros/nodes/presets/host_node.py @@ -152,11 +152,15 @@ class HostNode(BaseROS2DeviceNode): self.device_status = {} # 用来存储设备状态 self.device_status_timestamps = {} # 用来存储设备状态最后更新时间 if BasicConfig.upload_registry: - from unilabos.app.mq import mqtt_client - register_devices_and_resources(mqtt_client, lab_registry) + from unilabos.app.communication import get_communication_client + + comm_client = get_communication_client() + register_devices_and_resources(comm_client, lab_registry) else: - self.lab_logger().warning("本次启动注册表不报送云端,如果您需要联网调试,请使用unilab-register命令进行单独报送,或者在启动命令增加--upload_registry") - time.sleep(1) # 等待MQTT连接稳定 + self.lab_logger().warning( + "本次启动注册表不报送云端,如果您需要联网调试,请使用unilab-register命令进行单独报送,或者在启动命令增加--upload_registry" + ) + time.sleep(1) # 等待通信连接稳定 # 首次发现网络中的设备 self._discover_devices() @@ -214,6 +218,7 @@ class HostNode(BaseROS2DeviceNode): for bridge in self.bridges: if hasattr(bridge, "resource_add"): from unilabos.app.web.client import HTTPClient + client: HTTPClient = bridge resource_start_time = time.time() resource_add_res = client.resource_add(add_schema(resource_with_parent_name), False) @@ -340,9 +345,10 @@ class HostNode(BaseROS2DeviceNode): self.lab_logger().trace(f"[Host Node] Created ActionClient (Discovery): {action_id}") action_name = action_id[len(namespace) + 1 :] edge_device_id = namespace[9:] - # from unilabos.app.mq import mqtt_client + # from unilabos.app.comm_factory import get_communication_client + # comm_client = get_communication_client() # info_with_schema = ros_action_to_json_schema(action_type) - # mqtt_client.publish_actions(action_name, { + # comm_client.publish_actions(action_name, { # "device_id": edge_device_id, # "device_type": "", # "action_name": action_name, @@ -365,7 +371,9 @@ class HostNode(BaseROS2DeviceNode): ): # 这里要求device_id传入必须是edge_device_id if device_id not in self.devices_names: - self.lab_logger().error(f"[Host Node] Device {device_id} not found in devices_names. Create resource failed.") + self.lab_logger().error( + f"[Host Node] Device {device_id} not found in devices_names. Create resource failed." + ) raise ValueError(f"[Host Node] Device {device_id} not found in devices_names. Create resource failed.") device_key = f"{self.devices_names[device_id]}/{device_id}" @@ -425,10 +433,12 @@ class HostNode(BaseROS2DeviceNode): res_creation_input.update( { "data": { - "liquids": [{ - "liquid_type": liquid_type[0] if liquid_type else None, - "liquid_volume": liquid_volume[0] if liquid_volume else None, - }] + "liquids": [ + { + "liquid_type": liquid_type[0] if liquid_type else None, + "liquid_volume": liquid_volume[0] if liquid_volume else None, + } + ] } } ) @@ -451,7 +461,9 @@ class HostNode(BaseROS2DeviceNode): ) ] - response = await self.create_resource_detailed(resources, device_ids, bind_parent_id, bind_location, other_calling_param) + response = await self.create_resource_detailed( + resources, device_ids, bind_parent_id, bind_location, other_calling_param + ) return response @@ -482,7 +494,9 @@ class HostNode(BaseROS2DeviceNode): self.devices_instances[device_id] = d # noinspection PyProtectedMember for action_name, action_value_mapping in d._ros_node._action_value_mappings.items(): - if action_name.startswith("auto-") or str(action_value_mapping.get("type", "")).startswith("UniLabJsonCommand"): + if action_name.startswith("auto-") or str(action_value_mapping.get("type", "")).startswith( + "UniLabJsonCommand" + ): continue action_id = f"/devices/{device_id}/{action_name}" if action_id not in self._action_clients: @@ -491,9 +505,10 @@ class HostNode(BaseROS2DeviceNode): self.lab_logger().trace( f"[Host Node] Created ActionClient (Local): {action_id}" ) # 子设备再创建用的是Discover发现的 - # from unilabos.app.mq import mqtt_client + # from unilabos.app.comm_factory import get_communication_client + # comm_client = get_communication_client() # info_with_schema = ros_action_to_json_schema(action_type) - # mqtt_client.publish_actions(action_name, { + # comm_client.publish_actions(action_name, { # "device_id": device_id, # "device_type": device_config["class"], # "action_name": action_name, @@ -591,13 +606,9 @@ class HostNode(BaseROS2DeviceNode): if hasattr(bridge, "publish_device_status"): bridge.publish_device_status(self.device_status, device_id, property_name) if bCreate: - self.lab_logger().trace( - f"Status created: {device_id}.{property_name} = {msg.data}" - ) + self.lab_logger().trace(f"Status created: {device_id}.{property_name} = {msg.data}") else: - self.lab_logger().debug( - f"Status updated: {device_id}.{property_name} = {msg.data}" - ) + self.lab_logger().debug(f"Status updated: {device_id}.{property_name} = {msg.data}") def send_goal( self, @@ -624,10 +635,12 @@ class HostNode(BaseROS2DeviceNode): action_name = action_name[5:] action_id = f"/devices/{device_id}/_execute_driver_command" action_kwargs = { - "string": json.dumps({ - "function_name": action_name, - "function_args": action_kwargs, - }) + "string": json.dumps( + { + "function_name": action_name, + "function_args": action_kwargs, + } + ) } if action_type.startswith("UniLabJsonCommandAsync"): action_id = f"/devices/{device_id}/_execute_driver_command_async" @@ -802,7 +815,7 @@ class HostNode(BaseROS2DeviceNode): """ self.lab_logger().info(f"[Host Node] Node info update request received: {request}") try: - from unilabos.app.mq import mqtt_client + from unilabos.app.communication import get_communication_client info = json.loads(request.command) if "SYNC_SLAVE_NODE_INFO" in info: @@ -811,9 +824,10 @@ class HostNode(BaseROS2DeviceNode): edge_device_id = info["edge_device_id"] self.device_machine_names[edge_device_id] = machine_name else: + comm_client = get_communication_client() registry_config = info["registry_config"] for device_config in registry_config: - mqtt_client.publish_registry(device_config["id"], device_config) + comm_client.publish_registry(device_config["id"], device_config) self.lab_logger().debug(f"[Host Node] Node info update: {info}") response.response = "OK" except Exception as e: @@ -840,6 +854,7 @@ class HostNode(BaseROS2DeviceNode): success = False if len(self.bridges) > 0: # 边的提交待定 from unilabos.app.web.client import HTTPClient + client: HTTPClient = self.bridges[-1] r = client.resource_add(add_schema(resources), False) success = bool(r) @@ -848,6 +863,7 @@ class HostNode(BaseROS2DeviceNode): if success: from unilabos.resources.graphio import physical_setup_graph + for resource in resources: if resource.get("id") not in physical_setup_graph.nodes: physical_setup_graph.add_node(resource["id"], **resource) @@ -988,9 +1004,10 @@ class HostNode(BaseROS2DeviceNode): send_timestamp = time.time() # 发送ping - from unilabos.app.mq import mqtt_client + from unilabos.app.communication import get_communication_client - mqtt_client.send_ping(ping_id, send_timestamp) + comm_client = get_communication_client() + comm_client.send_ping(ping_id, send_timestamp) # 等待pong响应 timeout = 10.0 diff --git a/unilabos/utils/environment_check.py b/unilabos/utils/environment_check.py index 0c6ae4d7..cd50b822 100644 --- a/unilabos/utils/environment_check.py +++ b/unilabos/utils/environment_check.py @@ -18,6 +18,7 @@ class EnvironmentChecker: self.required_packages = { # 包导入名 : pip安装名 # "pymodbus.framer.FramerType": "pymodbus==3.9.2", + "websockets": "websockets", "paho.mqtt": "paho-mqtt", "opentrons_shared_data": "opentrons_shared_data", }