Files
Uni-Lab-OS/unilabos/app/ws_client.py
2025-09-01 16:34:23 +08:00

347 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# coding=utf-8
"""
WebSocket通信客户端
基于WebSocket协议的通信客户端实现继承自BaseCommunicationClient。
"""
import json
import time
import uuid
import threading
import asyncio
import traceback
from typing import Optional, Dict, Any
from urllib.parse import urlparse
from unilabos.app.controler import job_add
from unilabos.app.model import JobAddReq
from unilabos.ros.nodes.presets.host_node import HostNode
from unilabos.utils.type_check import serialize_result_info
try:
import websockets
import ssl as ssl_module
HAS_WEBSOCKETS = True
except ImportError:
HAS_WEBSOCKETS = False
from unilabos.app.communication import BaseCommunicationClient
from unilabos.config.config import WSConfig, HTTPConfig, BasicConfig
from unilabos.utils import logger
class WebSocketClient(BaseCommunicationClient):
"""
WebSocket通信客户端类
实现基于WebSocket协议的实时通信功能。
"""
def __init__(self):
super().__init__()
if not HAS_WEBSOCKETS:
logger.error("[WebSocket] websockets库未安装WebSocket功能不可用")
self.is_disabled = True
return
self.is_disabled = False
self.client_id = f"{uuid.uuid4()}"
# WebSocket连接相关
self.websocket = None
self.connection_loop = None
self.event_loop = None
self.connection_thread = None
self.is_running = False
self.connected = False
# 消息处理
self.message_queue = asyncio.Queue() if not self.is_disabled else None
self.reconnect_count = 0
# 构建WebSocket URL
self._build_websocket_url()
logger.info(f"[WebSocket] Client_id: {self.client_id}")
def _build_websocket_url(self):
"""构建WebSocket连接URL"""
if not HTTPConfig.remote_addr:
self.websocket_url = None
return
# 解析服务器URL
parsed = urlparse(HTTPConfig.remote_addr)
# 根据SSL配置选择协议
if parsed.scheme == "https":
scheme = "wss"
else:
scheme = "ws"
if ":" in parsed.netloc:
self.websocket_url = f"{scheme}://{parsed.hostname}:{parsed.port + 1}/api/v1/lab"
else:
self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/lab"
logger.debug(f"[WebSocket] URL: {self.websocket_url}")
def start(self) -> None:
"""启动WebSocket连接"""
if self.is_disabled:
logger.warning("[WebSocket] WebSocket is disabled, skipping connection.")
return
if not self.websocket_url:
logger.error("[WebSocket] WebSocket URL not configured")
return
logger.info(f"[WebSocket] Starting connection to {self.websocket_url}")
self.is_running = True
# 在单独线程中运行WebSocket连接
self.connection_thread = threading.Thread(target=self._run_connection, daemon=True, name="WebSocketConnection")
self.connection_thread.start()
def stop(self) -> None:
"""停止WebSocket连接"""
if self.is_disabled:
return
logger.info("[WebSocket] Stopping connection")
self.is_running = False
self.connected = False
if self.event_loop and self.event_loop.is_running():
asyncio.run_coroutine_threadsafe(self._close_connection(), self.event_loop)
if self.connection_thread and self.connection_thread.is_alive():
self.connection_thread.join(timeout=5)
def _run_connection(self):
"""在独立线程中运行WebSocket连接"""
try:
# 创建新的事件循环
self.event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.event_loop)
# 运行连接逻辑
self.event_loop.run_until_complete(self._connection_handler())
except Exception as e:
logger.error(f"[WebSocket] Connection thread error: {str(e)}")
logger.error(traceback.format_exc())
finally:
if self.event_loop:
self.event_loop.close()
async def _connection_handler(self):
"""处理WebSocket连接和重连逻辑"""
while self.is_running:
try:
# 构建SSL上下文
ssl_context = None
assert self.websocket_url is not None
if self.websocket_url.startswith("wss://"):
ssl_context = ssl_module.create_default_context()
async with websockets.connect(
self.websocket_url,
ssl=ssl_context,
ping_interval=WSConfig.ping_interval,
ping_timeout=10,
additional_headers={"Authorization": f"Lab {BasicConfig.auth_secret()}"},
) as websocket:
self.websocket = websocket
self.connected = True
self.reconnect_count = 0
logger.info(f"[WebSocket] Connected to {self.websocket_url}")
# 处理消息
await self._message_handler()
except websockets.exceptions.ConnectionClosed:
logger.warning("[WebSocket] Connection closed")
self.connected = False
except Exception as e:
logger.error(f"[WebSocket] Connection error: {str(e)}")
self.connected = False
# 重连逻辑
if self.is_running and self.reconnect_count < WSConfig.max_reconnect_attempts:
self.reconnect_count += 1
logger.info(
f"[WebSocket] Reconnecting in {WSConfig.reconnect_interval}s "
f"(attempt {self.reconnect_count}/{WSConfig.max_reconnect_attempts})"
)
await asyncio.sleep(WSConfig.reconnect_interval)
elif self.reconnect_count >= WSConfig.max_reconnect_attempts:
logger.error("[WebSocket] Max reconnection attempts reached")
break
else:
self.reconnect_count -= 1
async def _close_connection(self):
"""关闭WebSocket连接"""
if self.websocket:
await self.websocket.close()
self.websocket = None
async def _send_message(self, message: Dict[str, Any]):
"""发送消息"""
if not self.connected or not self.websocket:
logger.warning("[WebSocket] Not connected, cannot send message")
return
try:
message_str = json.dumps(message, ensure_ascii=False)
await self.websocket.send(message_str)
logger.debug(f"[WebSocket] Message sent: {message['type']}")
except Exception as e:
logger.error(f"[WebSocket] Failed to send message: {str(e)}")
async def _message_handler(self):
"""处理接收到的消息"""
if not self.websocket:
logger.error("[WebSocket] WebSocket connection is None")
return
try:
async for message in self.websocket:
try:
data = json.loads(message)
await self._process_message(data)
except json.JSONDecodeError:
logger.error(f"[WebSocket] Invalid JSON received: {message}")
except Exception as e:
logger.error(f"[WebSocket] Error processing message: {str(e)}")
except websockets.exceptions.ConnectionClosed:
logger.info("[WebSocket] Message handler stopped - connection closed")
except Exception as e:
logger.error(f"[WebSocket] Message handler error: {str(e)}")
async def _handle_query_state(self, data: Dict[str, str]) -> None:
device_id = data.get("device_id", "")
if not device_id:
logger.error("[WebSocket] query_action_state missing device_id")
return
action_name = data.get("action_name", "")
if not action_name:
logger.error("[WebSocket] query_action_state missing action_name")
return
task_id = data.get("task_id", "")
if not task_id:
logger.error("[WebSocket] query_action_state missing task_id")
return
job_id = data.get("job_id", "")
if not task_id:
logger.error("[WebSocket] query_action_state missing job_id")
return
device_action_key = f"/devices/{device_id}/{action_name}"
action_jobs = len(HostNode.get_instance()._device_action_status[device_action_key].job_ids)
message = {
"type": "report_action_state",
"data": {
"device_id": device_id,
"action_name": action_name,
"task_id": task_id,
"job_id": job_id,
"free": bool(action_jobs)
},
}
await self._send_message(message)
async def _process_message(self, input_message: Dict[str, Any]):
"""处理收到的消息"""
message_type = input_message.get("type", "")
data = input_message.get("data", {})
if message_type == "job_start":
# 处理作业启动消息
await self._handle_job_start(data)
elif message_type == "pong":
# 处理pong响应
self._handle_pong_sync(data)
elif message_type == "query_action_state":
await self._handle_query_state(data)
else:
logger.debug(f"[WebSocket] Unknown message type: {message_type}")
async def _handle_job_start(self, data: Dict[str, Any]):
"""处理作业启动消息"""
try:
req = JobAddReq(**data)
try:
req.job_id = str(uuid.uuid4())
logger.info(f"[WebSocket] Job started: {req.job_id}")
HostNode.get_instance().send_goal(req.device_id, action_type=req.action_type, action_name=req.action,
action_kwargs=req.action_args, goal_uuid=req.job_id,
server_info=req.server_info)
except Exception as e:
for bridge in HostNode.get_instance().bridges:
traceback.print_exc()
if hasattr(bridge, "publish_job_status"):
self.publish_job_status({}, req.job_id, "failed", serialize_result_info(traceback.format_exc(), False, {}))
except Exception as e:
logger.error(f"[WebSocket] Error handling job start: {str(e)}")
def _handle_pong_sync(self, pong_data: Dict[str, Any]):
"""同步处理pong响应"""
host_node = HostNode.get_instance(0)
if host_node:
host_node.handle_pong_response(pong_data)
# 实现抽象基类的方法
def publish_device_status(self, device_status: dict, device_id: str, property_name: str) -> None:
"""发布设备状态"""
if self.is_disabled or not self.connected:
return
message = {
"type": "device_status",
"data": {
"device_id": device_id,
"property_name": property_name,
"status": device_status.get(device_id, {}).get(property_name),
"timestamp": time.time(),
},
}
if self.event_loop:
asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop)
logger.debug(f"[WebSocket] Device status published: {device_id}.{property_name}")
def publish_job_status(
self, feedback_data: dict, job_id: str, status: str, return_info: Optional[str] = None
) -> None:
"""发布作业状态"""
if self.is_disabled or not self.connected:
logger.warning("[WebSocket] Not connected, cannot publish job status")
return
message = {
"type": "job_status",
"data": {
"job_id": job_id,
"status": status,
"feedback_data": feedback_data,
"return_info": return_info,
"timestamp": time.time(),
},
}
if self.event_loop:
asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop)
logger.debug(f"[WebSocket] Job status published: {job_id} - {status}")
def send_ping(self, ping_id: str, timestamp: float) -> None:
"""发送ping消息"""
if self.is_disabled or not self.connected:
logger.warning("[WebSocket] Not connected, cannot send ping")
return
message = {"type": "ping", "data": {"ping_id": ping_id, "client_timestamp": timestamp}}
if self.event_loop:
asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop)
logger.debug(f"[WebSocket] Ping sent: {ping_id}")
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self.connected and not self.is_disabled