diff --git a/unilabos/app/ws_client.py b/unilabos/app/ws_client.py index 54389342..621a2138 100644 --- a/unilabos/app/ws_client.py +++ b/unilabos/app/ws_client.py @@ -1,52 +1,676 @@ #!/usr/bin/env python # coding=utf-8 """ -WebSocket通信客户端 +WebSocket通信客户端和任务调度器 基于WebSocket协议的通信客户端实现,继承自BaseCommunicationClient。 +包含WebSocketClient(连接管理)和TaskScheduler(任务调度)两个类。 """ import json +import logging import time import uuid import threading import asyncio import traceback +import websockets +import ssl as ssl_module +from dataclasses import dataclass from typing import Optional, Dict, Any from urllib.parse import urlparse -from unilabos.app.controler import job_add from unilabos.app.model import JobAddReq from unilabos.ros.nodes.presets.host_node import HostNode from unilabos.utils.type_check import serialize_result_info - -try: - import websockets - import ssl as ssl_module - - HAS_WEBSOCKETS = True -except ImportError: - HAS_WEBSOCKETS = False - from unilabos.app.communication import BaseCommunicationClient from unilabos.config.config import WSConfig, HTTPConfig, BasicConfig from unilabos.utils import logger +@dataclass +class QueueItem: + """队列项数据结构""" + + task_type: str # "query_action_status" 或 "job_call_back_status" + device_id: str + action_name: str + task_id: str + job_id: str + device_action_key: str + next_run_time: float # 下次执行时间戳 + retry_count: int = 0 # 重试次数 + + +class TaskScheduler: + """ + 任务调度器类 + + 负责任务队列管理、状态跟踪、业务逻辑处理等功能。 + """ + + def __init__(self, message_sender: "WebSocketClient"): + """初始化任务调度器""" + self.message_sender = message_sender + + # 队列管理 + self.action_queue = [] # 任务队列 + self.action_queue_lock = threading.Lock() # 队列锁 + + # 任务状态跟踪 + self.active_jobs = {} # job_id -> 任务信息 + self.cancel_events = {} # job_id -> asyncio.Event for cancellation + + self.just_free_sets = {} # device_name + action_name -> end_timestamp + + # 队列处理器 + self.queue_processor_thread = None + self.queue_running = False + + # 队列处理器相关方法 + def start(self) -> None: + """启动任务调度器""" + if self.queue_running: + logger.warning("[TaskScheduler] Already running") + return + + self.queue_running = True + self.queue_processor_thread = threading.Thread( + target=self._run_queue_processor, daemon=True, name="TaskScheduler" + ) + self.queue_processor_thread.start() + + def stop(self) -> None: + """停止任务调度器""" + self.queue_running = False + if self.queue_processor_thread and self.queue_processor_thread.is_alive(): + self.queue_processor_thread.join(timeout=5) + logger.info("[TaskScheduler] Stopped") + + def _run_queue_processor(self): + """在独立线程中运行队列处理器""" + loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(loop) + loop.run_until_complete(self._action_queue_processor()) + except Exception as e: + logger.error(f"[TaskScheduler] Queue processor thread error: {str(e)}") + finally: + if loop: + loop.close() + + async def _action_queue_processor(self) -> None: + """队列处理器 - 从队列头部取出任务处理,保持顺序,使用list避免队尾排队问题""" + logger.info("[TaskScheduler] Action queue processor started") + + try: + while self.queue_running: + try: + current_time = time.time() + items_to_process = [] + items_to_requeue = [] + + # 使用锁安全地复制队列内容 + with self.action_queue_lock: + if not self.action_queue: + # 队列为空,等待一段时间 + pass + else: + # 复制队列内容以避免并发修改问题 + items_to_process = self.action_queue.copy() + self.action_queue.clear() + + if not items_to_process: + await asyncio.sleep(0.2) # 队列为空时等待 + continue + + # 处理每个任务 + for item in items_to_process: + try: + # 检查是否到了执行时间,是我们本地的执行时间,按顺序填入 + if current_time < item.next_run_time: + # 还没到执行时间,保留在队列中(保持原有顺序) + items_to_requeue.append(item) + continue + + # 执行相应的任务 + should_continue = False + if item.task_type == "query_action_status": + should_continue = await self._process_query_status_item(item) + elif item.task_type == "job_call_back_status": + should_continue = await self._process_job_callback_item(item) + else: + logger.warning(f"[TaskScheduler] Unknown task type: {item.task_type}") + continue + + # 如果需要继续,放入重新排队列表 + if should_continue: + item.next_run_time = current_time + 10.0 # 10秒后再次执行 + item.retry_count += 1 + items_to_requeue.append(item) + logger.critical( + f"[TaskScheduler] Re-queued {item.task_type} for {item.device_action_key}" + ) + else: + logger.critical( + f"[TaskScheduler] Completed {item.task_type} for {item.device_action_key}" + ) + + except Exception as e: + logger.error(f"[TaskScheduler] Error processing item {item.task_type}: {str(e)}") + + # 将需要重新排队的任务放回队列开头(保持原有顺序,确保优先于新任务执行) + if items_to_requeue and self.action_queue is not None: + with self.action_queue_lock: + self.action_queue = items_to_requeue + self.action_queue + + await asyncio.sleep(0.1) # 短暂等待避免过度占用CPU + + except Exception as e: + logger.error(f"[TaskScheduler] Error in queue processor: {str(e)}") + await asyncio.sleep(1) # 错误后稍等再继续 + + except asyncio.CancelledError: + logger.info("[TaskScheduler] Action queue processor cancelled") + except Exception as e: + logger.error(f"[TaskScheduler] Fatal error in queue processor: {str(e)}") + finally: + logger.info("[TaskScheduler] Action queue processor stopped") + + # 队列处理方法 + async def _process_query_status_item(self, item: QueueItem) -> bool: + """处理query_action_status类型的队列项,返回True表示需要继续,False表示可以停止""" + try: + # 检查设备状态 + host_node = HostNode.get_instance(0) + if not host_node: + logger.error("[TaskScheduler] HostNode instance not available in queue processor") + return False + + action_jobs = len(host_node._device_action_status[item.device_action_key].job_ids) + free = not bool(action_jobs) + + # 发送状态报告 + if free: + # 设备空闲,发送最终状态并停止 + await self._publish_device_action_state( + item.device_id, item.action_name, item.task_id, item.job_id, "query_action_status", True, 0.0 + ) + self.just_free_sets[item.device_action_key] = time.time() + 30.0 + return False # 停止继续监控 + else: + if item.device_action_key in self.just_free_sets: + if time.time() < self.just_free_sets[item.device_action_key]: + return True # 继续监控 + else: + del self.just_free_sets[item.device_action_key] + + # 设备忙碌,发送状态并继续监控 + await self._publish_device_action_state( + item.device_id, item.action_name, item.task_id, item.job_id, "query_action_status", False, 10.0 + ) + return True # 继续监控 + + except Exception as e: + logger.error(f"[TaskScheduler] Error processing query status item: {str(e)}") + return False # 出错则停止 + + async def _process_job_callback_item(self, item: QueueItem) -> bool: + """处理job_call_back_status类型的队列项,返回True表示需要继续,False表示可以停止""" + try: + logger.critical(f"[TaskScheduler] Processing job callback item for job_id: {item.job_id}") + # 检查任务是否还在活跃列表中 + if item.job_id not in self.active_jobs: + logger.critical(f"[TaskScheduler] Job {item.job_id} no longer active") + return False + + # 检查是否收到取消信号 + if item.job_id in self.cancel_events and self.cancel_events[item.job_id].is_set(): + logger.critical(f"[TaskScheduler] Job {item.job_id} cancelled via cancel event") + return False + + # 检查设备状态 + host_node = HostNode.get_instance(0) + if not host_node: + logger.error( + f"[TaskScheduler] HostNode instance not available in job callback queue for job_id: {item.job_id}" + ) + return False + + action_jobs = len(host_node._device_action_status[item.device_action_key].job_ids) + free = not bool(action_jobs) + + logger.critical( + f"[TaskScheduler] Job {item.job_id} callback status check - free: {free}, action_jobs: {action_jobs}" + ) + + # 发送job_call_back_status状态 + await self._publish_device_action_state( + item.device_id, item.action_name, item.task_id, item.job_id, "job_call_back_status", free, 10.0 + ) + + # 如果任务完成,停止监控 + if free: + logger.critical(f"[TaskScheduler] Job {item.job_id} callback monitoring completed - device is free") + return False + else: + logger.critical(f"[TaskScheduler] Job {item.job_id} callback monitoring continues - device is busy") + return True # 继续监控 + + except Exception as e: + logger.error(f"[TaskScheduler] Error processing job callback item for job_id {item.job_id}: {str(e)}") + return False # 出错则停止 + + # 消息发送方法 + async def _publish_device_action_state( + self, device_id: str, action_name: str, task_id: str, job_id: str, typ: str, free: bool, need_more: float + ) -> None: + """发布设备动作状态""" + message = { + "action": "report_action_state", + "data": { + "action": typ, + "device_id": device_id, + "action_name": action_name, + "task_id": task_id, + "job_id": job_id, + "free": free, + "need_more": need_more, + }, + } + await self.message_sender.send_message(message) + logger.critical(f"[TaskScheduler] Published action state: {device_id}/{action_name} - {typ}") + + async def _publish_final_job_status( + self, feedback_data: dict, item: QueueItem, status: str, return_info: Optional[str] = None + ) -> None: + """发布最终作业状态""" + if not self.message_sender.is_connected(): + logger.warning("[TaskScheduler] Not connected, cannot publish final job status") + return + + # 只处理最终状态 + if status not in ["completed", "failed", "cancelled"]: + logger.critical(f"[TaskScheduler] Ignoring non-final status: {status}") + return + + message = { + "action": "job_status", + "data": { + "job_id": item.job_id, + "task_id": item.task_id, + "device_id": item.device_id, + "action_name": item.action_name, + "status": status, + "feedback_data": feedback_data, + "return_info": return_info, + "timestamp": time.time(), + }, + } + await self.message_sender.send_message(message) + logger.critical(f"[TaskScheduler] Final job status published: {item.job_id} - {status}") + + # 业务逻辑处理方法 + async def handle_query_state(self, data: Dict[str, str]) -> None: + """处理query_action_state消息""" + device_id = data.get("device_id", "") + if not device_id: + logger.error("[TaskScheduler] query_action_state missing device_id") + return + action_name = data.get("action_name", "") + if not action_name: + logger.error("[TaskScheduler] query_action_state missing action_name") + return + task_id = data.get("task_id", "") + if not task_id: + logger.error("[TaskScheduler] query_action_state missing task_id") + return + job_id = data.get("job_id", "") + if not job_id: + logger.error("[TaskScheduler] query_action_state missing job_id") + return + + device_action_key = f"/devices/{device_id}/{action_name}" + host_node = HostNode.get_instance(0) + if not host_node: + logger.error("[TaskScheduler] HostNode instance not available") + return + + action_jobs = len(host_node._device_action_status[device_action_key].job_ids) + free = not bool(action_jobs) + + # 如果设备空闲,立即响应free状态 + if free: + await self._publish_device_action_state( + device_id, action_name, task_id, job_id, "query_action_status", True, 0.0 + ) + logger.critical(f"[TaskScheduler] Device {device_id}/{action_name} is free, responded immediately") + return + + # 设备忙碌时,检查是否已有相同的轮询任务 + if self.action_queue is not None: + with self.action_queue_lock: + # 检查是否已存在相同job_id和task_id的轮询任务 + for existing_item in self.action_queue: + if ( + existing_item.task_type == "query_action_status" + and existing_item.job_id == job_id + and existing_item.task_id == task_id + and existing_item.device_action_key == device_action_key + ): + logger.error( + f"[TaskScheduler] Duplicate query_action_state ignored: " + f"job_id={job_id}, task_id={task_id}, server error" + ) + return + + # 没有重复,加入轮询队列 + queue_item = QueueItem( + task_type="query_action_status", + device_id=device_id, + action_name=action_name, + task_id=task_id, + job_id=job_id, + device_action_key=device_action_key, + next_run_time=time.time() + 10.0, # 10秒后执行 + ) + self.action_queue.append(queue_item) + logger.critical(f"[TaskScheduler] Device {device_id}/{action_name} is busy, added to polling queue") + + # 立即发送busy状态 + await self._publish_device_action_state( + device_id, action_name, task_id, job_id, "query_action_status", False, 10.0 + ) + else: + logger.warning("[TaskScheduler] Action queue not available") + + async def handle_job_start(self, data: Dict[str, Any]): + """处理作业启动消息""" + try: + req = JobAddReq(**data) + device_action_key = f"/devices/{req.device_id}/{req.action}" + + logger.critical( + f"[TaskScheduler] Starting job with job_id: {req.job_id}, " + f"device: {req.device_id}, action: {req.action}" + ) + + # 添加到活跃任务 + self.active_jobs[req.job_id] = { + "device_id": req.device_id, + "action_name": req.action, + "task_id": data.get("task_id", ""), + "start_time": time.time(), + "device_action_key": device_action_key, + "callback_started": False, # 标记callback是否已启动 + } + + # 创建取消事件 + self.cancel_events[req.job_id] = asyncio.Event() + logger.critical(f"[TaskScheduler] Created cancel event for job_id: {req.job_id}") + + try: + host_node = HostNode.get_instance(0) + if not host_node: + logger.error(f"[TaskScheduler] HostNode instance not available for job_id: {req.job_id}") + return + + host_node._device_action_status[device_action_key].job_ids[req.job_id] = time.time() + logger.critical(f"[TaskScheduler] Job registered in HostNode: {req.job_id}") + + # 启动callback定时发送 + await self._start_job_callback( + req.job_id, req.device_id, req.action, data.get("task_id", ""), device_action_key + ) + + # 创建兼容HostNode的QueueItem对象 + job_queue_item = QueueItem( + task_type="job_call_back_status", + device_id=req.device_id, + action_name=req.action, + task_id=data.get("task_id", ""), + job_id=req.job_id, + device_action_key=device_action_key, + next_run_time=time.time(), + ) + + logger.critical(f"[TaskScheduler] Sending goal to HostNode for job_id: {req.job_id}") + host_node.send_goal( + job_queue_item, + action_type=req.action_type, + action_kwargs=req.action_args, + server_info=req.server_info, + ) + logger.critical(f"[TaskScheduler] Goal sent successfully for job_id: {req.job_id}") + except Exception as e: + logger.error(f"[TaskScheduler] Exception during job start for job_id {req.job_id}: {str(e)}") + traceback.print_exc() + # 异常结束,先停止callback,然后发送失败状态 + await self._stop_job_callback( + req.job_id, "failed", serialize_result_info(traceback.format_exc(), False, {}) + ) + + host_node = HostNode.get_instance(0) + if host_node: + host_node._device_action_status[device_action_key].job_ids.pop(req.job_id, None) + logger.critical(f"[TaskScheduler] Cleaned up failed job from HostNode: {req.job_id}") + except Exception as e: + logger.error(f"[TaskScheduler] Error handling job start: {str(e)}") + + async def handle_cancel_action(self, data: Dict[str, Any]) -> None: + """处理取消动作请求""" + task_id = data.get("task_id") + job_id = data.get("job_id") + + logger.critical(f"[TaskScheduler] Handling cancel action request - task_id: {task_id}, job_id: {job_id}") + + if not task_id and not job_id: + logger.error("[TaskScheduler] cancel_action missing both task_id and job_id") + return + + # 通过job_id取消 + if job_id: + logger.critical(f"[TaskScheduler] Cancelling job by job_id: {job_id}") + # 设置取消事件 + if job_id in self.cancel_events: + self.cancel_events[job_id].set() + logger.critical(f"[TaskScheduler] Set cancel event for job_id: {job_id}") + else: + logger.warning(f"[TaskScheduler] Cancel event not found for job_id: {job_id}") + + # 停止job callback并发送取消状态 + if job_id in self.active_jobs: + logger.critical(f"[TaskScheduler] Found active job for cancellation: {job_id}") + # 调用HostNode的cancel_goal + host_node = HostNode.get_instance(0) + if host_node: + host_node.cancel_goal(job_id) + logger.critical(f"[TaskScheduler] Cancelled goal in HostNode for job_id: {job_id}") + else: + logger.error(f"[TaskScheduler] HostNode not available for cancel goal: {job_id}") + + # 停止callback并发送取消状态 + await self._stop_job_callback(job_id, "cancelled", "Job was cancelled by user request") + logger.critical(f"[TaskScheduler] Stopped job callback and sent cancel status for job_id: {job_id}") + else: + logger.warning(f"[TaskScheduler] Job not found in active jobs for cancellation: {job_id}") + + # 通过task_id取消(需要查找对应的job_id) + if task_id and not job_id: + logger.critical(f"[TaskScheduler] Cancelling jobs by task_id: {task_id}") + jobs_to_cancel = [] + for jid, job_info in self.active_jobs.items(): + if job_info.get("task_id") == task_id: + jobs_to_cancel.append(jid) + + logger.critical( + f"[TaskScheduler] Found {len(jobs_to_cancel)} jobs to cancel for task_id {task_id}: {jobs_to_cancel}" + ) + + for jid in jobs_to_cancel: + logger.critical(f"[TaskScheduler] Recursively cancelling job_id: {jid} for task_id: {task_id}") + # 递归调用自身来取消每个job + await self.handle_cancel_action({"job_id": jid}) + + logger.critical(f"[TaskScheduler] Completed cancel action handling - task_id: {task_id}, job_id: {job_id}") + + # job管理方法 + async def _start_job_callback( + self, job_id: str, device_id: str, action_name: str, task_id: str, device_action_key: str + ) -> None: + """启动job的callback定时发送""" + logger.critical(f"[TaskScheduler] Starting job callback for job_id: {job_id}") + if job_id not in self.active_jobs: + logger.warning(f"[TaskScheduler] Job not found in active jobs when starting callback: {job_id}") + return + + # 检查是否已经启动过callback + if self.active_jobs[job_id].get("callback_started", False): + logger.warning(f"[TaskScheduler] Job callback already started for job_id: {job_id}") + return + + # 标记callback已启动 + self.active_jobs[job_id]["callback_started"] = True + logger.critical(f"[TaskScheduler] Marked callback as started for job_id: {job_id}") + + # 将job_call_back_status任务放入队列 + queue_item = QueueItem( + task_type="job_call_back_status", + device_id=device_id, + action_name=action_name, + task_id=task_id, + job_id=job_id, + device_action_key=device_action_key, + next_run_time=time.time() + 10.0, # 10秒后开始报送 + ) + if self.action_queue is not None: + with self.action_queue_lock: + self.action_queue.append(queue_item) + logger.critical(f"[TaskScheduler] Added job callback to queue for job_id: {job_id}") + else: + logger.warning(f"[TaskScheduler] Action queue not available for job callback: {job_id}") + + async def _stop_job_callback(self, job_id: str, final_status: str, return_info: Optional[str] = None) -> None: + """停止job的callback定时发送并发送最终结果""" + logger.critical( + f"[TaskScheduler] Stopping job callback for job_id: {job_id} with final status: {final_status}" + ) + if job_id not in self.active_jobs: + logger.warning(f"[TaskScheduler] Job {job_id} not found in active jobs when stopping callback") + return + + job_info = self.active_jobs[job_id] + device_id = job_info["device_id"] + action_name = job_info["action_name"] + task_id = job_info["task_id"] + device_action_key = job_info["device_action_key"] + + logger.critical( + f"[TaskScheduler] Job {job_id} details - device: {device_id}, action: {action_name}, task: {task_id}" + ) + + # 移除活跃任务和取消事件(这会让队列处理器自动停止callback) + self.active_jobs.pop(job_id, None) + self.cancel_events.pop(job_id, None) + logger.critical(f"[TaskScheduler] Removed job {job_id} from active jobs and cancel events") + + # 发送最终的callback状态 + await self._publish_device_action_state( + device_id, action_name, task_id, job_id, "job_call_back_status", True, 0.0 + ) + logger.critical(f"[TaskScheduler] Published final callback status for job_id: {job_id}") + + # 发送最终的job状态 + error_queue_item = QueueItem( + task_type="job_call_back_status", + device_id=device_id, + action_name=action_name, + task_id=task_id, + job_id=job_id, + device_action_key=device_action_key, + next_run_time=time.time(), + ) + await self._publish_final_job_status({}, error_queue_item, final_status, return_info) + logger.critical(f"[TaskScheduler] Published final job status for job_id: {job_id}") + + logger.critical( + f"[TaskScheduler] Completed stopping job callback for {job_id} with final status: {final_status}" + ) + + # 外部接口方法 + def publish_job_status( + self, feedback_data: dict, item: "QueueItem", status: str, return_info: Optional[str] = None + ) -> None: + """发布作业状态,拦截最终结果(给HostNode调用的接口)""" + logger.critical(f"[TaskScheduler] Publishing job status for job_id: {item.job_id} - status: {status}") + if not self.message_sender.is_connected(): + logger.warning(f"[TaskScheduler] Not connected, cannot publish job status for job_id: {item.job_id}") + return + + # 拦截最终结果状态 + if status in ["completed", "failed"]: + logger.critical(f"[TaskScheduler] Intercepting final status for job_id: {item.job_id} - {status}") + # 如果是最终状态,通过_stop_job_callback处理 + try: + loop = asyncio.get_event_loop() + loop.create_task(self._stop_job_callback(item.job_id, status, return_info)) + logger.critical(f"[TaskScheduler] Scheduled final callback stop for job_id: {item.job_id}") + except RuntimeError: + # 如果没有运行的事件循环,创建一个 + asyncio.run(self._stop_job_callback(item.job_id, status, return_info)) + logger.critical(f"[TaskScheduler] Executed final callback stop for job_id: {item.job_id}") + logger.critical(f"[TaskScheduler] Intercepted final job status: {item.job_id} - {status}") + return + + # 对于running状态,正常发布 + logger.critical(f"[TaskScheduler] Publishing running status for job_id: {item.job_id} - {status}") + message = { + "action": "job_status", + "data": { + "job_id": item.job_id, + "task_id": item.task_id, + "device_id": item.device_id, + "action_name": item.action_name, + "status": status, + "feedback_data": feedback_data, + "return_info": return_info, + "timestamp": time.time(), + }, + } + try: + loop = asyncio.get_event_loop() + loop.create_task(self.message_sender.send_message(message)) + logger.critical(f"[TaskScheduler] Scheduled message send for job_id: {item.job_id} - {status}") + except RuntimeError: + asyncio.run(self.message_sender.send_message(message)) + logger.critical(f"[TaskScheduler] Executed message send for job_id: {item.job_id} - {status}") + logger.trace(f"[TaskScheduler] Job status published: {item.job_id} - {status}") # type: ignore + + def cancel_goal(self, job_id: str) -> None: + """取消指定的任务(给外部调用的接口)""" + logger.critical(f"[TaskScheduler] External cancel request for job_id: {job_id}") + if job_id in self.cancel_events: + logger.critical(f"[TaskScheduler] Found cancel event for job_id: {job_id}, processing cancellation") + try: + loop = asyncio.get_event_loop() + loop.create_task(self.handle_cancel_action({"job_id": job_id})) + logger.critical(f"[TaskScheduler] Scheduled cancel action for job_id: {job_id}") + except RuntimeError: + asyncio.run(self.handle_cancel_action({"job_id": job_id})) + logger.critical(f"[TaskScheduler] Executed cancel action for job_id: {job_id}") + logger.critical(f"[TaskScheduler] Initiated cancel for job_id: {job_id}") + else: + logger.warning(f"[TaskScheduler] Job {job_id} not found in cancel events for cancellation") + + class WebSocketClient(BaseCommunicationClient): """ WebSocket通信客户端类 - 实现基于WebSocket协议的实时通信功能。 + 专注于WebSocket连接管理和消息传输。 """ def __init__(self): super().__init__() - - if not HAS_WEBSOCKETS: - logger.error("[WebSocket] websockets库未安装,WebSocket功能不可用") - self.is_disabled = True - return - self.is_disabled = False self.client_id = f"{uuid.uuid4()}" @@ -62,11 +686,22 @@ class WebSocketClient(BaseCommunicationClient): self.message_queue = asyncio.Queue() if not self.is_disabled else None self.reconnect_count = 0 + # 任务调度器 + self.task_scheduler = None + # 构建WebSocket URL self._build_websocket_url() logger.info(f"[WebSocket] Client_id: {self.client_id}") + # 初始化方法 + def _initialize_task_scheduler(self): + """初始化任务调度器""" + if not self.task_scheduler: + self.task_scheduler = TaskScheduler(self) + self.task_scheduler.start() + logger.info("[WebSocket] Task scheduler initialized") + def _build_websocket_url(self): """构建WebSocket连接URL""" if not HTTPConfig.remote_addr: @@ -81,14 +716,15 @@ class WebSocketClient(BaseCommunicationClient): scheme = "wss" else: scheme = "ws" - if ":" in parsed.netloc: - self.websocket_url = f"{scheme}://{parsed.hostname}:{parsed.port + 1}/api/v1/lab" + if ":" in parsed.netloc and parsed.port is not None: + self.websocket_url = f"{scheme}://{parsed.hostname}:{parsed.port + 1}/api/v1/ws/lab" else: - self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/lab" + self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/ws/lab" logger.debug(f"[WebSocket] URL: {self.websocket_url}") + # 连接管理方法 def start(self) -> None: - """启动WebSocket连接""" + """启动WebSocket连接和任务调度器""" if self.is_disabled: logger.warning("[WebSocket] WebSocket is disabled, skipping connection.") return @@ -99,6 +735,9 @@ class WebSocketClient(BaseCommunicationClient): logger.info(f"[WebSocket] Starting connection to {self.websocket_url}") + # 初始化任务调度器 + self._initialize_task_scheduler() + self.is_running = True # 在单独线程中运行WebSocket连接 @@ -106,7 +745,7 @@ class WebSocketClient(BaseCommunicationClient): self.connection_thread.start() def stop(self) -> None: - """停止WebSocket连接""" + """停止WebSocket连接和任务调度器""" if self.is_disabled: return @@ -114,6 +753,10 @@ class WebSocketClient(BaseCommunicationClient): self.is_running = False self.connected = False + # 停止任务调度器 + if self.task_scheduler: + self.task_scheduler.stop() + if self.event_loop and self.event_loop.is_running(): asyncio.run_coroutine_threadsafe(self._close_connection(), self.event_loop) @@ -145,19 +788,22 @@ class WebSocketClient(BaseCommunicationClient): assert self.websocket_url is not None if self.websocket_url.startswith("wss://"): ssl_context = ssl_module.create_default_context() - + ws_logger = logging.getLogger("websockets.client") + ws_logger.setLevel(logging.INFO) async with websockets.connect( self.websocket_url, ssl=ssl_context, ping_interval=WSConfig.ping_interval, ping_timeout=10, additional_headers={"Authorization": f"Lab {BasicConfig.auth_secret()}"}, + logger=ws_logger, ) as websocket: self.websocket = websocket self.connected = True self.reconnect_count = 0 logger.info(f"[WebSocket] Connected to {self.websocket_url}") + # 处理消息 await self._message_handler() @@ -167,6 +813,9 @@ class WebSocketClient(BaseCommunicationClient): except Exception as e: logger.error(f"[WebSocket] Connection error: {str(e)}") self.connected = False + finally: + # WebSocket连接结束时只需重置websocket对象 + self.websocket = None # 重连逻辑 if self.is_running and self.reconnect_count < WSConfig.max_reconnect_attempts: @@ -188,19 +837,7 @@ class WebSocketClient(BaseCommunicationClient): await self.websocket.close() self.websocket = None - async def _send_message(self, message: Dict[str, Any]): - """发送消息""" - if not self.connected or not self.websocket: - logger.warning("[WebSocket] Not connected, cannot send message") - return - - try: - message_str = json.dumps(message, ensure_ascii=False) - await self.websocket.send(message_str) - logger.debug(f"[WebSocket] Message sent: {message['type']}") - except Exception as e: - logger.error(f"[WebSocket] Failed to send message: {str(e)}") - + # 消息处理方法 async def _message_handler(self): """处理接收到的消息""" if not self.websocket: @@ -221,69 +858,28 @@ class WebSocketClient(BaseCommunicationClient): except Exception as e: logger.error(f"[WebSocket] Message handler error: {str(e)}") - async def _handle_query_state(self, data: Dict[str, str]) -> None: - device_id = data.get("device_id", "") - if not device_id: - logger.error("[WebSocket] query_action_state missing device_id") - return - action_name = data.get("action_name", "") - if not action_name: - logger.error("[WebSocket] query_action_state missing action_name") - return - task_id = data.get("task_id", "") - if not task_id: - logger.error("[WebSocket] query_action_state missing task_id") - return - job_id = data.get("job_id", "") - if not task_id: - logger.error("[WebSocket] query_action_state missing job_id") - return - device_action_key = f"/devices/{device_id}/{action_name}" - action_jobs = len(HostNode.get_instance()._device_action_status[device_action_key].job_ids) - message = { - "type": "report_action_state", - "data": { - "device_id": device_id, - "action_name": action_name, - "task_id": task_id, - "job_id": job_id, - "free": bool(action_jobs) - }, - } - await self._send_message(message) - async def _process_message(self, input_message: Dict[str, Any]): """处理收到的消息""" - message_type = input_message.get("type", "") + message_type = input_message.get("action", "") data = input_message.get("data", {}) - if message_type == "job_start": - # 处理作业启动消息 - await self._handle_job_start(data) - elif message_type == "pong": - # 处理pong响应 - self._handle_pong_sync(data) - elif message_type == "query_action_state": - await self._handle_query_state(data) - else: - logger.debug(f"[WebSocket] Unknown message type: {message_type}") - async def _handle_job_start(self, data: Dict[str, Any]): - """处理作业启动消息""" - try: - req = JobAddReq(**data) - try: - req.job_id = str(uuid.uuid4()) - logger.info(f"[WebSocket] Job started: {req.job_id}") - HostNode.get_instance().send_goal(req.device_id, action_type=req.action_type, action_name=req.action, - action_kwargs=req.action_args, goal_uuid=req.job_id, - server_info=req.server_info) - except Exception as e: - for bridge in HostNode.get_instance().bridges: - traceback.print_exc() - if hasattr(bridge, "publish_job_status"): - self.publish_job_status({}, req.job_id, "failed", serialize_result_info(traceback.format_exc(), False, {})) - except Exception as e: - logger.error(f"[WebSocket] Error handling job start: {str(e)}") + if message_type == "pong": + # 处理pong响应(WebSocket层面的连接管理) + self._handle_pong_sync(data) + elif self.task_scheduler: + # 其他消息交给TaskScheduler处理 + if message_type == "job_start": + await self.task_scheduler.handle_job_start(data) + elif message_type == "query_action_state": + await self.task_scheduler.handle_query_state(data) + elif message_type == "cancel_action": + await self.task_scheduler.handle_cancel_action(data) + elif message_type == "": + return + else: + logger.debug(f"[WebSocket] Unknown message: {input_message}") + else: + logger.warning(f"[WebSocket] Task scheduler not available for message: {message_type}") def _handle_pong_sync(self, pong_data: Dict[str, Any]): """同步处理pong响应""" @@ -291,18 +887,43 @@ class WebSocketClient(BaseCommunicationClient): if host_node: host_node.handle_pong_response(pong_data) - # 实现抽象基类的方法 + # 消息发送方法 + async def _send_message(self, message: Dict[str, Any]): + """内部发送消息方法""" + if not self.connected or not self.websocket: + logger.warning("[WebSocket] Not connected, cannot send message") + return + + try: + message_str = json.dumps(message, ensure_ascii=False) + await self.websocket.send(message_str) + logger.debug(f"[WebSocket] Message sent: {message['action']}") + except Exception as e: + logger.error(f"[WebSocket] Failed to send message: {str(e)}") + + # MessageSender接口实现 + async def send_message(self, message: Dict[str, Any]) -> None: + """发送消息(TaskScheduler调用的接口)""" + await self._send_message(message) + + def is_connected(self) -> bool: + """检查是否已连接(TaskScheduler调用的接口)""" + return self.connected and not self.is_disabled + + # 基类方法实现 def publish_device_status(self, device_status: dict, device_id: str, property_name: str) -> None: """发布设备状态""" if self.is_disabled or not self.connected: return message = { - "type": "device_status", + "action": "device_status", "data": { "device_id": device_id, - "property_name": property_name, - "status": device_status.get(device_id, {}).get(property_name), - "timestamp": time.time(), + "data": { + "property_name": property_name, + "status": device_status.get(device_id, {}).get(property_name), + "timestamp": time.time(), + }, }, } if self.event_loop: @@ -310,37 +931,31 @@ class WebSocketClient(BaseCommunicationClient): logger.debug(f"[WebSocket] Device status published: {device_id}.{property_name}") def publish_job_status( - self, feedback_data: dict, job_id: str, status: str, return_info: Optional[str] = None + self, feedback_data: dict, item: "QueueItem", status: str, return_info: Optional[str] = None ) -> None: - """发布作业状态""" - if self.is_disabled or not self.connected: - logger.warning("[WebSocket] Not connected, cannot publish job status") - return - message = { - "type": "job_status", - "data": { - "job_id": job_id, - "status": status, - "feedback_data": feedback_data, - "return_info": return_info, - "timestamp": time.time(), - }, - } - if self.event_loop: - asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop) - logger.debug(f"[WebSocket] Job status published: {job_id} - {status}") + """发布作业状态(转发给TaskScheduler)""" + logger.critical(f"[WebSocket] Received job status for job_id: {item.job_id} - status: {status}") + if self.task_scheduler: + self.task_scheduler.publish_job_status(feedback_data, item, status, return_info) + logger.critical(f"[WebSocket] Forwarded job status to TaskScheduler for job_id: {item.job_id}") + else: + logger.warning(f"[WebSocket] Task scheduler not available for job status: {item.job_id}") def send_ping(self, ping_id: str, timestamp: float) -> None: """发送ping消息""" if self.is_disabled or not self.connected: logger.warning("[WebSocket] Not connected, cannot send ping") return - message = {"type": "ping", "data": {"ping_id": ping_id, "client_timestamp": timestamp}} + message = {"action": "ping", "data": {"ping_id": ping_id, "client_timestamp": timestamp}} if self.event_loop: asyncio.run_coroutine_threadsafe(self._send_message(message), self.event_loop) logger.debug(f"[WebSocket] Ping sent: {ping_id}") - @property - def is_connected(self) -> bool: - """检查是否已连接""" - return self.connected and not self.is_disabled + def cancel_goal(self, job_id: str) -> None: + """取消指定的任务(转发给TaskScheduler)""" + logger.critical(f"[WebSocket] Received cancel goal request for job_id: {job_id}") + if self.task_scheduler: + self.task_scheduler.cancel_goal(job_id) + logger.critical(f"[WebSocket] Forwarded cancel goal to TaskScheduler for job_id: {job_id}") + else: + logger.warning(f"[WebSocket] Task scheduler not available for cancel goal: {job_id}") diff --git a/unilabos/ros/nodes/presets/host_node.py b/unilabos/ros/nodes/presets/host_node.py index 619fd377..eed9f420 100644 --- a/unilabos/ros/nodes/presets/host_node.py +++ b/unilabos/ros/nodes/presets/host_node.py @@ -6,7 +6,7 @@ import threading import time import traceback import uuid -from typing import Optional, Dict, Any, List, ClassVar, Set, Union +from typing import TYPE_CHECKING, Optional, Dict, Any, List, ClassVar, Set, Union from action_msgs.msg import GoalStatus from geometry_msgs.msg import Point @@ -42,6 +42,9 @@ from unilabos.ros.nodes.presets.controller_node import ControllerNode from unilabos.utils.exception import DeviceClassInvalid from unilabos.utils.type_check import serialize_result_info +if TYPE_CHECKING: + from unilabos.app.ws_client import QueueItem + @dataclass class DeviceActionStatus: @@ -230,12 +233,12 @@ class HostNode(BaseROS2DeviceNode): client: HTTPClient = bridge resource_start_time = time.time() - resource_add_res = client.resource_add(add_schema(resource_with_parent_name), False) + # resource_add_res = client.resource_add(add_schema(resource_with_parent_name), False) resource_end_time = time.time() self.lab_logger().info( f"[Host Node-Resource] 物料上传 {round(resource_end_time - resource_start_time, 5) * 1000} ms" ) - resource_add_res = client.resource_edge_add(self.resources_edge_config, False) + # resource_add_res = client.resource_edge_add(self.resources_edge_config, False) resource_edge_end_time = time.time() self.lab_logger().info( f"[Host Node-Resource] 物料关系上传 {round(resource_edge_end_time - resource_end_time, 5) * 1000} ms" @@ -621,11 +624,9 @@ class HostNode(BaseROS2DeviceNode): def send_goal( self, - device_id: str, + item: "QueueItem", action_type: str, - action_name: str, action_kwargs: Dict[str, Any], - goal_uuid: Optional[str] = None, server_info: Optional[Dict[str, Any]] = None, ) -> None: """ @@ -639,12 +640,9 @@ class HostNode(BaseROS2DeviceNode): goal_uuid: 目标UUID,如果为None则自动生成 server_info: 服务器发送信息,包含发送时间戳等 """ - if goal_uuid is None: - u = uuid.uuid4() - else: - u = uuid.UUID(goal_uuid) - device_action_key = f"/devices/{device_id}/{action_name}" - self._device_action_status[device_action_key].job_ids[str(u)] = time.time() + u = uuid.UUID(item.job_id) + device_id = item.device_id + action_name = item.action_name if action_type.startswith("UniLabJsonCommand"): if action_name.startswith("auto-"): action_name = action_name[5:] @@ -676,43 +674,44 @@ class HostNode(BaseROS2DeviceNode): future = action_client.send_goal_async( goal_msg, - feedback_callback=lambda feedback_msg: self.feedback_callback(action_id, str(u), feedback_msg), + feedback_callback=lambda feedback_msg: self.feedback_callback(item, action_id, feedback_msg), goal_uuid=goal_uuid_obj, ) future.add_done_callback( - lambda future: self.goal_response_callback(device_action_key, action_id, str(u), future) + lambda future: self.goal_response_callback(item, action_id, future) ) - def goal_response_callback(self, device_action_key: str, action_id: str, uuid_str: str, future) -> None: + def goal_response_callback(self, item: "QueueItem", action_id: str, future) -> None: """目标响应回调""" goal_handle = future.result() if not goal_handle.accepted: - self.lab_logger().warning(f"[Host Node] Goal {action_id} ({uuid_str}) rejected") + self.lab_logger().warning(f"[Host Node] Goal {item.action_name} ({item.job_id}) rejected") return - self.lab_logger().info(f"[Host Node] Goal {action_id} ({uuid_str}) accepted") - self._goals[uuid_str] = goal_handle + self.lab_logger().info(f"[Host Node] Goal {action_id} ({item.job_id}) accepted") + self._goals[item.job_id] = goal_handle goal_handle.get_result_async().add_done_callback( - lambda future: self.get_result_callback(device_action_key, action_id, uuid_str, future) + lambda future: self.get_result_callback(item, action_id, future) ) - def feedback_callback(self, action_id: str, uuid_str: str, feedback_msg) -> None: + def feedback_callback(self, item: "QueueItem", action_id: str, feedback_msg) -> None: """反馈回调""" feedback_data = convert_from_ros_msg(feedback_msg) feedback_data.pop("goal_id") - self.lab_logger().debug(f"[Host Node] Feedback for {action_id} ({uuid_str}): {feedback_data}") + self.lab_logger().trace(f"[Host Node] Feedback for {action_id} ({item.job_id}): {feedback_data}") for bridge in self.bridges: if hasattr(bridge, "publish_job_status"): - bridge.publish_job_status(feedback_data, uuid_str, "running") + bridge.publish_job_status(feedback_data, item, "running") - def get_result_callback(self, device_action_key: str, action_id: str, uuid_str: str, future) -> None: + def get_result_callback(self, item: "QueueItem", action_id: str, future) -> None: """获取结果回调""" + job_id = item.job_id + self._device_action_status[f"/devices/{item.device_id}/{item.action_name}"].job_ids.pop(item.job_id) result_msg = future.result().result result_data = convert_from_ros_msg(result_msg) status = "success" return_info_str = result_data.get("return_info") - self._device_action_status[device_action_key].job_ids.pop(uuid_str) if return_info_str is not None: try: ret = json.loads(return_info_str) @@ -732,13 +731,23 @@ class HostNode(BaseROS2DeviceNode): status = "failed" return_info_str = serialize_result_info("缺少return_info", False, result_data) - self.lab_logger().info(f"[Host Node] Result for {action_id} ({uuid_str}): {status}") + self.lab_logger().info(f"[Host Node] Result for {action_id} ({job_id}): {status}") self.lab_logger().debug(f"[Host Node] Result data: {result_data}") - if uuid_str: + if job_id: for bridge in self.bridges: if hasattr(bridge, "publish_job_status"): - bridge.publish_job_status(result_data, uuid_str, status, return_info_str) + bridge.publish_job_status(result_data, item, status, return_info_str) + # 如果是WebSocket客户端,通知任务完成 + if hasattr(bridge, "_finish_job_callback_status"): + import asyncio + + free = True # 任务完成,设备空闲 + need_more = 0.0 # 任务结束,不需要更多时间 + try: + asyncio.create_task(bridge._finish_job_callback_status(job_id, free, need_more)) + except Exception as e: + self.lab_logger().error(f"[Host Node] Error finishing job callback status: {e}") def cancel_goal(self, goal_uuid: str) -> None: """取消目标""" @@ -748,14 +757,14 @@ class HostNode(BaseROS2DeviceNode): else: self.lab_logger().warning(f"[Host Node] Goal {goal_uuid} not found, cannot cancel") - def get_goal_status(self, uuid_str: str) -> int: + def get_goal_status(self, job_id: str) -> int: """获取目标状态""" - if uuid_str in self._goals: - g = self._goals[uuid_str] + if job_id in self._goals: + g = self._goals[job_id] status = g.status - self.lab_logger().debug(f"[Host Node] Goal status for {uuid_str}: {status}") + self.lab_logger().debug(f"[Host Node] Goal status for {job_id}: {status}") return status - self.lab_logger().warning(f"[Host Node] Goal {uuid_str} not found, status unknown") + self.lab_logger().warning(f"[Host Node] Goal {job_id} not found, status unknown") return GoalStatus.STATUS_UNKNOWN """Controller Node"""