feat: websocket

This commit is contained in:
Xuwznln
2025-08-28 14:34:38 +08:00
parent 02c79363c1
commit cd84e26126
16 changed files with 626 additions and 67 deletions

View File

@@ -61,7 +61,7 @@ requirements:
- uvicorn - uvicorn
- gradio - gradio
- flask - flask
- websocket - websockets
- ipython - ipython
- jupyter - jupyter
- jupyros - jupyros

View File

@@ -5,6 +5,7 @@
# Uni-Lab-OS # Uni-Lab-OS
<!-- Language switcher --> <!-- Language switcher -->
**English** | [中文](README_zh.md) **English** | [中文](README_zh.md)
[![GitHub Stars](https://img.shields.io/github/stars/dptech-corp/Uni-Lab-OS.svg)](https://github.com/dptech-corp/Uni-Lab-OS/stargazers) [![GitHub Stars](https://img.shields.io/github/stars/dptech-corp/Uni-Lab-OS.svg)](https://github.com/dptech-corp/Uni-Lab-OS/stargazers)

View File

@@ -5,6 +5,7 @@
# Uni-Lab-OS # Uni-Lab-OS
<!-- Language switcher --> <!-- Language switcher -->
[English](README.md) | **中文** [English](README.md) | **中文**
[![GitHub Stars](https://img.shields.io/github/stars/dptech-corp/Uni-Lab-OS.svg)](https://github.com/dptech-corp/Uni-Lab-OS/stargazers) [![GitHub Stars](https://img.shields.io/github/stars/dptech-corp/Uni-Lab-OS.svg)](https://github.com/dptech-corp/Uni-Lab-OS/stargazers)
@@ -12,7 +13,7 @@
[![GitHub Issues](https://img.shields.io/github/issues/dptech-corp/Uni-Lab-OS.svg)](https://github.com/dptech-corp/Uni-Lab-OS/issues) [![GitHub Issues](https://img.shields.io/github/issues/dptech-corp/Uni-Lab-OS.svg)](https://github.com/dptech-corp/Uni-Lab-OS/issues)
[![GitHub License](https://img.shields.io/github/license/dptech-corp/Uni-Lab-OS.svg)](https://github.com/dptech-corp/Uni-Lab-OS/blob/main/LICENSE) [![GitHub License](https://img.shields.io/github/license/dptech-corp/Uni-Lab-OS.svg)](https://github.com/dptech-corp/Uni-Lab-OS/blob/main/LICENSE)
Uni-Lab-OS是一个用于实验室自动化的综合平台旨在连接和控制各种实验设备实现实验流程的自动化和标准化。 Uni-Lab-OS 是一个用于实验室自动化的综合平台,旨在连接和控制各种实验设备,实现实验流程的自动化和标准化。
## 🏆 比赛 ## 🏆 比赛
@@ -34,7 +35,7 @@ Uni-Lab-OS是一个用于实验室自动化的综合平台旨在连接和控
## 快速开始 ## 快速开始
1. 配置Conda环境 1. 配置 Conda 环境
Uni-Lab-OS 建议使用 `mamba` 管理环境。根据您的操作系统选择适当的环境文件: Uni-Lab-OS 建议使用 `mamba` 管理环境。根据您的操作系统选择适当的环境文件:
@@ -43,7 +44,7 @@ Uni-Lab-OS 建议使用 `mamba` 管理环境。根据您的操作系统选择适
mamba create -n unilab uni-lab::unilabos -c robostack-staging -c conda-forge mamba create -n unilab uni-lab::unilabos -c robostack-staging -c conda-forge
``` ```
2. 安装开发版Uni-Lab-OS: 2. 安装开发版 Uni-Lab-OS:
```bash ```bash
# 克隆仓库 # 克隆仓库

View File

@@ -34,7 +34,7 @@ dependencies:
- uvicorn - uvicorn
- gradio - gradio
- flask - flask
- websocket - websockets
# Notebook # Notebook
- ipython - ipython
- jupyter - jupyter

View File

@@ -34,7 +34,7 @@ dependencies:
- uvicorn - uvicorn
- gradio - gradio
- flask - flask
- websocket - websockets
# Notebook # Notebook
- ipython - ipython
- jupyter - jupyter

View File

@@ -35,8 +35,7 @@ dependencies:
- uvicorn - uvicorn
- gradio - gradio
- flask - flask
- websocket - websockets
- paho-mqtt
# Notebook # Notebook
- ipython - ipython
- jupyter - jupyter

View File

@@ -34,7 +34,7 @@ dependencies:
- uvicorn - uvicorn
- gradio - gradio
- flask - flask
- websocket - websockets
# Notebook # Notebook
- ipython - ipython
- jupyter - jupyter

View File

@@ -0,0 +1,204 @@
#!/usr/bin/env python
# coding=utf-8
"""
通信模块
提供MQTT和WebSocket的统一接口支持通过配置选择通信协议。
包含通信抽象层基类和通信客户端工厂。
"""
from abc import ABC, abstractmethod
from typing import Optional
from unilabos.config.config import BasicConfig
from unilabos.utils import logger
class BaseCommunicationClient(ABC):
"""
通信客户端抽象基类
定义了所有通信客户端MQTT、WebSocket等需要实现的接口。
"""
def __init__(self):
self.is_disabled = True
self.client_id = ""
@abstractmethod
def start(self) -> None:
"""
启动通信客户端连接
"""
pass
@abstractmethod
def stop(self) -> None:
"""
停止通信客户端连接
"""
pass
@abstractmethod
def publish_device_status(self, device_status: dict, device_id: str, property_name: str) -> None:
"""
发布设备状态信息
Args:
device_status: 设备状态字典
device_id: 设备ID
property_name: 属性名称
"""
pass
@abstractmethod
def publish_job_status(
self, feedback_data: dict, job_id: str, status: str, return_info: Optional[str] = None
) -> None:
"""
发布作业状态信息
Args:
feedback_data: 反馈数据
job_id: 作业ID
status: 作业状态
return_info: 返回信息
"""
pass
@abstractmethod
def send_ping(self, ping_id: str, timestamp: float) -> None:
"""
发送ping消息
Args:
ping_id: ping ID
timestamp: 时间戳
"""
pass
def setup_pong_subscription(self) -> None:
"""
设置pong消息订阅可选实现
"""
pass
@property
def is_connected(self) -> bool:
"""
检查是否已连接
Returns:
是否已连接
"""
return not self.is_disabled
class CommunicationClientFactory:
"""
通信客户端工厂类
根据配置文件中的通信协议设置创建相应的客户端实例。
"""
_client_cache: Optional[BaseCommunicationClient] = None
@classmethod
def create_client(cls, protocol: Optional[str] = None) -> BaseCommunicationClient:
"""
创建通信客户端实例
Args:
protocol: 指定的协议类型如果为None则使用配置文件中的设置
Returns:
通信客户端实例
Raises:
ValueError: 当协议类型不支持时
"""
if protocol is None:
protocol = BasicConfig.communication_protocol
protocol = protocol.lower()
if protocol == "mqtt":
return cls._create_mqtt_client()
elif protocol == "websocket":
return cls._create_websocket_client()
else:
logger.error(f"[CommunicationFactory] Unsupported protocol: {protocol}")
logger.warning(f"[CommunicationFactory] Falling back to MQTT")
return cls._create_mqtt_client()
@classmethod
def get_client(cls, protocol: Optional[str] = None) -> BaseCommunicationClient:
"""
获取通信客户端实例(单例模式)
Args:
protocol: 指定的协议类型如果为None则使用配置文件中的设置
Returns:
通信客户端实例
"""
if cls._client_cache is None:
cls._client_cache = cls.create_client(protocol)
logger.info(f"[CommunicationFactory] Created {type(cls._client_cache).__name__} client")
return cls._client_cache
@classmethod
def _create_mqtt_client(cls) -> BaseCommunicationClient:
"""创建MQTT客户端"""
try:
from unilabos.app.mq import mqtt_client
return mqtt_client
except Exception as e:
logger.error(f"[CommunicationFactory] Failed to create MQTT client: {str(e)}")
raise
@classmethod
def _create_websocket_client(cls) -> BaseCommunicationClient:
"""创建WebSocket客户端"""
try:
from unilabos.app.ws_client import WebSocketClient
return WebSocketClient()
except Exception as e:
logger.error(f"[CommunicationFactory] Failed to create WebSocket client: {str(e)}")
logger.warning(f"[CommunicationFactory] Falling back to MQTT")
return cls._create_mqtt_client()
@classmethod
def reset_client(cls):
"""重置客户端缓存(用于测试或重新配置)"""
if cls._client_cache:
try:
cls._client_cache.stop()
except Exception as e:
logger.warning(f"[CommunicationFactory] Error stopping old client: {str(e)}")
cls._client_cache = None
logger.info("[CommunicationFactory] Client cache reset")
@classmethod
def get_supported_protocols(cls) -> list[str]:
"""
获取支持的协议列表
Returns:
支持的协议列表
"""
return ["mqtt", "websocket"]
def get_communication_client(protocol: Optional[str] = None) -> BaseCommunicationClient:
"""
获取通信客户端实例的便捷函数
Args:
protocol: 指定的协议类型如果为None则使用配置文件中的设置
Returns:
通信客户端实例
"""
return CommunicationClientFactory.get_client(protocol)

View File

@@ -10,7 +10,6 @@ from copy import deepcopy
import yaml import yaml
from unilabos.resources.graphio import modify_to_backend_format
# 首先添加项目根目录到路径 # 首先添加项目根目录到路径
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -20,6 +19,7 @@ if unilabos_dir not in sys.path:
from unilabos.config.config import load_config, BasicConfig from unilabos.config.config import load_config, BasicConfig
from unilabos.utils.banner_print import print_status, print_unilab_banner from unilabos.utils.banner_print import print_status, print_unilab_banner
from unilabos.resources.graphio import modify_to_backend_format
def load_config_from_file(config_path, override_labid=None): def load_config_from_file(config_path, override_labid=None):
@@ -146,6 +146,11 @@ def parse_args():
default="", default="",
help="实验室请求的sk", help="实验室请求的sk",
) )
parser.add_argument(
"--websocket",
action="store_true",
help="使用websocket而非mqtt作为通信协议",
)
parser.add_argument( parser.add_argument(
"--skip_env_check", "--skip_env_check",
action="store_true", action="store_true",
@@ -179,7 +184,7 @@ def main():
else: else:
working_dir = os.path.abspath(os.path.join(os.getcwd(), "unilabos_data")) working_dir = os.path.abspath(os.path.join(os.getcwd(), "unilabos_data"))
if args_dict.get("working_dir"): if args_dict.get("working_dir"):
working_dir = args_dict.get("working_dir") working_dir = args_dict.get("working_dir", "")
if config_path and not os.path.exists(config_path): if config_path and not os.path.exists(config_path):
config_path = os.path.join(working_dir, "local_config.py") config_path = os.path.join(working_dir, "local_config.py")
if not os.path.exists(config_path): if not os.path.exists(config_path):
@@ -215,6 +220,7 @@ def main():
if args_dict["use_remote_resource"]: if args_dict["use_remote_resource"]:
print_status("使用远程资源启动", "info") print_status("使用远程资源启动", "info")
from unilabos.app.web import http_client from unilabos.app.web import http_client
res = http_client.resource_get("host_node", False) res = http_client.resource_get("host_node", False)
if str(res.get("code", 0)) == "0" and len(res.get("data", [])) > 0: if str(res.get("code", 0)) == "0" and len(res.get("data", [])) > 0:
print_status("远程资源已存在,使用云端物料!", "info") print_status("远程资源已存在,使用云端物料!", "info")
@@ -229,6 +235,7 @@ def main():
BasicConfig.is_host_mode = not args_dict.get("without_host", False) BasicConfig.is_host_mode = not args_dict.get("without_host", False)
BasicConfig.slave_no_host = args_dict.get("slave_no_host", False) BasicConfig.slave_no_host = args_dict.get("slave_no_host", False)
BasicConfig.upload_registry = args_dict.get("upload_registry", False) BasicConfig.upload_registry = args_dict.get("upload_registry", False)
BasicConfig.communication_protocol = "websocket" if args_dict.get("websocket", False) else "mqtt"
machine_name = os.popen("hostname").read().strip() machine_name = os.popen("hostname").read().strip()
machine_name = "".join([c if c.isalnum() or c == "_" else "_" for c in machine_name]) machine_name = "".join([c if c.isalnum() or c == "_" else "_" for c in machine_name])
BasicConfig.machine_name = machine_name BasicConfig.machine_name = machine_name
@@ -241,7 +248,7 @@ def main():
dict_to_nested_dict, dict_to_nested_dict,
initialize_resources, initialize_resources,
) )
from unilabos.app.mq import mqtt_client from unilabos.app.communication import get_communication_client
from unilabos.registry.registry import build_registry from unilabos.registry.registry import build_registry
from unilabos.app.backend import start_backend from unilabos.app.backend import start_backend
from unilabos.app.web import http_client from unilabos.app.web import http_client
@@ -289,19 +296,22 @@ def main():
args_dict["bridges"] = [] args_dict["bridges"] = []
# 获取通信客户端根据配置选择MQTT或WebSocket
comm_client = get_communication_client()
if "mqtt" in args_dict["app_bridges"]: if "mqtt" in args_dict["app_bridges"]:
args_dict["bridges"].append(mqtt_client) args_dict["bridges"].append(comm_client)
if "fastapi" in args_dict["app_bridges"]: if "fastapi" in args_dict["app_bridges"]:
args_dict["bridges"].append(http_client) args_dict["bridges"].append(http_client)
if "mqtt" in args_dict["app_bridges"]: if "mqtt" in args_dict["app_bridges"]:
def _exit(signum, frame): def _exit(signum, frame):
mqtt_client.stop() comm_client.stop()
sys.exit(0) sys.exit(0)
signal.signal(signal.SIGINT, _exit) signal.signal(signal.SIGINT, _exit)
signal.signal(signal.SIGTERM, _exit) signal.signal(signal.SIGTERM, _exit)
mqtt_client.start() comm_client.start()
args_dict["resources_mesh_config"] = {} args_dict["resources_mesh_config"] = {}
args_dict["resources_edge_config"] = resource_edge_info args_dict["resources_edge_config"] = resource_edge_info
# web visiualize 2D # web visiualize 2D

View File

@@ -15,17 +15,20 @@ import os
from unilabos.config.config import MQConfig from unilabos.config.config import MQConfig
from unilabos.app.controler import job_add from unilabos.app.controler import job_add
from unilabos.app.model import JobAddReq from unilabos.app.model import JobAddReq
from unilabos.app.communication import BaseCommunicationClient
from unilabos.utils import logger from unilabos.utils import logger
from unilabos.utils.type_check import TypeEncoder from unilabos.utils.type_check import TypeEncoder
from paho.mqtt.enums import CallbackAPIVersion from paho.mqtt.enums import CallbackAPIVersion
class MQTTClient: class MQTTClient(BaseCommunicationClient):
mqtt_disable = True mqtt_disable = True
def __init__(self): def __init__(self):
super().__init__()
self.mqtt_disable = not MQConfig.lab_id self.mqtt_disable = not MQConfig.lab_id
self.is_disabled = self.mqtt_disable # 更新父类属性
self.client_id = f"{MQConfig.group_id}@@@{MQConfig.lab_id}{uuid.uuid4()}" self.client_id = f"{MQConfig.group_id}@@@{MQConfig.lab_id}{uuid.uuid4()}"
logger.info("[MQTT] Client_id: " + self.client_id) logger.info("[MQTT] Client_id: " + self.client_id)
self.client = mqtt.Client(CallbackAPIVersion.VERSION2, client_id=self.client_id, protocol=mqtt.MQTTv5) self.client = mqtt.Client(CallbackAPIVersion.VERSION2, client_id=self.client_id, protocol=mqtt.MQTTv5)
@@ -208,11 +211,12 @@ class MQTTClient:
self.client.subscribe(pong_topic, 0) self.client.subscribe(pong_topic, 0)
logger.debug(f"Subscribed to pong topic: {pong_topic}") logger.debug(f"Subscribed to pong topic: {pong_topic}")
def handle_pong(self, pong_data: dict): @property
"""处理pong响应这个方法会在收到pong消息时被调用""" def is_connected(self) -> bool:
logger.debug(f"Pong received: {pong_data}") """检查MQTT是否已连接"""
# 这里会被HostNode的ping-pong处理逻辑调用 if self.is_disabled:
pass return False
return hasattr(self.client, "is_connected") and self.client.is_connected()
mqtt_client = MQTTClient() mqtt_client = MQTTClient()

View File

@@ -10,35 +10,44 @@ from unilabos.utils.log import logger
from unilabos.utils.type_check import TypeEncoder from unilabos.utils.type_check import TypeEncoder
def register_devices_and_resources(mqtt_client, lab_registry): def register_devices_and_resources(comm_client, lab_registry):
""" """
注册设备和资源到 MQTT 注册设备和资源到通信服务器MQTT/WebSocket
""" """
# 注册资源信息 - 使用HTTP方式 # 注册资源信息 - 使用HTTP方式
from unilabos.app.web.client import http_client from unilabos.app.web.client import http_client
logger.info("[UniLab Register] 开始注册设备和资源...") logger.info("[UniLab Register] 开始注册设备和资源...")
if BasicConfig.auth_secret(): if BasicConfig.auth_secret():
# 注册设备信息 # 注册设备信息
devices_to_register = {} devices_to_register = {}
for device_info in lab_registry.obtain_registry_device_info(): for device_info in lab_registry.obtain_registry_device_info():
devices_to_register[device_info["id"]] = json.loads(json.dumps(device_info, ensure_ascii=False, cls=TypeEncoder)) devices_to_register[device_info["id"]] = json.loads(
json.dumps(device_info, ensure_ascii=False, cls=TypeEncoder)
)
logger.debug(f"[UniLab Register] 收集设备: {device_info['id']}") logger.debug(f"[UniLab Register] 收集设备: {device_info['id']}")
resources_to_register = {} resources_to_register = {}
for resource_info in lab_registry.obtain_registry_resource_info(): for resource_info in lab_registry.obtain_registry_resource_info():
resources_to_register[resource_info["id"]] = resource_info resources_to_register[resource_info["id"]] = resource_info
logger.debug(f"[UniLab Register] 收集资源: {resource_info['id']}") logger.debug(f"[UniLab Register] 收集资源: {resource_info['id']}")
print("[UniLab Register] 设备注册", http_client.resource_registry({"resources": list(devices_to_register.values())}).text) print(
print("[UniLab Register] 资源注册", http_client.resource_registry({"resources": list(resources_to_register.values())}).text) "[UniLab Register] 设备注册",
http_client.resource_registry({"resources": list(devices_to_register.values())}).text,
)
print(
"[UniLab Register] 资源注册",
http_client.resource_registry({"resources": list(resources_to_register.values())}).text,
)
else: else:
# 注册设备信息 # 注册设备信息
for device_info in lab_registry.obtain_registry_device_info(): for device_info in lab_registry.obtain_registry_device_info():
mqtt_client.publish_registry(device_info["id"], device_info, False) comm_client.publish_registry(device_info["id"], device_info, False)
logger.debug(f"[UniLab Register] 注册设备: {device_info['id']}") logger.debug(f"[UniLab Register] 注册设备: {device_info['id']}")
# # 注册资源信息 # # 注册资源信息
# for resource_info in lab_registry.obtain_registry_resource_info(): # for resource_info in lab_registry.obtain_registry_resource_info():
# mqtt_client.publish_registry(resource_info["id"], resource_info, False) # comm_client.publish_registry(resource_info["id"], resource_info, False)
# logger.debug(f"[UniLab Register] 注册资源: {resource_info['id']}") # logger.debug(f"[UniLab Register] 注册资源: {resource_info['id']}")
resources_to_register = {} resources_to_register = {}
@@ -53,7 +62,9 @@ def register_devices_and_resources(mqtt_client, lab_registry):
if response.status_code in [200, 201]: if response.status_code in [200, 201]:
logger.info(f"[UniLab Register] 成功通过HTTP注册 {len(resources_to_register)} 个资源 {cost_time}ms") logger.info(f"[UniLab Register] 成功通过HTTP注册 {len(resources_to_register)} 个资源 {cost_time}ms")
else: else:
logger.error(f"[UniLab Register] HTTP注册资源失败: {response.status_code}, {response.text} {cost_time}ms") logger.error(
f"[UniLab Register] HTTP注册资源失败: {response.status_code}, {response.text} {cost_time}ms"
)
logger.info("[UniLab Register] 设备和资源注册完成.") logger.info("[UniLab Register] 设备和资源注册完成.")
@@ -99,15 +110,16 @@ def main():
BasicConfig.sk = args.sk BasicConfig.sk = args.sk
# 构建注册表 # 构建注册表
build_registry(args.registry, args.complete_registry, True) build_registry(args.registry, args.complete_registry, True)
from unilabos.app.mq import mqtt_client from unilabos.app.communication import get_communication_client
# 连接mqtt # 获取通信客户端并启动连接
mqtt_client.start() comm_client = get_communication_client()
comm_client.start()
from unilabos.registry.registry import lab_registry from unilabos.registry.registry import lab_registry
# 注册设备和资源 # 注册设备和资源
register_devices_and_resources(mqtt_client, lab_registry) register_devices_and_resources(comm_client, lab_registry)
if __name__ == "__main__": if __name__ == "__main__":

300
unilabos/app/ws_client.py Normal file
View File

@@ -0,0 +1,300 @@
#!/usr/bin/env python
# coding=utf-8
"""
WebSocket通信客户端
基于WebSocket协议的通信客户端实现继承自BaseCommunicationClient。
"""
import json
import time
import uuid
import threading
import asyncio
import traceback
from typing import Optional, Dict, Any
from urllib.parse import urlparse
from unilabos.app.controler import job_add
from unilabos.app.model import JobAddReq
from unilabos.ros.nodes.presets.host_node import HostNode
try:
import websockets
import ssl as ssl_module
HAS_WEBSOCKETS = True
except ImportError:
HAS_WEBSOCKETS = False
from unilabos.app.communication import BaseCommunicationClient
from unilabos.config.config import WSConfig, HTTPConfig, BasicConfig
from unilabos.utils import logger
class WebSocketClient(BaseCommunicationClient):
"""
WebSocket通信客户端类
实现基于WebSocket协议的实时通信功能。
"""
def __init__(self):
super().__init__()
if not HAS_WEBSOCKETS:
logger.error("[WebSocket] websockets库未安装WebSocket功能不可用")
self.is_disabled = True
return
self.is_disabled = False
self.client_id = f"{uuid.uuid4()}"
# WebSocket连接相关
self.websocket = None
self.connection_loop = None
self.event_loop = None
self.connection_thread = None
self.is_running = False
self.connected = False
# 消息处理
self.message_queue = asyncio.Queue() if not self.is_disabled else None
self.reconnect_count = 0
# 构建WebSocket URL
self._build_websocket_url()
logger.info(f"[WebSocket] Client_id: {self.client_id}")
def _build_websocket_url(self):
"""构建WebSocket连接URL"""
if not HTTPConfig.remote_addr:
self.websocket_url = None
return
# 解析服务器URL
parsed = urlparse(HTTPConfig.remote_addr)
# 根据SSL配置选择协议
if parsed.scheme == "https":
scheme = "wss"
else:
scheme = "ws"
self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/lab"
logger.debug(f"[WebSocket] URL: {self.websocket_url}")
def start(self) -> None:
"""启动WebSocket连接"""
if self.is_disabled:
logger.warning("[WebSocket] WebSocket is disabled, skipping connection.")
return
if not self.websocket_url:
logger.error("[WebSocket] WebSocket URL not configured")
return
logger.info(f"[WebSocket] Starting connection to {self.websocket_url}")
self.is_running = True
# 在单独线程中运行WebSocket连接
self.connection_thread = threading.Thread(target=self._run_connection, daemon=True, name="WebSocketConnection")
self.connection_thread.start()
def stop(self) -> None:
"""停止WebSocket连接"""
if self.is_disabled:
return
logger.info("[WebSocket] Stopping connection")
self.is_running = False
self.connected = False
if self.event_loop and self.event_loop.is_running():
asyncio.run_coroutine_threadsafe(self._close_connection(), self.event_loop)
if self.connection_thread and self.connection_thread.is_alive():
self.connection_thread.join(timeout=5)
def _run_connection(self):
"""在独立线程中运行WebSocket连接"""
try:
# 创建新的事件循环
self.event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.event_loop)
# 运行连接逻辑
self.event_loop.run_until_complete(self._connection_handler())
except Exception as e:
logger.error(f"[WebSocket] Connection thread error: {str(e)}")
logger.error(traceback.format_exc())
finally:
if self.event_loop:
self.event_loop.close()
async def _connection_handler(self):
"""处理WebSocket连接和重连逻辑"""
while self.is_running:
try:
# 构建SSL上下文
ssl_context = None
assert self.websocket_url is not None
if self.websocket_url.startswith("wss://"):
ssl_context = ssl_module.create_default_context()
async with websockets.connect(
self.websocket_url,
ssl=ssl_context,
ping_interval=WSConfig.ping_interval,
ping_timeout=10,
additional_headers={"Authorization": f"Bearer {BasicConfig.auth_secret()}"},
) as websocket:
self.websocket = websocket
self.connected = True
self.reconnect_count = 0
logger.info(f"[WebSocket] Connected to {self.websocket_url}")
# 处理消息
await self._message_handler()
except websockets.exceptions.ConnectionClosed:
logger.warning("[WebSocket] Connection closed")
self.connected = False
except Exception as e:
logger.error(f"[WebSocket] Connection error: {str(e)}")
self.connected = False
# 重连逻辑
if self.is_running and self.reconnect_count < WSConfig.max_reconnect_attempts:
self.reconnect_count += 1
logger.info(
f"[WebSocket] Reconnecting in {WSConfig.reconnect_interval}s "
f"(attempt {self.reconnect_count}/{WSConfig.max_reconnect_attempts})"
)
await asyncio.sleep(WSConfig.reconnect_interval)
elif self.reconnect_count >= WSConfig.max_reconnect_attempts:
logger.error("[WebSocket] Max reconnection attempts reached")
break
async def _close_connection(self):
"""关闭WebSocket连接"""
if self.websocket:
await self.websocket.close()
self.websocket = None
async def _send_message(self, message: Dict[str, Any]):
"""发送消息"""
if not self.connected or not self.websocket:
logger.warning("[WebSocket] Not connected, cannot send message")
return
try:
message_str = json.dumps(message, ensure_ascii=False)
await self.websocket.send(message_str)
logger.debug(f"[WebSocket] Message sent: {message['type']}")
except Exception as e:
logger.error(f"[WebSocket] Failed to send message: {str(e)}")
async def _message_handler(self):
"""处理接收到的消息"""
if not self.websocket:
logger.error("[WebSocket] WebSocket connection is None")
return
try:
async for message in self.websocket:
try:
data = json.loads(message)
await self._process_message(data)
except json.JSONDecodeError:
logger.error(f"[WebSocket] Invalid JSON received: {message}")
except Exception as e:
logger.error(f"[WebSocket] Error processing message: {str(e)}")
except websockets.exceptions.ConnectionClosed:
logger.info("[WebSocket] Message handler stopped - connection closed")
except Exception as e:
logger.error(f"[WebSocket] Message handler error: {str(e)}")
async def _process_message(self, input_message: Dict[str, Any]):
"""处理收到的消息"""
message_type = input_message.get("type", "")
data = input_message.get("data", {})
if message_type == "job_start":
# 处理作业启动消息
await self._handle_job_start(data)
elif message_type == "pong":
# 处理pong响应
self._handle_pong_sync(data)
else:
logger.debug(f"[WebSocket] Unknown message type: {message_type}")
async def _handle_job_start(self, data: Dict[str, Any]):
"""处理作业启动消息"""
try:
job_req = JobAddReq(**data.get("job_data", {}))
job_add(job_req)
job_id = getattr(job_req, "id", "unknown")
logger.info(f"[WebSocket] Job started: {job_id}")
except Exception as e:
logger.error(f"[WebSocket] Error handling job start: {str(e)}")
def _handle_pong_sync(self, pong_data: Dict[str, Any]):
"""同步处理pong响应"""
host_node = HostNode.get_instance(0)
if host_node:
host_node.handle_pong_response(pong_data)
# 实现抽象基类的方法
def publish_device_status(self, device_status: dict, device_id: str, property_name: str) -> None:
"""发布设备状态"""
if self.is_disabled or not self.connected:
return
message = {
"type": "device_status",
"data": {
"device_id": device_id,
"property_name": property_name,
"status": device_status.get(device_id, {}).get(property_name),
"timestamp": time.time(),
},
}
if self.event_loop:
asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop)
logger.debug(f"[WebSocket] Device status published: {device_id}.{property_name}")
def publish_job_status(
self, feedback_data: dict, job_id: str, status: str, return_info: Optional[str] = None
) -> None:
"""发布作业状态"""
if self.is_disabled or not self.connected:
logger.warning("[WebSocket] Not connected, cannot publish job status")
return
message = {
"type": "job_status",
"data": {
"job_id": job_id,
"status": status,
"feedback_data": feedback_data,
"return_info": return_info,
"timestamp": time.time(),
},
}
if self.event_loop:
asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop)
logger.debug(f"[WebSocket] Job status published: {job_id} - {status}")
def send_ping(self, ping_id: str, timestamp: float) -> None:
"""发送ping消息"""
if self.is_disabled or not self.connected:
logger.warning("[WebSocket] Not connected, cannot send ping")
return
message = {"type": "ping", "data": {"ping_id": ping_id, "client_timestamp": timestamp}}
if self.event_loop:
asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop)
logger.debug(f"[WebSocket] Ping sent: {ping_id}")
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self.connected and not self.is_disabled

View File

@@ -5,6 +5,7 @@ import base64
import traceback import traceback
import os import os
import importlib.util import importlib.util
from typing import Optional
from unilabos.utils import logger from unilabos.utils import logger
@@ -20,6 +21,8 @@ class BasicConfig:
machine_name = "undefined" machine_name = "undefined"
vis_2d_enable = False vis_2d_enable = False
enable_resource_load = True enable_resource_load = True
# 通信协议配置
communication_protocol = "mqtt" # 支持: "mqtt", "websocket"
@classmethod @classmethod
def auth_secret(cls): def auth_secret(cls):
@@ -27,7 +30,7 @@ class BasicConfig:
if not cls.ak or not cls.sk: if not cls.ak or not cls.sk:
return "" return ""
target = f"{cls.ak}:{cls.sk}" target = f"{cls.ak}:{cls.sk}"
base64_target = base64.b64encode(target.encode('utf-8')).decode('utf-8') base64_target = base64.b64encode(target.encode("utf-8")).decode("utf-8")
return base64_target return base64_target
@@ -50,6 +53,13 @@ class MQConfig:
key_file = "" # 相对config.py所在目录的路径 key_file = "" # 相对config.py所在目录的路径
# WebSocket配置
class WSConfig:
reconnect_interval = 5 # 重连间隔(秒)
max_reconnect_attempts = 10 # 最大重连次数
ping_interval = 30 # ping间隔
# OSS上传配置 # OSS上传配置
class OSSUploadConfig: class OSSUploadConfig:
api_host = "" api_host = ""
@@ -77,7 +87,7 @@ class ROSConfig:
] ]
def _update_config_from_module(module, override_labid: str): def _update_config_from_module(module, override_labid: Optional[str]):
for name, obj in globals().items(): for name, obj in globals().items():
if isinstance(obj, type) and name.endswith("Config"): if isinstance(obj, type) and name.endswith("Config"):
if hasattr(module, name) and isinstance(getattr(module, name), type): if hasattr(module, name) and isinstance(getattr(module, name), type):
@@ -171,7 +181,6 @@ def _update_config_from_env():
logger.warning(f"[ENV] 解析环境变量 {env_key} 失败: {e}") logger.warning(f"[ENV] 解析环境变量 {env_key} 失败: {e}")
def load_config(config_path=None, override_labid=None): def load_config(config_path=None, override_labid=None):
# 如果提供了配置文件路径,从该文件导入配置 # 如果提供了配置文件路径,从该文件导入配置
if config_path: if config_path:

