Compare commits

...

3 Commits

Author SHA1 Message Date
Xuwznln
831f4549f9 ws protocol 2025-09-02 18:51:27 +08:00
Xuwznln
f4d4eb06d3 ws test version 2 2025-09-02 18:29:05 +08:00
Xuwznln
e3b8164f6b ws test version 1 2025-09-02 14:32:02 +08:00
3 changed files with 705 additions and 154 deletions

View File

@@ -53,6 +53,7 @@ class JobAddReq(BaseModel):
action: str = Field(examples=["_execute_driver_command_async"], description="action name", default="") action: str = Field(examples=["_execute_driver_command_async"], description="action name", default="")
action_type: str = Field(examples=["unilabos_msgs.action._str_single_input.StrSingleInput"], description="action name", default="") action_type: str = Field(examples=["unilabos_msgs.action._str_single_input.StrSingleInput"], description="action name", default="")
action_args: dict = Field(examples=[{'string': 'string'}], description="action name", default="") action_args: dict = Field(examples=[{'string': 'string'}], description="action name", default="")
task_id: str = Field(examples=["task_id"], description="task uuid")
job_id: str = Field(examples=["job_id"], description="goal uuid") job_id: str = Field(examples=["job_id"], description="goal uuid")
node_id: str = Field(examples=["node_id"], description="node uuid") node_id: str = Field(examples=["node_id"], description="node uuid")
server_info: dict = Field(examples=[{"send_timestamp": 1717000000.0}], description="server info") server_info: dict = Field(examples=[{"send_timestamp": 1717000000.0}], description="server info")

View File

