mirror of
https://github.com/dptech-corp/Uni-Lab-OS.git
synced 2025-12-17 13:01:12 +00:00
feat: websocket
This commit is contained in:
204
unilabos/app/communication.py
Normal file
204
unilabos/app/communication.py
Normal file
@@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
"""
|
||||
通信模块
|
||||
|
||||
提供MQTT和WebSocket的统一接口,支持通过配置选择通信协议。
|
||||
包含通信抽象层基类和通信客户端工厂。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from unilabos.config.config import BasicConfig
|
||||
from unilabos.utils import logger
|
||||
|
||||
|
||||
class BaseCommunicationClient(ABC):
|
||||
"""
|
||||
通信客户端抽象基类
|
||||
|
||||
定义了所有通信客户端(MQTT、WebSocket等)需要实现的接口。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.is_disabled = True
|
||||
self.client_id = ""
|
||||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""
|
||||
启动通信客户端连接
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
停止通信客户端连接
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def publish_device_status(self, device_status: dict, device_id: str, property_name: str) -> None:
|
||||
"""
|
||||
发布设备状态信息
|
||||
|
||||
Args:
|
||||
device_status: 设备状态字典
|
||||
device_id: 设备ID
|
||||
property_name: 属性名称
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def publish_job_status(
|
||||
self, feedback_data: dict, job_id: str, status: str, return_info: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
发布作业状态信息
|
||||
|
||||
Args:
|
||||
feedback_data: 反馈数据
|
||||
job_id: 作业ID
|
||||
status: 作业状态
|
||||
return_info: 返回信息
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def send_ping(self, ping_id: str, timestamp: float) -> None:
|
||||
"""
|
||||
发送ping消息
|
||||
|
||||
Args:
|
||||
ping_id: ping ID
|
||||
timestamp: 时间戳
|
||||
"""
|
||||
pass
|
||||
|
||||
def setup_pong_subscription(self) -> None:
|
||||
"""
|
||||
设置pong消息订阅(可选实现)
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""
|
||||
检查是否已连接
|
||||
|
||||
Returns:
|
||||
是否已连接
|
||||
"""
|
||||
return not self.is_disabled
|
||||
|
||||
|
||||
class CommunicationClientFactory:
|
||||
"""
|
||||
通信客户端工厂类
|
||||
|
||||
根据配置文件中的通信协议设置创建相应的客户端实例。
|
||||
"""
|
||||
|
||||
_client_cache: Optional[BaseCommunicationClient] = None
|
||||
|
||||
@classmethod
|
||||
def create_client(cls, protocol: Optional[str] = None) -> BaseCommunicationClient:
|
||||
"""
|
||||
创建通信客户端实例
|
||||
|
||||
Args:
|
||||
protocol: 指定的协议类型,如果为None则使用配置文件中的设置
|
||||
|
||||
Returns:
|
||||
通信客户端实例
|
||||
|
||||
Raises:
|
||||
ValueError: 当协议类型不支持时
|
||||
"""
|
||||
if protocol is None:
|
||||
protocol = BasicConfig.communication_protocol
|
||||
|
||||
protocol = protocol.lower()
|
||||
|
||||
if protocol == "mqtt":
|
||||
return cls._create_mqtt_client()
|
||||
elif protocol == "websocket":
|
||||
return cls._create_websocket_client()
|
||||
else:
|
||||
logger.error(f"[CommunicationFactory] Unsupported protocol: {protocol}")
|
||||
logger.warning(f"[CommunicationFactory] Falling back to MQTT")
|
||||
return cls._create_mqtt_client()
|
||||
|
||||
@classmethod
|
||||
def get_client(cls, protocol: Optional[str] = None) -> BaseCommunicationClient:
|
||||
"""
|
||||
获取通信客户端实例(单例模式)
|
||||
|
||||
Args:
|
||||
protocol: 指定的协议类型,如果为None则使用配置文件中的设置
|
||||
|
||||
Returns:
|
||||
通信客户端实例
|
||||
"""
|
||||
if cls._client_cache is None:
|
||||
cls._client_cache = cls.create_client(protocol)
|
||||
logger.info(f"[CommunicationFactory] Created {type(cls._client_cache).__name__} client")
|
||||
|
||||
return cls._client_cache
|
||||
|
||||
@classmethod
|
||||
def _create_mqtt_client(cls) -> BaseCommunicationClient:
|
||||
"""创建MQTT客户端"""
|
||||
try:
|
||||
from unilabos.app.mq import mqtt_client
|
||||
return mqtt_client
|
||||
except Exception as e:
|
||||
logger.error(f"[CommunicationFactory] Failed to create MQTT client: {str(e)}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def _create_websocket_client(cls) -> BaseCommunicationClient:
|
||||
"""创建WebSocket客户端"""
|
||||
try:
|
||||
from unilabos.app.ws_client import WebSocketClient
|
||||
return WebSocketClient()
|
||||
except Exception as e:
|
||||
logger.error(f"[CommunicationFactory] Failed to create WebSocket client: {str(e)}")
|
||||
logger.warning(f"[CommunicationFactory] Falling back to MQTT")
|
||||
return cls._create_mqtt_client()
|
||||
|
||||
@classmethod
|
||||
def reset_client(cls):
|
||||
"""重置客户端缓存(用于测试或重新配置)"""
|
||||
if cls._client_cache:
|
||||
try:
|
||||
cls._client_cache.stop()
|
||||
except Exception as e:
|
||||
logger.warning(f"[CommunicationFactory] Error stopping old client: {str(e)}")
|
||||
|
||||
cls._client_cache = None
|
||||
logger.info("[CommunicationFactory] Client cache reset")
|
||||
|
||||
@classmethod
|
||||
def get_supported_protocols(cls) -> list[str]:
|
||||
"""
|
||||
获取支持的协议列表
|
||||
|
||||
Returns:
|
||||
支持的协议列表
|
||||
"""
|
||||
return ["mqtt", "websocket"]
|
||||
|
||||
|
||||
def get_communication_client(protocol: Optional[str] = None) -> BaseCommunicationClient:
|
||||
"""
|
||||
获取通信客户端实例的便捷函数
|
||||
|
||||
Args:
|
||||
protocol: 指定的协议类型,如果为None则使用配置文件中的设置
|
||||
|
||||
Returns:
|
||||
通信客户端实例
|
||||
"""
|
||||
return CommunicationClientFactory.get_client(protocol)
|
||||
@@ -10,7 +10,6 @@ from copy import deepcopy
|
||||
|
||||
import yaml
|
||||
|
||||
from unilabos.resources.graphio import modify_to_backend_format
|
||||
|
||||
# 首先添加项目根目录到路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -20,6 +19,7 @@ if unilabos_dir not in sys.path:
|
||||
|
||||
from unilabos.config.config import load_config, BasicConfig
|
||||
from unilabos.utils.banner_print import print_status, print_unilab_banner
|
||||
from unilabos.resources.graphio import modify_to_backend_format
|
||||
|
||||
|
||||
def load_config_from_file(config_path, override_labid=None):
|
||||
@@ -146,6 +146,11 @@ def parse_args():
|
||||
default="",
|
||||
help="实验室请求的sk",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--websocket",
|
||||
action="store_true",
|
||||
help="使用websocket而非mqtt作为通信协议",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_env_check",
|
||||
action="store_true",
|
||||
@@ -179,7 +184,7 @@ def main():
|
||||
else:
|
||||
working_dir = os.path.abspath(os.path.join(os.getcwd(), "unilabos_data"))
|
||||
if args_dict.get("working_dir"):
|
||||
working_dir = args_dict.get("working_dir")
|
||||
working_dir = args_dict.get("working_dir", "")
|
||||
if config_path and not os.path.exists(config_path):
|
||||
config_path = os.path.join(working_dir, "local_config.py")
|
||||
if not os.path.exists(config_path):
|
||||
@@ -215,6 +220,7 @@ def main():
|
||||
if args_dict["use_remote_resource"]:
|
||||
print_status("使用远程资源启动", "info")
|
||||
from unilabos.app.web import http_client
|
||||
|
||||
res = http_client.resource_get("host_node", False)
|
||||
if str(res.get("code", 0)) == "0" and len(res.get("data", [])) > 0:
|
||||
print_status("远程资源已存在,使用云端物料!", "info")
|
||||
@@ -229,6 +235,7 @@ def main():
|
||||
BasicConfig.is_host_mode = not args_dict.get("without_host", False)
|
||||
BasicConfig.slave_no_host = args_dict.get("slave_no_host", False)
|
||||
BasicConfig.upload_registry = args_dict.get("upload_registry", False)
|
||||
BasicConfig.communication_protocol = "websocket" if args_dict.get("websocket", False) else "mqtt"
|
||||
machine_name = os.popen("hostname").read().strip()
|
||||
machine_name = "".join([c if c.isalnum() or c == "_" else "_" for c in machine_name])
|
||||
BasicConfig.machine_name = machine_name
|
||||
@@ -241,7 +248,7 @@ def main():
|
||||
dict_to_nested_dict,
|
||||
initialize_resources,
|
||||
)
|
||||
from unilabos.app.mq import mqtt_client
|
||||
from unilabos.app.communication import get_communication_client
|
||||
from unilabos.registry.registry import build_registry
|
||||
from unilabos.app.backend import start_backend
|
||||
from unilabos.app.web import http_client
|
||||
@@ -289,19 +296,22 @@ def main():
|
||||
|
||||
args_dict["bridges"] = []
|
||||
|
||||
# 获取通信客户端(根据配置选择MQTT或WebSocket)
|
||||
comm_client = get_communication_client()
|
||||
|
||||
if "mqtt" in args_dict["app_bridges"]:
|
||||
args_dict["bridges"].append(mqtt_client)
|
||||
args_dict["bridges"].append(comm_client)
|
||||
if "fastapi" in args_dict["app_bridges"]:
|
||||
args_dict["bridges"].append(http_client)
|
||||
if "mqtt" in args_dict["app_bridges"]:
|
||||
|
||||
def _exit(signum, frame):
|
||||
mqtt_client.stop()
|
||||
comm_client.stop()
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, _exit)
|
||||
signal.signal(signal.SIGTERM, _exit)
|
||||
mqtt_client.start()
|
||||
comm_client.start()
|
||||
args_dict["resources_mesh_config"] = {}
|
||||
args_dict["resources_edge_config"] = resource_edge_info
|
||||
# web visiualize 2D
|
||||
|
||||
@@ -15,17 +15,20 @@ import os
|
||||
from unilabos.config.config import MQConfig
|
||||
from unilabos.app.controler import job_add
|
||||
from unilabos.app.model import JobAddReq
|
||||
from unilabos.app.communication import BaseCommunicationClient
|
||||
from unilabos.utils import logger
|
||||
from unilabos.utils.type_check import TypeEncoder
|
||||
|
||||
from paho.mqtt.enums import CallbackAPIVersion
|
||||
|
||||
|
||||
class MQTTClient:
|
||||
class MQTTClient(BaseCommunicationClient):
|
||||
mqtt_disable = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mqtt_disable = not MQConfig.lab_id
|
||||
self.is_disabled = self.mqtt_disable # 更新父类属性
|
||||
self.client_id = f"{MQConfig.group_id}@@@{MQConfig.lab_id}{uuid.uuid4()}"
|
||||
logger.info("[MQTT] Client_id: " + self.client_id)
|
||||
self.client = mqtt.Client(CallbackAPIVersion.VERSION2, client_id=self.client_id, protocol=mqtt.MQTTv5)
|
||||
@@ -208,11 +211,12 @@ class MQTTClient:
|
||||
self.client.subscribe(pong_topic, 0)
|
||||
logger.debug(f"Subscribed to pong topic: {pong_topic}")
|
||||
|
||||
def handle_pong(self, pong_data: dict):
|
||||
"""处理pong响应(这个方法会在收到pong消息时被调用)"""
|
||||
logger.debug(f"Pong received: {pong_data}")
|
||||
# 这里会被HostNode的ping-pong处理逻辑调用
|
||||
pass
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""检查MQTT是否已连接"""
|
||||
if self.is_disabled:
|
||||
return False
|
||||
return hasattr(self.client, "is_connected") and self.client.is_connected()
|
||||
|
||||
|
||||
mqtt_client = MQTTClient()
|
||||
|
||||
@@ -10,35 +10,44 @@ from unilabos.utils.log import logger
|
||||
from unilabos.utils.type_check import TypeEncoder
|
||||
|
||||
|
||||
def register_devices_and_resources(mqtt_client, lab_registry):
|
||||
def register_devices_and_resources(comm_client, lab_registry):
|
||||
"""
|
||||
注册设备和资源到 MQTT
|
||||
注册设备和资源到通信服务器(MQTT/WebSocket)
|
||||
"""
|
||||
|
||||
# 注册资源信息 - 使用HTTP方式
|
||||
from unilabos.app.web.client import http_client
|
||||
|
||||
logger.info("[UniLab Register] 开始注册设备和资源...")
|
||||
if BasicConfig.auth_secret():
|
||||
# 注册设备信息
|
||||
devices_to_register = {}
|
||||
for device_info in lab_registry.obtain_registry_device_info():
|
||||
devices_to_register[device_info["id"]] = json.loads(json.dumps(device_info, ensure_ascii=False, cls=TypeEncoder))
|
||||
devices_to_register[device_info["id"]] = json.loads(
|
||||
json.dumps(device_info, ensure_ascii=False, cls=TypeEncoder)
|
||||
)
|
||||
logger.debug(f"[UniLab Register] 收集设备: {device_info['id']}")
|
||||
resources_to_register = {}
|
||||
for resource_info in lab_registry.obtain_registry_resource_info():
|
||||
resources_to_register[resource_info["id"]] = resource_info
|
||||
logger.debug(f"[UniLab Register] 收集资源: {resource_info['id']}")
|
||||
print("[UniLab Register] 设备注册", http_client.resource_registry({"resources": list(devices_to_register.values())}).text)
|
||||
print("[UniLab Register] 资源注册", http_client.resource_registry({"resources": list(resources_to_register.values())}).text)
|
||||
print(
|
||||
"[UniLab Register] 设备注册",
|
||||
http_client.resource_registry({"resources": list(devices_to_register.values())}).text,
|
||||
)
|
||||
print(
|
||||
"[UniLab Register] 资源注册",
|
||||
http_client.resource_registry({"resources": list(resources_to_register.values())}).text,
|
||||
)
|
||||
else:
|
||||
# 注册设备信息
|
||||
for device_info in lab_registry.obtain_registry_device_info():
|
||||
mqtt_client.publish_registry(device_info["id"], device_info, False)
|
||||
comm_client.publish_registry(device_info["id"], device_info, False)
|
||||
logger.debug(f"[UniLab Register] 注册设备: {device_info['id']}")
|
||||
|
||||
# # 注册资源信息
|
||||
# for resource_info in lab_registry.obtain_registry_resource_info():
|
||||
# mqtt_client.publish_registry(resource_info["id"], resource_info, False)
|
||||
# comm_client.publish_registry(resource_info["id"], resource_info, False)
|
||||
# logger.debug(f"[UniLab Register] 注册资源: {resource_info['id']}")
|
||||
|
||||
resources_to_register = {}
|
||||
@@ -53,7 +62,9 @@ def register_devices_and_resources(mqtt_client, lab_registry):
|
||||
if response.status_code in [200, 201]:
|
||||
logger.info(f"[UniLab Register] 成功通过HTTP注册 {len(resources_to_register)} 个资源 {cost_time}ms")
|
||||
else:
|
||||
logger.error(f"[UniLab Register] HTTP注册资源失败: {response.status_code}, {response.text} {cost_time}ms")
|
||||
logger.error(
|
||||
f"[UniLab Register] HTTP注册资源失败: {response.status_code}, {response.text} {cost_time}ms"
|
||||
)
|
||||
logger.info("[UniLab Register] 设备和资源注册完成.")
|
||||
|
||||
|
||||
@@ -99,15 +110,16 @@ def main():
|
||||
BasicConfig.sk = args.sk
|
||||
# 构建注册表
|
||||
build_registry(args.registry, args.complete_registry, True)
|
||||
from unilabos.app.mq import mqtt_client
|
||||
from unilabos.app.communication import get_communication_client
|
||||
|
||||
# 连接mqtt
|
||||
mqtt_client.start()
|
||||
# 获取通信客户端并启动连接
|
||||
comm_client = get_communication_client()
|
||||
comm_client.start()
|
||||
|
||||
from unilabos.registry.registry import lab_registry
|
||||
|
||||
# 注册设备和资源
|
||||
register_devices_and_resources(mqtt_client, lab_registry)
|
||||
register_devices_and_resources(comm_client, lab_registry)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
300
unilabos/app/ws_client.py
Normal file
300
unilabos/app/ws_client.py
Normal file
@@ -0,0 +1,300 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
"""
|
||||
WebSocket通信客户端
|
||||
|
||||
基于WebSocket协议的通信客户端实现,继承自BaseCommunicationClient。
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import threading
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Optional, Dict, Any
|
||||
from urllib.parse import urlparse
|
||||
from unilabos.app.controler import job_add
|
||||
from unilabos.app.model import JobAddReq
|
||||
from unilabos.ros.nodes.presets.host_node import HostNode
|
||||
|
||||
try:
|
||||
import websockets
|
||||
import ssl as ssl_module
|
||||
|
||||
HAS_WEBSOCKETS = True
|
||||
except ImportError:
|
||||
HAS_WEBSOCKETS = False
|
||||
|
||||
from unilabos.app.communication import BaseCommunicationClient
|
||||
from unilabos.config.config import WSConfig, HTTPConfig, BasicConfig
|
||||
from unilabos.utils import logger
|
||||
|
||||
|
||||
class WebSocketClient(BaseCommunicationClient):
|
||||
"""
|
||||
WebSocket通信客户端类
|
||||
|
||||
实现基于WebSocket协议的实时通信功能。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
if not HAS_WEBSOCKETS:
|
||||
logger.error("[WebSocket] websockets库未安装,WebSocket功能不可用")
|
||||
self.is_disabled = True
|
||||
return
|
||||
|
||||
self.is_disabled = False
|
||||
self.client_id = f"{uuid.uuid4()}"
|
||||
|
||||
# WebSocket连接相关
|
||||
self.websocket = None
|
||||
self.connection_loop = None
|
||||
self.event_loop = None
|
||||
self.connection_thread = None
|
||||
self.is_running = False
|
||||
self.connected = False
|
||||
|
||||
# 消息处理
|
||||
self.message_queue = asyncio.Queue() if not self.is_disabled else None
|
||||
self.reconnect_count = 0
|
||||
|
||||
# 构建WebSocket URL
|
||||
self._build_websocket_url()
|
||||
|
||||
logger.info(f"[WebSocket] Client_id: {self.client_id}")
|
||||
|
||||
def _build_websocket_url(self):
|
||||
"""构建WebSocket连接URL"""
|
||||
if not HTTPConfig.remote_addr:
|
||||
self.websocket_url = None
|
||||
return
|
||||
|
||||
# 解析服务器URL
|
||||
parsed = urlparse(HTTPConfig.remote_addr)
|
||||
|
||||
# 根据SSL配置选择协议
|
||||
if parsed.scheme == "https":
|
||||
scheme = "wss"
|
||||
else:
|
||||
scheme = "ws"
|
||||
self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/lab"
|
||||
|
||||
logger.debug(f"[WebSocket] URL: {self.websocket_url}")
|
||||
|
||||
def start(self) -> None:
|
||||
"""启动WebSocket连接"""
|
||||
if self.is_disabled:
|
||||
logger.warning("[WebSocket] WebSocket is disabled, skipping connection.")
|
||||
return
|
||||
|
||||
if not self.websocket_url:
|
||||
logger.error("[WebSocket] WebSocket URL not configured")
|
||||
return
|
||||
|
||||
logger.info(f"[WebSocket] Starting connection to {self.websocket_url}")
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# 在单独线程中运行WebSocket连接
|
||||
self.connection_thread = threading.Thread(target=self._run_connection, daemon=True, name="WebSocketConnection")
|
||||
self.connection_thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""停止WebSocket连接"""
|
||||
if self.is_disabled:
|
||||
return
|
||||
|
||||
logger.info("[WebSocket] Stopping connection")
|
||||
self.is_running = False
|
||||
self.connected = False
|
||||
|
||||
if self.event_loop and self.event_loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(self._close_connection(), self.event_loop)
|
||||
|
||||
if self.connection_thread and self.connection_thread.is_alive():
|
||||
self.connection_thread.join(timeout=5)
|
||||
|
||||
def _run_connection(self):
|
||||
"""在独立线程中运行WebSocket连接"""
|
||||
try:
|
||||
# 创建新的事件循环
|
||||
self.event_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.event_loop)
|
||||
|
||||
# 运行连接逻辑
|
||||
self.event_loop.run_until_complete(self._connection_handler())
|
||||
except Exception as e:
|
||||
logger.error(f"[WebSocket] Connection thread error: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
if self.event_loop:
|
||||
self.event_loop.close()
|
||||
|
||||
async def _connection_handler(self):
|
||||
"""处理WebSocket连接和重连逻辑"""
|
||||
while self.is_running:
|
||||
try:
|
||||
# 构建SSL上下文
|
||||
ssl_context = None
|
||||
assert self.websocket_url is not None
|
||||
if self.websocket_url.startswith("wss://"):
|
||||
ssl_context = ssl_module.create_default_context()
|
||||
|
||||
async with websockets.connect(
|
||||
self.websocket_url,
|
||||
ssl=ssl_context,
|
||||
ping_interval=WSConfig.ping_interval,
|
||||
ping_timeout=10,
|
||||
additional_headers={"Authorization": f"Bearer {BasicConfig.auth_secret()}"},
|
||||
) as websocket:
|
||||
self.websocket = websocket
|
||||
self.connected = True
|
||||
self.reconnect_count = 0
|
||||
|
||||
logger.info(f"[WebSocket] Connected to {self.websocket_url}")
|
||||
# 处理消息
|
||||
await self._message_handler()
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.warning("[WebSocket] Connection closed")
|
||||
self.connected = False
|
||||
except Exception as e:
|
||||
logger.error(f"[WebSocket] Connection error: {str(e)}")
|
||||
self.connected = False
|
||||
|
||||
# 重连逻辑
|
||||
if self.is_running and self.reconnect_count < WSConfig.max_reconnect_attempts:
|
||||
self.reconnect_count += 1
|
||||
logger.info(
|
||||
f"[WebSocket] Reconnecting in {WSConfig.reconnect_interval}s "
|
||||
f"(attempt {self.reconnect_count}/{WSConfig.max_reconnect_attempts})"
|
||||
)
|
||||
await asyncio.sleep(WSConfig.reconnect_interval)
|
||||
elif self.reconnect_count >= WSConfig.max_reconnect_attempts:
|
||||
logger.error("[WebSocket] Max reconnection attempts reached")
|
||||
break
|
||||
|
||||
async def _close_connection(self):
|
||||
"""关闭WebSocket连接"""
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
self.websocket = None
|
||||
|
||||
async def _send_message(self, message: Dict[str, Any]):
|
||||
"""发送消息"""
|
||||
if not self.connected or not self.websocket:
|
||||
logger.warning("[WebSocket] Not connected, cannot send message")
|
||||
return
|
||||
|
||||
try:
|
||||
message_str = json.dumps(message, ensure_ascii=False)
|
||||
await self.websocket.send(message_str)
|
||||
logger.debug(f"[WebSocket] Message sent: {message['type']}")
|
||||
except Exception as e:
|
||||
logger.error(f"[WebSocket] Failed to send message: {str(e)}")
|
||||
|
||||
async def _message_handler(self):
|
||||
"""处理接收到的消息"""
|
||||
if not self.websocket:
|
||||
logger.error("[WebSocket] WebSocket connection is None")
|
||||
return
|
||||
|
||||
try:
|
||||
async for message in self.websocket:
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_message(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"[WebSocket] Invalid JSON received: {message}")
|
||||
except Exception as e:
|
||||
logger.error(f"[WebSocket] Error processing message: {str(e)}")
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("[WebSocket] Message handler stopped - connection closed")
|
||||
except Exception as e:
|
||||
logger.error(f"[WebSocket] Message handler error: {str(e)}")
|
||||
|
||||
async def _process_message(self, input_message: Dict[str, Any]):
|
||||
"""处理收到的消息"""
|
||||
message_type = input_message.get("type", "")
|
||||
data = input_message.get("data", {})
|
||||
if message_type == "job_start":
|
||||
# 处理作业启动消息
|
||||
await self._handle_job_start(data)
|
||||
elif message_type == "pong":
|
||||
# 处理pong响应
|
||||
self._handle_pong_sync(data)
|
||||
else:
|
||||
logger.debug(f"[WebSocket] Unknown message type: {message_type}")
|
||||
|
||||
async def _handle_job_start(self, data: Dict[str, Any]):
|
||||
"""处理作业启动消息"""
|
||||
try:
|
||||
job_req = JobAddReq(**data.get("job_data", {}))
|
||||
job_add(job_req)
|
||||
job_id = getattr(job_req, "id", "unknown")
|
||||
logger.info(f"[WebSocket] Job started: {job_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"[WebSocket] Error handling job start: {str(e)}")
|
||||
|
||||
def _handle_pong_sync(self, pong_data: Dict[str, Any]):
|
||||
"""同步处理pong响应"""
|
||||
host_node = HostNode.get_instance(0)
|
||||
if host_node:
|
||||
host_node.handle_pong_response(pong_data)
|
||||
|
||||
# 实现抽象基类的方法
|
||||
def publish_device_status(self, device_status: dict, device_id: str, property_name: str) -> None:
|
||||
"""发布设备状态"""
|
||||
if self.is_disabled or not self.connected:
|
||||
return
|
||||
message = {
|
||||
"type": "device_status",
|
||||
"data": {
|
||||
"device_id": device_id,
|
||||
"property_name": property_name,
|
||||
"status": device_status.get(device_id, {}).get(property_name),
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
}
|
||||
if self.event_loop:
|
||||
asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop)
|
||||
logger.debug(f"[WebSocket] Device status published: {device_id}.{property_name}")
|
||||
|
||||
def publish_job_status(
|
||||
self, feedback_data: dict, job_id: str, status: str, return_info: Optional[str] = None
|
||||
) -> None:
|
||||
"""发布作业状态"""
|
||||
if self.is_disabled or not self.connected:
|
||||
logger.warning("[WebSocket] Not connected, cannot publish job status")
|
||||
return
|
||||
message = {
|
||||
"type": "job_status",
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"status": status,
|
||||
"feedback_data": feedback_data,
|
||||
"return_info": return_info,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
}
|
||||
if self.event_loop:
|
||||
asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop)
|
||||
logger.debug(f"[WebSocket] Job status published: {job_id} - {status}")
|
||||
|
||||
def send_ping(self, ping_id: str, timestamp: float) -> None:
|
||||
"""发送ping消息"""
|
||||
if self.is_disabled or not self.connected:
|
||||
logger.warning("[WebSocket] Not connected, cannot send ping")
|
||||
return
|
||||
message = {"type": "ping", "data": {"ping_id": ping_id, "client_timestamp": timestamp}}
|
||||
if self.event_loop:
|
||||
asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop)
|
||||
logger.debug(f"[WebSocket] Ping sent: {ping_id}")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""检查是否已连接"""
|
||||
return self.connected and not self.is_disabled
|
||||
Reference in New Issue
Block a user