From bbbdb06bbcc3a6c6f3c92d6e5741126e30c1ebbf Mon Sep 17 00:00:00 2001 From: Xuwznln <18435084+Xuwznln@users.noreply.github.com> Date: Thu, 28 Aug 2025 19:57:14 +0800 Subject: [PATCH] feat: websocket test --- unilabos/app/model.py | 6 +++++- unilabos/app/web/client.py | 16 ++++++++-------- unilabos/app/ws_client.py | 27 ++++++++++++++++++++------- unilabos/config/config.py | 4 ++-- 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/unilabos/app/model.py b/unilabos/app/model.py index a5b8c786..48b3a689 100644 --- a/unilabos/app/model.py +++ b/unilabos/app/model.py @@ -50,11 +50,15 @@ class Resp(BaseModel): class JobAddReq(BaseModel): device_id: str = Field(examples=["Gripper"], description="device id") - data: dict = Field(examples=[{"position": 30, "torque": 5, "action": "push_to"}]) + action: str = Field(examples=["_execute_driver_command_async"], description="action name", default="") + action_type: str = Field(examples=["unilabos_msgs.action._str_single_input.StrSingleInput"], description="action name", default="") + action_args: dict = Field(examples=[{'string': 'string'}], description="action name", default="") job_id: str = Field(examples=["job_id"], description="goal uuid") node_id: str = Field(examples=["node_id"], description="node uuid") server_info: dict = Field(examples=[{"send_timestamp": 1717000000.0}], description="server info") + data: dict = Field(examples=[{"position": 30, "torque": 5, "action": "push_to"}], default={}) + class JobStepFinishReq(BaseModel): token: str = Field(examples=["030944"], description="token") diff --git a/unilabos/app/web/client.py b/unilabos/app/web/client.py index 214424b3..eff1284f 100644 --- a/unilabos/app/web/client.py +++ b/unilabos/app/web/client.py @@ -53,7 +53,7 @@ class HTTPClient: f"{self.remote_addr}/lab/resource/edge/batch_create/?database_process_later={database_param}" if not self.backend_go else f"{self.remote_addr}/lab/material/edge", json=resources, - headers={"Authorization": f"lab {self.auth}"}, + headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"}, timeout=100, ) if response.status_code != 200 and response.status_code != 201: @@ -73,7 +73,7 @@ class HTTPClient: response = requests.post( f"{self.remote_addr}/lab/resource/?database_process_later={1 if database_process_later else 0}" if not self.backend_go else f"{self.remote_addr}/lab/material", json=resources if not self.backend_go else {"nodes": resources}, - headers={"Authorization": f"lab {self.auth}"}, + headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"}, timeout=100, ) if response.status_code != 200: @@ -96,7 +96,7 @@ class HTTPClient: response = requests.get( f"{self.remote_addr}/lab/resource/?edge_format=1" if not self.backend_go else f"{self.remote_addr}/lab/material", params={"id": id, "with_children": with_children}, - headers={"Authorization": f"lab {self.auth}"}, + headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"}, timeout=20, ) return response.json() @@ -114,7 +114,7 @@ class HTTPClient: response = requests.delete( f"{self.remote_addr}/lab/resource/batch_delete/", params={"id": id}, - headers={"Authorization": f"lab {self.auth}"}, + headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"}, timeout=20, ) return response @@ -132,7 +132,7 @@ class HTTPClient: response = requests.patch( f"{self.remote_addr}/lab/resource/batch_update/?edge_format=1", json=resources, - headers={"Authorization": f"lab {self.auth}"}, + headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"}, timeout=100, ) return response @@ -156,7 +156,7 @@ class HTTPClient: response = requests.post( f"{self.remote_addr}/api/account/file_upload/{scene}", files=files, - headers={"Authorization": f"lab {self.auth}"}, + headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"}, timeout=30, # 上传文件可能需要更长的超时时间 ) return response @@ -174,7 +174,7 @@ class HTTPClient: response = requests.post( f"{self.remote_addr}/lab/registry/" if not self.backend_go else f"{self.remote_addr}/lab/resource", json=registry_data, - headers={"Authorization": f"lab {self.auth}"}, + headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"}, timeout=30, ) if response.status_code not in [200, 201]: @@ -193,7 +193,7 @@ class HTTPClient: """ response = requests.get( f"{self.remote_addr}/lab/resource/graph_info/", - headers={"Authorization": f"lab {self.auth}"}, + headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"}, timeout=(3, 30), ) if response.status_code != 200: diff --git a/unilabos/app/ws_client.py b/unilabos/app/ws_client.py index 94fa6c6d..61e26f00 100644 --- a/unilabos/app/ws_client.py +++ b/unilabos/app/ws_client.py @@ -17,6 +17,7 @@ from urllib.parse import urlparse from unilabos.app.controler import job_add from unilabos.app.model import JobAddReq from unilabos.ros.nodes.presets.host_node import HostNode +from unilabos.utils.type_check import serialize_result_info try: import websockets @@ -80,8 +81,10 @@ class WebSocketClient(BaseCommunicationClient): scheme = "wss" else: scheme = "ws" - self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/lab" - + if ":" in parsed.netloc: + self.websocket_url = f"{scheme}://{parsed.hostname}:{parsed.port + 1}/api/v1/lab" + else: + self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/lab" logger.debug(f"[WebSocket] URL: {self.websocket_url}") def start(self) -> None: @@ -148,7 +151,7 @@ class WebSocketClient(BaseCommunicationClient): ssl=ssl_context, ping_interval=WSConfig.ping_interval, ping_timeout=10, - additional_headers={"Authorization": f"Bearer {BasicConfig.auth_secret()}"}, + additional_headers={"Authorization": f"Lab {BasicConfig.auth_secret()}"}, ) as websocket: self.websocket = websocket self.connected = True @@ -176,6 +179,8 @@ class WebSocketClient(BaseCommunicationClient): elif self.reconnect_count >= WSConfig.max_reconnect_attempts: logger.error("[WebSocket] Max reconnection attempts reached") break + else: + self.reconnect_count -= 1 async def _close_connection(self): """关闭WebSocket连接""" @@ -232,10 +237,18 @@ class WebSocketClient(BaseCommunicationClient): async def _handle_job_start(self, data: Dict[str, Any]): """处理作业启动消息""" try: - job_req = JobAddReq(**data.get("job_data", {})) - job_add(job_req) - job_id = getattr(job_req, "id", "unknown") - logger.info(f"[WebSocket] Job started: {job_id}") + req = JobAddReq(**data) + try: + req.job_id = str(uuid.uuid4()) + logger.info(f"[WebSocket] Job started: {req.job_id}") + HostNode.get_instance().send_goal(req.device_id, action_type=req.action_type, action_name=req.action, + action_kwargs=req.action_args, goal_uuid=req.job_id, + server_info=req.server_info) + except Exception as e: + for bridge in HostNode.get_instance().bridges: + traceback.print_exc() + if hasattr(bridge, "publish_job_status"): + self.publish_job_status({}, req.job_id, "failed", serialize_result_info(traceback.format_exc(), False, {})) except Exception as e: logger.error(f"[WebSocket] Error handling job start: {str(e)}") diff --git a/unilabos/config/config.py b/unilabos/config/config.py index c0edf2ba..99c31eea 100644 --- a/unilabos/config/config.py +++ b/unilabos/config/config.py @@ -56,7 +56,7 @@ class MQConfig: # WebSocket配置 class WSConfig: reconnect_interval = 5 # 重连间隔(秒) - max_reconnect_attempts = 10 # 最大重连次数 + max_reconnect_attempts = 999 # 最大重连次数 ping_interval = 30 # ping间隔(秒) @@ -96,7 +96,7 @@ def _update_config_from_module(module, override_labid: Optional[str]): setattr(obj, attr, getattr(getattr(module, name), attr)) # 更新OSS认证 if len(OSSUploadConfig.authorization) == 0: - OSSUploadConfig.authorization = f"lab {MQConfig.lab_id}" + OSSUploadConfig.authorization = f"Lab {MQConfig.lab_id}" # 对 ca_file cert_file key_file 进行初始化 if override_labid: MQConfig.lab_id = override_labid