@@ -1,52 +1,615 @@
#!/usr/bin/env python #!/usr/bin/env python
# coding=utf-8 # coding=utf-8
""" """
WebSocket通信客户端 WebSocket通信客户端和任务调度器
基于WebSocket协议的通信客户端实现继承自BaseCommunicationClient。 基于WebSocket协议的通信客户端实现继承自BaseCommunicationClient。
包含WebSocketClient连接管理和TaskScheduler任务调度两个类。
""" """
import json import json
import logging
import time import time
import uuid import uuid
import threading import threading
import asyncio import asyncio
import traceback import traceback
import websockets
import ssl as ssl_module
from dataclasses import dataclass
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from urllib.parse import urlparse from urllib.parse import urlparse
from unilabos.app.controler import job_add
from unilabos.app.model import JobAddReq from unilabos.app.model import JobAddReq
from unilabos.ros.nodes.presets.host_node import HostNode from unilabos.ros.nodes.presets.host_node import HostNode
from unilabos.utils.type_check import serialize_result_info from unilabos.utils.type_check import serialize_result_info
try:
import websockets
import ssl as ssl_module
HAS_WEBSOCKETS = True
except ImportError:
HAS_WEBSOCKETS = False
from unilabos.app.communication import BaseCommunicationClient from unilabos.app.communication import BaseCommunicationClient
from unilabos.config.config import WSConfig, HTTPConfig, BasicConfig from unilabos.config.config import WSConfig, HTTPConfig, BasicConfig
from unilabos.utils import logger from unilabos.utils import logger
@dataclass
class QueueItem:
"""队列项数据结构"""
task_type: str # "query_action_status" 或 "job_call_back_status"
device_id: str
action_name: str
task_id: str
job_id: str
device_action_key: str
next_run_time: float # 下次执行时间戳
retry_count: int = 0 # 重试次数
class TaskScheduler:
"""
任务调度器类
负责任务队列管理、状态跟踪、业务逻辑处理等功能。
"""
def __init__(self, message_sender: "WebSocketClient"):
"""初始化任务调度器"""
self.message_sender = message_sender
# 队列管理
self.action_queue = [] # 任务队列
self.action_queue_lock = threading.Lock() # 队列锁
# 任务状态跟踪
self.active_jobs = {} # job_id -> 任务信息
self.cancel_events = {} # job_id -> asyncio.Event for cancellation
# 立即执行标记字典 - device_id+action_name -> timestamp
self.immediate_execution_flags = {} # 存储需要立即执行的设备动作组合
self.immediate_execution_lock = threading.Lock() # 立即执行标记锁
# 队列处理器
self.queue_processor_thread = None
self.queue_running = False
# 队列处理器相关方法
def start(self) -> None:
"""启动任务调度器"""
if self.queue_running:
logger.warning("[TaskScheduler] Already running")
return
self.queue_running = True
self.queue_processor_thread = threading.Thread(
target=self._run_queue_processor, daemon=True, name="TaskScheduler"
)
self.queue_processor_thread.start()
def stop(self) -> None:
"""停止任务调度器"""
self.queue_running = False
if self.queue_processor_thread and self.queue_processor_thread.is_alive():
self.queue_processor_thread.join(timeout=5)
logger.info("[TaskScheduler] Stopped")
def _run_queue_processor(self):
"""在独立线程中运行队列处理器"""
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
loop.run_until_complete(self._action_queue_processor())
except Exception as e:
logger.error(f"[TaskScheduler] Queue processor thread error: {str(e)}")
finally:
if loop:
loop.close()
async def _action_queue_processor(self) -> None:
"""队列处理器 - 从队列头部取出任务处理保持顺序使用list避免队尾排队问题"""
logger.info("[TaskScheduler] Action queue processor started")
try:
while self.queue_running:
try:
current_time = time.time()
items_to_process = []
items_to_requeue = []
# 使用锁安全地复制队列内容
with self.action_queue_lock:
if not self.action_queue:
# 队列为空,等待一段时间
pass
else:
# 复制队列内容以避免并发修改问题
items_to_process = self.action_queue.copy()
self.action_queue.clear()
if not items_to_process:
await asyncio.sleep(0.2) # 队列为空时等待
continue
with self.immediate_execution_lock:
expired_keys = [k for k, v in self.immediate_execution_flags.items() if current_time > v]
for k in expired_keys:
del self.immediate_execution_flags[k]
immediate_execution = self.immediate_execution_flags.copy()
# 处理每个任务
for item in items_to_process:
try:
# 检查是否到了执行时间,是我们本地的执行时间,按顺序填入
if current_time < item.next_run_time and item.device_action_key not in immediate_execution:
# 还没到执行时间,保留在队列中(保持原有顺序)
items_to_requeue.append(item)
continue
# 执行相应的任务
should_continue = False
if item.task_type == "query_action_status":
should_continue = await self._process_query_status_item(item)
elif item.task_type == "job_call_back_status":
should_continue = await self._process_job_callback_item(item)
else:
logger.warning(f"[TaskScheduler] Unknown task type: {item.task_type}")
continue
# 如果需要继续,放入重新排队列表
if should_continue:
item.next_run_time = current_time + 10 # 10秒后再次执行
item.retry_count += 1
items_to_requeue.append(item)
logger.trace( # type: ignore
f"[TaskScheduler] Re-queued {item.job_id} {item.task_type} "
f"for {item.device_action_key}"
)
else:
logger.debug(
f"[TaskScheduler] Completed {item.job_id} {item.task_type} "
f"for {item.device_action_key}"
)
except Exception as e:
logger.error(f"[TaskScheduler] Error processing item {item.task_type}: {str(e)}")
# 将需要重新排队的任务放回队列开头(保持原有顺序,确保优先于新任务执行)
if items_to_requeue and self.action_queue is not None:
with self.action_queue_lock:
self.action_queue = items_to_requeue + self.action_queue
await asyncio.sleep(0.1) # 短暂等待避免过度占用CPU
except Exception as e:
logger.error(f"[TaskScheduler] Error in queue processor: {str(e)}")
await asyncio.sleep(1) # 错误后稍等再继续
except asyncio.CancelledError:
logger.info("[TaskScheduler] Action queue processor cancelled")
except Exception as e:
logger.error(f"[TaskScheduler] Fatal error in queue processor: {str(e)}")
finally:
logger.info("[TaskScheduler] Action queue processor stopped")
# 队列处理方法
async def _process_query_status_item(self, item: QueueItem) -> bool:
"""处理query_action_status类型的队列项返回True表示需要继续False表示可以停止"""
try:
# 检查设备状态
host_node = HostNode.get_instance(0)
if not host_node:
logger.error("[TaskScheduler] HostNode instance not available in queue processor")
return False
action_jobs = len(host_node._device_action_status[item.device_action_key].job_ids)
free = not bool(action_jobs)
# 发送状态报告
if free:
# 设备空闲,发送最终状态并停止
# 下面要增加和handle_query_state相同的逻辑
host_node._device_action_status[item.device_action_key].job_ids[item.job_id] = time.time()
await self._publish_device_action_state(
item.device_id, item.action_name, item.task_id, item.job_id, "query_action_status", True, 0
)
return False # 停止继续监控
else:
# 设备忙碌,发送状态并继续监控
await self._publish_device_action_state(
item.device_id, item.action_name, item.task_id, item.job_id, "query_action_status", False, 10
)
return True # 继续监控
except Exception as e:
logger.error(f"[TaskScheduler] Error processing query status item: {str(e)}")
return False # 出错则停止
async def _process_job_callback_item(self, item: QueueItem) -> bool:
"""处理job_call_back_status类型的队列项返回True表示需要继续False表示可以停止"""
try:
# 检查任务是否还在活跃列表中
if item.job_id not in self.active_jobs:
logger.debug(f"[TaskScheduler] Job {item.job_id} no longer active")
return False
# 检查是否收到取消信号
if item.job_id in self.cancel_events and self.cancel_events[item.job_id].is_set():
logger.info(f"[TaskScheduler] Job {item.job_id} cancelled via cancel event")
return False
# 检查设备状态
host_node = HostNode.get_instance(0)
if not host_node:
logger.error(
f"[TaskScheduler] HostNode instance not available in job callback queue for job_id: {item.job_id}"
)
return False
action_jobs = len(host_node._device_action_status[item.device_action_key].job_ids)
free = not bool(action_jobs)
# 发送job_call_back_status状态
await self._publish_device_action_state(
item.device_id, item.action_name, item.task_id, item.job_id, "job_call_back_status", free, 10
)
# 如果任务完成,停止监控
if free:
return False
else:
return True # 继续监控
except Exception as e:
logger.error(f"[TaskScheduler] Error processing job callback item for job_id {item.job_id}: {str(e)}")
return False # 出错则停止
# 消息发送方法
async def _publish_device_action_state(
self, device_id: str, action_name: str, task_id: str, job_id: str, typ: str, free: bool, need_more: int
) -> None:
"""发布设备动作状态"""
message = {
"action": "report_action_state",
"data": {
"type": typ,
"device_id": device_id,
"action_name": action_name,
"task_id": task_id,
"job_id": job_id,
"free": free,
"need_more": need_more,
},
}
await self.message_sender.send_message(message)
# 业务逻辑处理方法
async def handle_query_state(self, data: Dict[str, str]) -> None:
"""处理query_action_state消息"""
device_id = data.get("device_id", "")
if not device_id:
logger.error("[TaskScheduler] query_action_state missing device_id")
return
action_name = data.get("action_name", "")
if not action_name:
logger.error("[TaskScheduler] query_action_state missing action_name")
return
task_id = data.get("task_id", "")
if not task_id:
logger.error("[TaskScheduler] query_action_state missing task_id")
return
job_id = data.get("job_id", "")
if not job_id:
logger.error("[TaskScheduler] query_action_state missing job_id")
return
device_action_key = f"/devices/{device_id}/{action_name}"
host_node = HostNode.get_instance(0)
if not host_node:
logger.error("[TaskScheduler] HostNode instance not available")
return
action_jobs = len(host_node._device_action_status[device_action_key].job_ids)
free = not bool(action_jobs)
# 如果设备空闲立即响应free状态
if free:
await self._publish_device_action_state(
device_id, action_name, task_id, job_id, "query_action_status", True, 0
)
logger.debug(f"[TaskScheduler] {job_id} Device {device_id}/{action_name} is free, responded immediately")
host_node = HostNode.get_instance(0)
if not host_node:
logger.error(f"[TaskScheduler] HostNode instance not available for job_id: {job_id}")
return
host_node._device_action_status[device_action_key].job_ids[job_id] = time.time()
return
# 设备忙碌时,检查是否已有相同的轮询任务
if self.action_queue is not None:
with self.action_queue_lock:
# 检查是否已存在相同job_id和task_id的轮询任务
for existing_item in self.action_queue:
if (
existing_item.task_type == "query_action_status"
and existing_item.job_id == job_id
and existing_item.task_id == task_id
and existing_item.device_action_key == device_action_key
):
logger.error(
f"[TaskScheduler] Duplicate query_action_state ignored: "
f"job_id={job_id}, task_id={task_id}, server error"
)
return
# 没有重复,加入轮询队列
queue_item = QueueItem(
task_type="query_action_status",
device_id=device_id,
action_name=action_name,
task_id=task_id,
job_id=job_id,
device_action_key=device_action_key,
next_run_time=time.time() + 10, # 10秒后执行
)
self.action_queue.append(queue_item)
logger.debug(
f"[TaskScheduler] {job_id} Device {device_id}/{action_name} is busy, "
f"added to polling queue {action_jobs}"
)
# 立即发送busy状态
await self._publish_device_action_state(
device_id, action_name, task_id, job_id, "query_action_status", False, 10
)
else:
logger.warning("[TaskScheduler] Action queue not available")
async def handle_job_start(self, data: Dict[str, Any]):
"""处理作业启动消息"""
try:
req = JobAddReq(**data)
device_action_key = f"/devices/{req.device_id}/{req.action}"
logger.info(
f"[TaskScheduler] Starting job with job_id: {req.job_id}, "
f"device: {req.device_id}, action: {req.action}"
)
# 添加到活跃任务
self.active_jobs[req.job_id] = {
"device_id": req.device_id,
"action_name": req.action,
"task_id": data.get("task_id", ""),
"start_time": time.time(),
"device_action_key": device_action_key,
"callback_started": False, # 标记callback是否已启动
}
# 创建取消事件todo要移动到query_state中
self.cancel_events[req.job_id] = asyncio.Event()
try:
# 启动callback定时发送
await self._start_job_callback(req.job_id, req.device_id, req.action, req.task_id, device_action_key)
# 创建兼容HostNode的QueueItem对象
job_queue_item = QueueItem(
task_type="job_call_back_status",
device_id=req.device_id,
action_name=req.action,
task_id=req.task_id,
job_id=req.job_id,
device_action_key=device_action_key,
next_run_time=time.time(),
)
host_node = HostNode.get_instance(0)
if not host_node:
logger.error(f"[TaskScheduler] HostNode instance not available for job_id: {req.job_id}")
return
host_node.send_goal(
job_queue_item,
action_type=req.action_type,
action_kwargs=req.action_args,
server_info=req.server_info,
)
except Exception as e:
logger.error(f"[TaskScheduler] Exception during job start for job_id {req.job_id}: {str(e)}")
traceback.print_exc()
# 异常结束先停止callback然后发送失败状态
await self._stop_job_callback(
req.job_id, "failed", serialize_result_info(traceback.format_exc(), False, {})
)
host_node = HostNode.get_instance(0)
if host_node:
host_node._device_action_status[device_action_key].job_ids.pop(req.job_id, None)
logger.warning(f"[TaskScheduler] Cleaned up failed job from HostNode: {req.job_id}")
except Exception as e:
logger.error(f"[TaskScheduler] Error handling job start: {str(e)}")
async def handle_cancel_action(self, data: Dict[str, Any]) -> None:
"""处理取消动作请求"""
task_id = data.get("task_id")
job_id = data.get("job_id")
logger.debug(f"[TaskScheduler] Handling cancel action request - task_id: {task_id}, job_id: {job_id}")
if not task_id and not job_id:
logger.error("[TaskScheduler] cancel_action missing both task_id and job_id")
return
# 通过job_id取消
if job_id:
logger.info(f"[TaskScheduler] Cancelling job by job_id: {job_id}")
# 设置取消事件
if job_id in self.cancel_events:
self.cancel_events[job_id].set()
logger.debug(f"[TaskScheduler] Set cancel event for job_id: {job_id}")
else:
logger.warning(f"[TaskScheduler] Cancel event not found for job_id: {job_id}")
# 停止job callback并发送取消状态
if job_id in self.active_jobs:
logger.debug(f"[TaskScheduler] Found active job for cancellation: {job_id}")
# 调用HostNode的cancel_goal
host_node = HostNode.get_instance(0)
if host_node:
host_node.cancel_goal(job_id)
logger.info(f"[TaskScheduler] Cancelled goal in HostNode for job_id: {job_id}")
else:
logger.error(f"[TaskScheduler] HostNode not available for cancel goal: {job_id}")
# 停止callback并发送取消状态
await self._stop_job_callback(job_id, "cancelled", "Job was cancelled by user request")
logger.info(f"[TaskScheduler] Stopped job callback and sent cancel status for job_id: {job_id}")
else:
logger.warning(f"[TaskScheduler] Job not found in active jobs for cancellation: {job_id}")
# 通过task_id取消需要查找对应的job_id
if task_id and not job_id:
logger.debug(f"[TaskScheduler] Cancelling jobs by task_id: {task_id}")
jobs_to_cancel = []
for jid, job_info in self.active_jobs.items():
if job_info.get("task_id") == task_id:
jobs_to_cancel.append(jid)
logger.debug(
f"[TaskScheduler] Found {len(jobs_to_cancel)} jobs to cancel for task_id {task_id}: {jobs_to_cancel}"
)
for jid in jobs_to_cancel:
logger.debug(f"[TaskScheduler] Recursively cancelling job_id: {jid} for task_id: {task_id}")
# 递归调用自身来取消每个job
await self.handle_cancel_action({"job_id": jid})
logger.debug(f"[TaskScheduler] Completed cancel action handling - task_id: {task_id}, job_id: {job_id}")
# job管理方法
async def _start_job_callback(
self, job_id: str, device_id: str, action_name: str, task_id: str, device_action_key: str
) -> None:
"""启动job的callback定时发送"""
if job_id not in self.active_jobs:
logger.debug(f"[TaskScheduler] Job not found in active jobs when starting callback: {job_id}")
return
# 检查是否已经启动过callback
if self.active_jobs[job_id].get("callback_started", False):
logger.warning(f"[TaskScheduler] Job callback already started for job_id: {job_id}")
return
# 标记callback已启动
self.active_jobs[job_id]["callback_started"] = True
# 将job_call_back_status任务放入队列
queue_item = QueueItem(
task_type="job_call_back_status",
device_id=device_id,
action_name=action_name,
task_id=task_id,
job_id=job_id,
device_action_key=device_action_key,
next_run_time=time.time() + 10, # 10秒后开始报送
)
if self.action_queue is not None:
with self.action_queue_lock:
self.action_queue.append(queue_item)
else:
logger.debug(f"[TaskScheduler] Action queue not available for job callback: {job_id}")
async def _stop_job_callback(self, job_id: str, final_status: str, return_info: Optional[str] = None) -> None:
"""停止job的callback定时发送并发送最终结果"""
logger.info(f"[TaskScheduler] Stopping job callback for job_id: {job_id} with final status: {final_status}")
if job_id not in self.active_jobs:
logger.debug(f"[TaskScheduler] Job {job_id} not found in active jobs when stopping callback")
return
job_info = self.active_jobs[job_id]
device_id = job_info["device_id"]
action_name = job_info["action_name"]
task_id = job_info["task_id"]
device_action_key = job_info["device_action_key"]
logger.debug(
f"[TaskScheduler] Job {job_id} details - device: {device_id}, action: {action_name}, task: {task_id}"
)
# 移除活跃任务和取消事件这会让队列处理器自动停止callback
self.active_jobs.pop(job_id, None)
self.cancel_events.pop(job_id, None)
logger.debug(f"[TaskScheduler] Removed job {job_id} from active jobs and cancel events")
# 发送最终的callback状态
await self._publish_device_action_state(
device_id, action_name, task_id, job_id, "job_call_back_status", True, 0
)
logger.debug(f"[TaskScheduler] Completed stopping job callback for {job_id} with final status: {final_status}")
# 外部接口方法
def publish_job_status(
self, feedback_data: dict, item: "QueueItem", status: str, return_info: Optional[str] = None
) -> None:
"""发布作业状态拦截最终结果给HostNode调用的接口"""
if not self.message_sender.is_connected():
logger.debug(f"[TaskScheduler] Not connected, cannot publish job status for job_id: {item.job_id}")
return
# 拦截最终结果状态
if status in ["success", "failed"]:
host_node = HostNode.get_instance(0)
if host_node:
host_node._device_action_status[item.device_action_key].job_ids.pop(item.job_id)
logger.info(f"[TaskScheduler] Intercepting final status for job_id: {item.job_id} - {status}")
# 给其他同名action至少执行一次的机会
with self.immediate_execution_lock:
self.immediate_execution_flags[item.device_action_key] = time.time() + 3
# 如果是最终状态通过_stop_job_callback处理
if self.message_sender.event_loop:
asyncio.run_coroutine_threadsafe(
self._stop_job_callback(item.job_id, status, return_info), self.message_sender.event_loop
).result()
# 执行结果信息上传
message = {
"action": "job_status",
"data": {
"job_id": item.job_id,
"task_id": item.task_id,
"device_id": item.device_id,
"action_name": item.action_name,
"status": status,
"feedback_data": feedback_data,
"return_info": return_info,
"timestamp": time.time(),
},
}
try:
loop = asyncio.get_event_loop()
loop.create_task(self.message_sender.send_message(message))
except RuntimeError:
asyncio.run(self.message_sender.send_message(message))
logger.trace(f"[TaskScheduler] Job status published: {item.job_id} - {status}") # type: ignore
def cancel_goal(self, job_id: str) -> None:
"""取消指定的任务(给外部调用的接口)"""
logger.debug(f"[TaskScheduler] External cancel request for job_id: {job_id}")
if job_id in self.cancel_events:
logger.debug(f"[TaskScheduler] Found cancel event for job_id: {job_id}, processing cancellation")
try:
loop = asyncio.get_event_loop()
loop.create_task(self.handle_cancel_action({"job_id": job_id}))
logger.debug(f"[TaskScheduler] Scheduled cancel action for job_id: {job_id}")
except RuntimeError:
asyncio.run(self.handle_cancel_action({"job_id": job_id}))
logger.debug(f"[TaskScheduler] Executed cancel action for job_id: {job_id}")
logger.debug(f"[TaskScheduler] Initiated cancel for job_id: {job_id}")
else:
logger.debug(f"[TaskScheduler] Job {job_id} not found in cancel events for cancellation")
class WebSocketClient(BaseCommunicationClient): class WebSocketClient(BaseCommunicationClient):
""" """
WebSocket通信客户端类 WebSocket通信客户端类
实现基于WebSocket协议的实时通信功能 专注于WebSocket连接管理和消息传输
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
if not HAS_WEBSOCKETS:
logger.error("[WebSocket] websockets库未安装WebSocket功能不可用")
self.is_disabled = True
return
self.is_disabled = False self.is_disabled = False
self.client_id = f"{uuid.uuid4()}" self.client_id = f"{uuid.uuid4()}"
@@ -62,11 +625,22 @@ class WebSocketClient(BaseCommunicationClient):
self.message_queue = asyncio.Queue() if not self.is_disabled else None self.message_queue = asyncio.Queue() if not self.is_disabled else None
self.reconnect_count = 0 self.reconnect_count = 0
# 任务调度器
self.task_scheduler = None
# 构建WebSocket URL # 构建WebSocket URL
self._build_websocket_url() self._build_websocket_url()
logger.info(f"[WebSocket] Client_id: {self.client_id}") logger.info(f"[WebSocket] Client_id: {self.client_id}")
# 初始化方法
def _initialize_task_scheduler(self):
"""初始化任务调度器"""
if not self.task_scheduler:
self.task_scheduler = TaskScheduler(self)
self.task_scheduler.start()
logger.info("[WebSocket] Task scheduler initialized")
def _build_websocket_url(self): def _build_websocket_url(self):
"""构建WebSocket连接URL""" """构建WebSocket连接URL"""
if not HTTPConfig.remote_addr: if not HTTPConfig.remote_addr:
@@ -81,14 +655,15 @@ class WebSocketClient(BaseCommunicationClient):
scheme = "wss" scheme = "wss"
else: else:
scheme = "ws" scheme = "ws"
if ":" in parsed.netloc: if ":" in parsed.netloc and parsed.port is not None:
self.websocket_url = f"{scheme}://{parsed.hostname}:{parsed.port + 1}/api/v1/lab" self.websocket_url = f"{scheme}://{parsed.hostname}:{parsed.port + 1}/api/v1/ws/schedule"
else: else:
self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/lab" self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/ws/schedule"
logger.debug(f"[WebSocket] URL: {self.websocket_url}") logger.debug(f"[WebSocket] URL: {self.websocket_url}")
# 连接管理方法
def start(self) -> None: def start(self) -> None:
"""启动WebSocket连接""" """启动WebSocket连接和任务调度器"""
if self.is_disabled: if self.is_disabled:
logger.warning("[WebSocket] WebSocket is disabled, skipping connection.") logger.warning("[WebSocket] WebSocket is disabled, skipping connection.")
return return
@@ -99,6 +674,9 @@ class WebSocketClient(BaseCommunicationClient):
logger.info(f"[WebSocket] Starting connection to {self.websocket_url}") logger.info(f"[WebSocket] Starting connection to {self.websocket_url}")
# 初始化任务调度器
self._initialize_task_scheduler()
self.is_running = True self.is_running = True
# 在单独线程中运行WebSocket连接 # 在单独线程中运行WebSocket连接
@@ -106,7 +684,7 @@ class WebSocketClient(BaseCommunicationClient):
self.connection_thread.start() self.connection_thread.start()
def stop(self) -> None: def stop(self) -> None:
"""停止WebSocket连接""" """停止WebSocket连接和任务调度器"""
if self.is_disabled: if self.is_disabled:
return return
@@ -114,6 +692,10 @@ class WebSocketClient(BaseCommunicationClient):
self.is_running = False self.is_running = False
self.connected = False self.connected = False
# 停止任务调度器
if self.task_scheduler:
self.task_scheduler.stop()
if self.event_loop and self.event_loop.is_running(): if self.event_loop and self.event_loop.is_running():
asyncio.run_coroutine_threadsafe(self._close_connection(), self.event_loop) asyncio.run_coroutine_threadsafe(self._close_connection(), self.event_loop)
@@ -145,19 +727,22 @@ class WebSocketClient(BaseCommunicationClient):
assert self.websocket_url is not None assert self.websocket_url is not None
if self.websocket_url.startswith("wss://"): if self.websocket_url.startswith("wss://"):
ssl_context = ssl_module.create_default_context() ssl_context = ssl_module.create_default_context()
ws_logger = logging.getLogger("websockets.client")
ws_logger.setLevel(logging.INFO)
async with websockets.connect( async with websockets.connect(
self.websocket_url, self.websocket_url,
ssl=ssl_context, ssl=ssl_context,
ping_interval=WSConfig.ping_interval, ping_interval=WSConfig.ping_interval,
ping_timeout=10, ping_timeout=10,
additional_headers={"Authorization": f"Lab {BasicConfig.auth_secret()}"}, additional_headers={"Authorization": f"Lab {BasicConfig.auth_secret()}"},
logger=ws_logger,
) as websocket: ) as websocket:
self.websocket = websocket self.websocket = websocket
self.connected = True self.connected = True
self.reconnect_count = 0 self.reconnect_count = 0
logger.info(f"[WebSocket] Connected to {self.websocket_url}") logger.info(f"[WebSocket] Connected to {self.websocket_url}")
# 处理消息 # 处理消息
await self._message_handler() await self._message_handler()
@@ -167,6 +752,9 @@ class WebSocketClient(BaseCommunicationClient):
except Exception as e: except Exception as e:
logger.error(f"[WebSocket] Connection error: {str(e)}") logger.error(f"[WebSocket] Connection error: {str(e)}")
self.connected = False self.connected = False
finally:
# WebSocket连接结束时只需重置websocket对象
self.websocket = None
# 重连逻辑 # 重连逻辑
if self.is_running and self.reconnect_count < WSConfig.max_reconnect_attempts: if self.is_running and self.reconnect_count < WSConfig.max_reconnect_attempts:
@@ -188,19 +776,7 @@ class WebSocketClient(BaseCommunicationClient):
await self.websocket.close() await self.websocket.close()
self.websocket = None self.websocket = None
async def _send_message(self, message: Dict[str, Any]): # 消息处理方法
"""发送消息"""
if not self.connected or not self.websocket:
logger.warning("[WebSocket] Not connected, cannot send message")
return
try:
message_str = json.dumps(message, ensure_ascii=False)
await self.websocket.send(message_str)
logger.debug(f"[WebSocket] Message sent: {message['type']}")
except Exception as e:
logger.error(f"[WebSocket] Failed to send message: {str(e)}")
async def _message_handler(self): async def _message_handler(self):
"""处理接收到的消息""" """处理接收到的消息"""
if not self.websocket: if not self.websocket:
@@ -221,69 +797,28 @@ class WebSocketClient(BaseCommunicationClient):
except Exception as e: except Exception as e:
logger.error(f"[WebSocket] Message handler error: {str(e)}") logger.error(f"[WebSocket] Message handler error: {str(e)}")
async def _handle_query_state(self, data: Dict[str, str]) -> None:
device_id = data.get("device_id", "")
if not device_id:
logger.error("[WebSocket] query_action_state missing device_id")
return
action_name = data.get("action_name", "")
if not action_name:
logger.error("[WebSocket] query_action_state missing action_name")
return
task_id = data.get("task_id", "")
if not task_id:
logger.error("[WebSocket] query_action_state missing task_id")
return
job_id = data.get("job_id", "")
if not task_id:
logger.error("[WebSocket] query_action_state missing job_id")
return
device_action_key = f"/devices/{device_id}/{action_name}"
action_jobs = len(HostNode.get_instance()._device_action_status[device_action_key].job_ids)
message = {
"type": "report_action_state",
"data": {
"device_id": device_id,
"action_name": action_name,
"task_id": task_id,
"job_id": job_id,
"free": bool(action_jobs)
},
}
await self._send_message(message)
async def _process_message(self, input_message: Dict[str, Any]): async def _process_message(self, input_message: Dict[str, Any]):
"""处理收到的消息""" """处理收到的消息"""
message_type = input_message.get("type", "") message_type = input_message.get("action", "")
data = input_message.get("data", {}) data = input_message.get("data", {})
if message_type == "job_start":
# 处理作业启动消息
await self._handle_job_start(data)
elif message_type == "pong":
# 处理pong响应
self._handle_pong_sync(data)
elif message_type == "query_action_state":
await self._handle_query_state(data)
else:
logger.debug(f"[WebSocket] Unknown message type: {message_type}")
async def _handle_job_start(self, data: Dict[str, Any]): if message_type == "pong":
"""处理作业启动消息""" # 处理pong响应WebSocket层面的连接管理
try: self._handle_pong_sync(data)
req = JobAddReq(**data) elif self.task_scheduler:
try: # 其他消息交给TaskScheduler处理
req.job_id = str(uuid.uuid4()) if message_type == "job_start":
logger.info(f"[WebSocket] Job started: {req.job_id}") await self.task_scheduler.handle_job_start(data)
HostNode.get_instance().send_goal(req.device_id, action_type=req.action_type, action_name=req.action, elif message_type == "query_action_state":
action_kwargs=req.action_args, goal_uuid=req.job_id, await self.task_scheduler.handle_query_state(data)
server_info=req.server_info) elif message_type == "cancel_action":
except Exception as e: await self.task_scheduler.handle_cancel_action(data)
for bridge in HostNode.get_instance().bridges: elif message_type == "":
traceback.print_exc() return
if hasattr(bridge, "publish_job_status"): else:
self.publish_job_status({}, req.job_id, "failed", serialize_result_info(traceback.format_exc(), False, {})) logger.debug(f"[WebSocket] Unknown message: {input_message}")
except Exception as e: else:
logger.error(f"[WebSocket] Error handling job start: {str(e)}") logger.warning(f"[WebSocket] Task scheduler not available for message: {message_type}")
def _handle_pong_sync(self, pong_data: Dict[str, Any]): def _handle_pong_sync(self, pong_data: Dict[str, Any]):
"""同步处理pong响应""" """同步处理pong响应"""
@@ -291,18 +826,43 @@ class WebSocketClient(BaseCommunicationClient):
if host_node: if host_node:
host_node.handle_pong_response(pong_data) host_node.handle_pong_response(pong_data)
# 实现抽象基类的方法 # 消息发送方法
async def _send_message(self, message: Dict[str, Any]):
"""内部发送消息方法"""
if not self.connected or not self.websocket:
logger.warning("[WebSocket] Not connected, cannot send message")
return
try:
message_str = json.dumps(message, ensure_ascii=False)
await self.websocket.send(message_str)
logger.debug(f"[WebSocket] Message sent: {message['action']}")
except Exception as e:
logger.error(f"[WebSocket] Failed to send message: {str(e)}")
# MessageSender接口实现
async def send_message(self, message: Dict[str, Any]) -> None:
"""发送消息TaskScheduler调用的接口"""
await self._send_message(message)
def is_connected(self) -> bool:
"""检查是否已连接TaskScheduler调用的接口"""
return self.connected and not self.is_disabled
# 基类方法实现
def publish_device_status(self, device_status: dict, device_id: str, property_name: str) -> None: def publish_device_status(self, device_status: dict, device_id: str, property_name: str) -> None:
"""发布设备状态""" """发布设备状态"""
if self.is_disabled or not self.connected: if self.is_disabled or not self.connected:
return return
message = { message = {
"type": "device_status", "action": "device_status",
"data": { "data": {
"device_id": device_id, "device_id": device_id,
"property_name": property_name, "data": {
"status": device_status.get(device_id, {}).get(property_name), "property_name": property_name,
"timestamp": time.time(), "status": device_status.get(device_id, {}).get(property_name),
"timestamp": time.time(),
},
}, },
} }
if self.event_loop: if self.event_loop:
@@ -310,37 +870,29 @@ class WebSocketClient(BaseCommunicationClient):
logger.debug(f"[WebSocket] Device status published: {device_id}.{property_name}") logger.debug(f"[WebSocket] Device status published: {device_id}.{property_name}")
def publish_job_status( def publish_job_status(
self, feedback_data: dict, job_id: str, status: str, return_info: Optional[str] = None self, feedback_data: dict, item: "QueueItem", status: str, return_info: Optional[str] = None
) -> None: ) -> None:
"""发布作业状态""" """发布作业状态转发给TaskScheduler"""
if self.is_disabled or not self.connected: if self.task_scheduler:
logger.warning("[WebSocket] Not connected, cannot publish job status") self.task_scheduler.publish_job_status(feedback_data, item, status, return_info)
return else:
message = { logger.debug(f"[WebSocket] Task scheduler not available for job status: {item.job_id}")
"type": "job_status",
"data": {
"job_id": job_id,
"status": status,
"feedback_data": feedback_data,
"return_info": return_info,
"timestamp": time.time(),
},
}
if self.event_loop:
asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop)
logger.debug(f"[WebSocket] Job status published: {job_id} - {status}")
def send_ping(self, ping_id: str, timestamp: float) -> None: def send_ping(self, ping_id: str, timestamp: float) -> None:
"""发送ping消息""" """发送ping消息"""
if self.is_disabled or not self.connected: if self.is_disabled or not self.connected:
logger.warning("[WebSocket] Not connected, cannot send ping") logger.warning("[WebSocket] Not connected, cannot send ping")
return return
message = {"type": "ping", "data": {"ping_id": ping_id, "client_timestamp": timestamp}} message = {"action": "ping", "data": {"ping_id": ping_id, "client_timestamp": timestamp}}
if self.event_loop: if self.event_loop:
asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop) asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop)
logger.debug(f"[WebSocket] Ping sent: {ping_id}") logger.debug(f"[WebSocket] Ping sent: {ping_id}")
@property def cancel_goal(self, job_id: str) -> None:
def is_connected(self) -> bool: """取消指定的任务转发给TaskScheduler"""
"""检查是否已连接""" logger.debug(f"[WebSocket] Received cancel goal request for job_id: {job_id}")
return self.connected and not self.is_disabled if self.task_scheduler:
self.task_scheduler.cancel_goal(job_id)
logger.debug(f"[WebSocket] Forwarded cancel goal to TaskScheduler for job_id: {job_id}")
else:
logger.debug(f"[WebSocket] Task scheduler not available for cancel goal: {job_id}")