View File

@@ -12,6 +12,7 @@ class MQConfig:
cert_file = "./lab.crt" cert_file = "./lab.crt"
key_file = "./lab.key" key_file = "./lab.key"
# HTTP配置 # HTTP配置
class HTTPConfig: class HTTPConfig:
remote_addr = "https://uni-lab.bohrium.com/api/v1" remote_addr = "https://uni-lab.bohrium.com/api/v1"

View File

@@ -152,11 +152,15 @@ class HostNode(BaseROS2DeviceNode):
self.device_status = {} # 用来存储设备状态 self.device_status = {} # 用来存储设备状态
self.device_status_timestamps = {} # 用来存储设备状态最后更新时间 self.device_status_timestamps = {} # 用来存储设备状态最后更新时间
if BasicConfig.upload_registry: if BasicConfig.upload_registry:
from unilabos.app.mq import mqtt_client from unilabos.app.communication import get_communication_client
register_devices_and_resources(mqtt_client, lab_registry)
comm_client = get_communication_client()
register_devices_and_resources(comm_client, lab_registry)
else: else:
self.lab_logger().warning("本次启动注册表不报送云端如果您需要联网调试请使用unilab-register命令进行单独报送或者在启动命令增加--upload_registry") self.lab_logger().warning(
time.sleep(1) # 等待MQTT连接稳定 "本次启动注册表不报送云端如果您需要联网调试请使用unilab-register命令进行单独报送或者在启动命令增加--upload_registry"
)
time.sleep(1) # 等待通信连接稳定
# 首次发现网络中的设备 # 首次发现网络中的设备
self._discover_devices() self._discover_devices()
@@ -214,6 +218,7 @@ class HostNode(BaseROS2DeviceNode):
for bridge in self.bridges: for bridge in self.bridges:
if hasattr(bridge, "resource_add"): if hasattr(bridge, "resource_add"):
from unilabos.app.web.client import HTTPClient from unilabos.app.web.client import HTTPClient
client: HTTPClient = bridge client: HTTPClient = bridge
resource_start_time = time.time() resource_start_time = time.time()
resource_add_res = client.resource_add(add_schema(resource_with_parent_name), False) resource_add_res = client.resource_add(add_schema(resource_with_parent_name), False)
@@ -340,9 +345,10 @@ class HostNode(BaseROS2DeviceNode):
self.lab_logger().trace(f"[Host Node] Created ActionClient (Discovery): {action_id}") self.lab_logger().trace(f"[Host Node] Created ActionClient (Discovery): {action_id}")
action_name = action_id[len(namespace) + 1 :] action_name = action_id[len(namespace) + 1 :]
edge_device_id = namespace[9:] edge_device_id = namespace[9:]
# from unilabos.app.mq import mqtt_client # from unilabos.app.comm_factory import get_communication_client
# comm_client = get_communication_client()
# info_with_schema = ros_action_to_json_schema(action_type) # info_with_schema = ros_action_to_json_schema(action_type)
# mqtt_client.publish_actions(action_name, { # comm_client.publish_actions(action_name, {
# "device_id": edge_device_id, # "device_id": edge_device_id,
# "device_type": "", # "device_type": "",
# "action_name": action_name, # "action_name": action_name,
@@ -365,7 +371,9 @@ class HostNode(BaseROS2DeviceNode):
): ):
# 这里要求device_id传入必须是edge_device_id # 这里要求device_id传入必须是edge_device_id
if device_id not in self.devices_names: if device_id not in self.devices_names:
self.lab_logger().error(f"[Host Node] Device {device_id} not found in devices_names. Create resource failed.") self.lab_logger().error(
f"[Host Node] Device {device_id} not found in devices_names. Create resource failed."
)
raise ValueError(f"[Host Node] Device {device_id} not found in devices_names. Create resource failed.") raise ValueError(f"[Host Node] Device {device_id} not found in devices_names. Create resource failed.")
device_key = f"{self.devices_names[device_id]}/{device_id}" device_key = f"{self.devices_names[device_id]}/{device_id}"
@@ -425,10 +433,12 @@ class HostNode(BaseROS2DeviceNode):
res_creation_input.update( res_creation_input.update(
{ {
"data": { "data": {
"liquids": [{ "liquids": [
{
"liquid_type": liquid_type[0] if liquid_type else None, "liquid_type": liquid_type[0] if liquid_type else None,
"liquid_volume": liquid_volume[0] if liquid_volume else None, "liquid_volume": liquid_volume[0] if liquid_volume else None,
}] }
]
} }
} }
) )
@@ -451,7 +461,9 @@ class HostNode(BaseROS2DeviceNode):
) )
] ]
response = await self.create_resource_detailed(resources, device_ids, bind_parent_id, bind_location, other_calling_param) response = await self.create_resource_detailed(
resources, device_ids, bind_parent_id, bind_location, other_calling_param
)
return response return response
@@ -482,7 +494,9 @@ class HostNode(BaseROS2DeviceNode):
self.devices_instances[device_id] = d self.devices_instances[device_id] = d
# noinspection PyProtectedMember # noinspection PyProtectedMember
for action_name, action_value_mapping in d._ros_node._action_value_mappings.items(): for action_name, action_value_mapping in d._ros_node._action_value_mappings.items():
if action_name.startswith("auto-") or str(action_value_mapping.get("type", "")).startswith("UniLabJsonCommand"): if action_name.startswith("auto-") or str(action_value_mapping.get("type", "")).startswith(
"UniLabJsonCommand"
):
continue continue
action_id = f"/devices/{device_id}/{action_name}" action_id = f"/devices/{device_id}/{action_name}"
if action_id not in self._action_clients: if action_id not in self._action_clients:
@@ -491,9 +505,10 @@ class HostNode(BaseROS2DeviceNode):
self.lab_logger().trace( self.lab_logger().trace(
f"[Host Node] Created ActionClient (Local): {action_id}" f"[Host Node] Created ActionClient (Local): {action_id}"
) # 子设备再创建用的是Discover发现的 ) # 子设备再创建用的是Discover发现的
# from unilabos.app.mq import mqtt_client # from unilabos.app.comm_factory import get_communication_client
# comm_client = get_communication_client()
# info_with_schema = ros_action_to_json_schema(action_type) # info_with_schema = ros_action_to_json_schema(action_type)
# mqtt_client.publish_actions(action_name, { # comm_client.publish_actions(action_name, {
# "device_id": device_id, # "device_id": device_id,
# "device_type": device_config["class"], # "device_type": device_config["class"],
# "action_name": action_name, # "action_name": action_name,
@@ -591,13 +606,9 @@ class HostNode(BaseROS2DeviceNode):
if hasattr(bridge, "publish_device_status"): if hasattr(bridge, "publish_device_status"):
bridge.publish_device_status(self.device_status, device_id, property_name) bridge.publish_device_status(self.device_status, device_id, property_name)
if bCreate: if bCreate:
self.lab_logger().trace( self.lab_logger().trace(f"Status created: {device_id}.{property_name} = {msg.data}")
f"Status created: {device_id}.{property_name} = {msg.data}"
)
else: else:
self.lab_logger().debug( self.lab_logger().debug(f"Status updated: {device_id}.{property_name} = {msg.data}")
f"Status updated: {device_id}.{property_name} = {msg.data}"
)
def send_goal( def send_goal(
self, self,
@@ -624,10 +635,12 @@ class HostNode(BaseROS2DeviceNode):
action_name = action_name[5:] action_name = action_name[5:]
action_id = f"/devices/{device_id}/_execute_driver_command" action_id = f"/devices/{device_id}/_execute_driver_command"
action_kwargs = { action_kwargs = {
"string": json.dumps({ "string": json.dumps(
{
"function_name": action_name, "function_name": action_name,
"function_args": action_kwargs, "function_args": action_kwargs,
}) }
)
} }
if action_type.startswith("UniLabJsonCommandAsync"): if action_type.startswith("UniLabJsonCommandAsync"):
action_id = f"/devices/{device_id}/_execute_driver_command_async" action_id = f"/devices/{device_id}/_execute_driver_command_async"
@@ -802,7 +815,7 @@ class HostNode(BaseROS2DeviceNode):
""" """
self.lab_logger().info(f"[Host Node] Node info update request received: {request}") self.lab_logger().info(f"[Host Node] Node info update request received: {request}")
try: try:
from unilabos.app.mq import mqtt_client from unilabos.app.communication import get_communication_client
info = json.loads(request.command) info = json.loads(request.command)
if "SYNC_SLAVE_NODE_INFO" in info: if "SYNC_SLAVE_NODE_INFO" in info:
@@ -811,9 +824,10 @@ class HostNode(BaseROS2DeviceNode):
edge_device_id = info["edge_device_id"] edge_device_id = info["edge_device_id"]
self.device_machine_names[edge_device_id] = machine_name self.device_machine_names[edge_device_id] = machine_name
else: else:
comm_client = get_communication_client()
registry_config = info["registry_config"] registry_config = info["registry_config"]
for device_config in registry_config: for device_config in registry_config:
mqtt_client.publish_registry(device_config["id"], device_config) comm_client.publish_registry(device_config["id"], device_config)
self.lab_logger().debug(f"[Host Node] Node info update: {info}") self.lab_logger().debug(f"[Host Node] Node info update: {info}")
response.response = "OK" response.response = "OK"
except Exception as e: except Exception as e:
@@ -840,6 +854,7 @@ class HostNode(BaseROS2DeviceNode):
success = False success = False
if len(self.bridges) > 0: # 边的提交待定 if len(self.bridges) > 0: # 边的提交待定
from unilabos.app.web.client import HTTPClient from unilabos.app.web.client import HTTPClient
client: HTTPClient = self.bridges[-1] client: HTTPClient = self.bridges[-1]
r = client.resource_add(add_schema(resources), False) r = client.resource_add(add_schema(resources), False)
success = bool(r) success = bool(r)
@@ -848,6 +863,7 @@ class HostNode(BaseROS2DeviceNode):
if success: if success:
from unilabos.resources.graphio import physical_setup_graph from unilabos.resources.graphio import physical_setup_graph
for resource in resources: for resource in resources:
if resource.get("id") not in physical_setup_graph.nodes: if resource.get("id") not in physical_setup_graph.nodes:
physical_setup_graph.add_node(resource["id"], **resource) physical_setup_graph.add_node(resource["id"], **resource)
@@ -988,9 +1004,10 @@ class HostNode(BaseROS2DeviceNode):
send_timestamp = time.time() send_timestamp = time.time()
# 发送ping # 发送ping
from unilabos.app.mq import mqtt_client from unilabos.app.communication import get_communication_client
mqtt_client.send_ping(ping_id, send_timestamp) comm_client = get_communication_client()
comm_client.send_ping(ping_id, send_timestamp)
# 等待pong响应 # 等待pong响应
timeout = 10.0 timeout = 10.0

View File

@@ -18,6 +18,7 @@ class EnvironmentChecker:
self.required_packages = { self.required_packages = {
# 包导入名 : pip安装名 # 包导入名 : pip安装名
# "pymodbus.framer.FramerType": "pymodbus==3.9.2", # "pymodbus.framer.FramerType": "pymodbus==3.9.2",
"websockets": "websockets",
"paho.mqtt": "paho-mqtt", "paho.mqtt": "paho-mqtt",
"opentrons_shared_data": "opentrons_shared_data", "opentrons_shared_data": "opentrons_shared_data",
} }