View File

@@ -6,7 +6,7 @@ import threading
import time import time
import traceback import traceback
import uuid import uuid
from typing import Optional, Dict, Any, List, ClassVar, Set, Union from typing import TYPE_CHECKING, Optional, Dict, Any, List, ClassVar, Set, Union
from action_msgs.msg import GoalStatus from action_msgs.msg import GoalStatus
from geometry_msgs.msg import Point from geometry_msgs.msg import Point
@@ -42,6 +42,9 @@ from unilabos.ros.nodes.presets.controller_node import ControllerNode
from unilabos.utils.exception import DeviceClassInvalid from unilabos.utils.exception import DeviceClassInvalid
from unilabos.utils.type_check import serialize_result_info from unilabos.utils.type_check import serialize_result_info
if TYPE_CHECKING:
from unilabos.app.ws_client import QueueItem
@dataclass @dataclass
class DeviceActionStatus: class DeviceActionStatus:
@@ -621,11 +624,9 @@ class HostNode(BaseROS2DeviceNode):
def send_goal( def send_goal(
self, self,
device_id: str, item: "QueueItem",
action_type: str, action_type: str,
action_name: str,
action_kwargs: Dict[str, Any], action_kwargs: Dict[str, Any],
goal_uuid: Optional[str] = None,
server_info: Optional[Dict[str, Any]] = None, server_info: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
""" """
@@ -639,12 +640,9 @@ class HostNode(BaseROS2DeviceNode):
goal_uuid: 目标UUID如果为None则自动生成 goal_uuid: 目标UUID如果为None则自动生成
server_info: 服务器发送信息,包含发送时间戳等 server_info: 服务器发送信息,包含发送时间戳等
""" """
if goal_uuid is None: u = uuid.UUID(item.job_id)
u = uuid.uuid4() device_id = item.device_id
else: action_name = item.action_name
u = uuid.UUID(goal_uuid)
device_action_key = f"/devices/{device_id}/{action_name}"
self._device_action_status[device_action_key].job_ids[str(u)] = time.time()
if action_type.startswith("UniLabJsonCommand"): if action_type.startswith("UniLabJsonCommand"):
if action_name.startswith("auto-"): if action_name.startswith("auto-"):
action_name = action_name[5:] action_name = action_name[5:]
@@ -676,43 +674,43 @@ class HostNode(BaseROS2DeviceNode):
future = action_client.send_goal_async( future = action_client.send_goal_async(
goal_msg, goal_msg,
feedback_callback=lambda feedback_msg: self.feedback_callback(action_id, str(u), feedback_msg), feedback_callback=lambda feedback_msg: self.feedback_callback(item, action_id, feedback_msg),
goal_uuid=goal_uuid_obj, goal_uuid=goal_uuid_obj,
) )
future.add_done_callback( future.add_done_callback(
lambda future: self.goal_response_callback(device_action_key, action_id, str(u), future) lambda future: self.goal_response_callback(item, action_id, future)
) )
def goal_response_callback(self, device_action_key: str, action_id: str, uuid_str: str, future) -> None: def goal_response_callback(self, item: "QueueItem", action_id: str, future) -> None:
"""目标响应回调""" """目标响应回调"""
goal_handle = future.result() goal_handle = future.result()
if not goal_handle.accepted: if not goal_handle.accepted:
self.lab_logger().warning(f"[Host Node] Goal {action_id} ({uuid_str}) rejected") self.lab_logger().warning(f"[Host Node] Goal {item.action_name} ({item.job_id}) rejected")
return return
self.lab_logger().info(f"[Host Node] Goal {action_id} ({uuid_str}) accepted") self.lab_logger().info(f"[Host Node] Goal {action_id} ({item.job_id}) accepted")
self._goals[uuid_str] = goal_handle self._goals[item.job_id] = goal_handle
goal_handle.get_result_async().add_done_callback( goal_handle.get_result_async().add_done_callback(
lambda future: self.get_result_callback(device_action_key, action_id, uuid_str, future) lambda future: self.get_result_callback(item, action_id, future)
) )
def feedback_callback(self, action_id: str, uuid_str: str, feedback_msg) -> None: def feedback_callback(self, item: "QueueItem", action_id: str, feedback_msg) -> None:
"""反馈回调""" """反馈回调"""
feedback_data = convert_from_ros_msg(feedback_msg) feedback_data = convert_from_ros_msg(feedback_msg)
feedback_data.pop("goal_id") feedback_data.pop("goal_id")
self.lab_logger().debug(f"[Host Node] Feedback for {action_id} ({uuid_str}): {feedback_data}") self.lab_logger().trace(f"[Host Node] Feedback for {action_id} ({item.job_id}): {feedback_data}")
for bridge in self.bridges: for bridge in self.bridges:
if hasattr(bridge, "publish_job_status"): if hasattr(bridge, "publish_job_status"):
bridge.publish_job_status(feedback_data, uuid_str, "running") bridge.publish_job_status(feedback_data, item, "running")
def get_result_callback(self, device_action_key: str, action_id: str, uuid_str: str, future) -> None: def get_result_callback(self, item: "QueueItem", action_id: str, future) -> None:
"""获取结果回调""" """获取结果回调"""
job_id = item.job_id
result_msg = future.result().result result_msg = future.result().result
result_data = convert_from_ros_msg(result_msg) result_data = convert_from_ros_msg(result_msg)
status = "success" status = "success"
return_info_str = result_data.get("return_info") return_info_str = result_data.get("return_info")
self._device_action_status[device_action_key].job_ids.pop(uuid_str)
if return_info_str is not None: if return_info_str is not None:
try: try:
ret = json.loads(return_info_str) ret = json.loads(return_info_str)
@@ -732,13 +730,13 @@ class HostNode(BaseROS2DeviceNode):
status = "failed" status = "failed"
return_info_str = serialize_result_info("缺少return_info", False, result_data) return_info_str = serialize_result_info("缺少return_info", False, result_data)
self.lab_logger().info(f"[Host Node] Result for {action_id} ({uuid_str}): {status}") self.lab_logger().info(f"[Host Node] Result for {action_id} ({job_id}): {status}")
self.lab_logger().debug(f"[Host Node] Result data: {result_data}") self.lab_logger().debug(f"[Host Node] Result data: {result_data}")
if uuid_str: if job_id:
for bridge in self.bridges: for bridge in self.bridges:
if hasattr(bridge, "publish_job_status"): if hasattr(bridge, "publish_job_status"):
bridge.publish_job_status(result_data, uuid_str, status, return_info_str) bridge.publish_job_status(result_data, item, status, return_info_str)
def cancel_goal(self, goal_uuid: str) -> None: def cancel_goal(self, goal_uuid: str) -> None:
"""取消目标""" """取消目标"""
@@ -748,14 +746,14 @@ class HostNode(BaseROS2DeviceNode):
else: else:
self.lab_logger().warning(f"[Host Node] Goal {goal_uuid} not found, cannot cancel") self.lab_logger().warning(f"[Host Node] Goal {goal_uuid} not found, cannot cancel")
def get_goal_status(self, uuid_str: str) -> int: def get_goal_status(self, job_id: str) -> int:
"""获取目标状态""" """获取目标状态"""
if uuid_str in self._goals: if job_id in self._goals:
g = self._goals[uuid_str] g = self._goals[job_id]
status = g.status status = g.status
self.lab_logger().debug(f"[Host Node] Goal status for {uuid_str}: {status}") self.lab_logger().debug(f"[Host Node] Goal status for {job_id}: {status}")
return status return status
self.lab_logger().warning(f"[Host Node] Goal {uuid_str} not found, status unknown") self.lab_logger().warning(f"[Host Node] Goal {job_id} not found, status unknown")
return GoalStatus.STATUS_UNKNOWN return GoalStatus.STATUS_UNKNOWN
"""Controller Node""" """Controller Node"""