feat: websocket test

This commit is contained in:
Xuwznln
2025-08-28 19:57:14 +08:00
parent cd84e26126
commit bbbdb06bbc
4 changed files with 35 additions and 18 deletions

View File

@@ -50,11 +50,15 @@ class Resp(BaseModel):
class JobAddReq(BaseModel): class JobAddReq(BaseModel):
device_id: str = Field(examples=["Gripper"], description="device id") device_id: str = Field(examples=["Gripper"], description="device id")
data: dict = Field(examples=[{"position": 30, "torque": 5, "action": "push_to"}]) action: str = Field(examples=["_execute_driver_command_async"], description="action name", default="")
action_type: str = Field(examples=["unilabos_msgs.action._str_single_input.StrSingleInput"], description="action name", default="")
action_args: dict = Field(examples=[{'string': 'string'}], description="action name", default="")
job_id: str = Field(examples=["job_id"], description="goal uuid") job_id: str = Field(examples=["job_id"], description="goal uuid")
node_id: str = Field(examples=["node_id"], description="node uuid") node_id: str = Field(examples=["node_id"], description="node uuid")
server_info: dict = Field(examples=[{"send_timestamp": 1717000000.0}], description="server info") server_info: dict = Field(examples=[{"send_timestamp": 1717000000.0}], description="server info")
data: dict = Field(examples=[{"position": 30, "torque": 5, "action": "push_to"}], default={})
class JobStepFinishReq(BaseModel): class JobStepFinishReq(BaseModel):
token: str = Field(examples=["030944"], description="token") token: str = Field(examples=["030944"], description="token")

View File

@@ -53,7 +53,7 @@ class HTTPClient:
f"{self.remote_addr}/lab/resource/edge/batch_create/?database_process_later={database_param}" f"{self.remote_addr}/lab/resource/edge/batch_create/?database_process_later={database_param}"
if not self.backend_go else f"{self.remote_addr}/lab/material/edge", if not self.backend_go else f"{self.remote_addr}/lab/material/edge",
json=resources, json=resources,
headers={"Authorization": f"lab {self.auth}"}, headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
timeout=100, timeout=100,
) )
if response.status_code != 200 and response.status_code != 201: if response.status_code != 200 and response.status_code != 201:
@@ -73,7 +73,7 @@ class HTTPClient:
response = requests.post( response = requests.post(
f"{self.remote_addr}/lab/resource/?database_process_later={1 if database_process_later else 0}" if not self.backend_go else f"{self.remote_addr}/lab/material", f"{self.remote_addr}/lab/resource/?database_process_later={1 if database_process_later else 0}" if not self.backend_go else f"{self.remote_addr}/lab/material",
json=resources if not self.backend_go else {"nodes": resources}, json=resources if not self.backend_go else {"nodes": resources},
headers={"Authorization": f"lab {self.auth}"}, headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
timeout=100, timeout=100,
) )
if response.status_code != 200: if response.status_code != 200:
@@ -96,7 +96,7 @@ class HTTPClient:
response = requests.get( response = requests.get(
f"{self.remote_addr}/lab/resource/?edge_format=1" if not self.backend_go else f"{self.remote_addr}/lab/material", f"{self.remote_addr}/lab/resource/?edge_format=1" if not self.backend_go else f"{self.remote_addr}/lab/material",
params={"id": id, "with_children": with_children}, params={"id": id, "with_children": with_children},
headers={"Authorization": f"lab {self.auth}"}, headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
timeout=20, timeout=20,
) )
return response.json() return response.json()
@@ -114,7 +114,7 @@ class HTTPClient:
response = requests.delete( response = requests.delete(
f"{self.remote_addr}/lab/resource/batch_delete/", f"{self.remote_addr}/lab/resource/batch_delete/",
params={"id": id}, params={"id": id},
headers={"Authorization": f"lab {self.auth}"}, headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
timeout=20, timeout=20,
) )
return response return response
@@ -132,7 +132,7 @@ class HTTPClient:
response = requests.patch( response = requests.patch(
f"{self.remote_addr}/lab/resource/batch_update/?edge_format=1", f"{self.remote_addr}/lab/resource/batch_update/?edge_format=1",
json=resources, json=resources,
headers={"Authorization": f"lab {self.auth}"}, headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
timeout=100, timeout=100,
) )
return response return response
@@ -156,7 +156,7 @@ class HTTPClient:
response = requests.post( response = requests.post(
f"{self.remote_addr}/api/account/file_upload/{scene}", f"{self.remote_addr}/api/account/file_upload/{scene}",
files=files, files=files,
headers={"Authorization": f"lab {self.auth}"}, headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
timeout=30, # 上传文件可能需要更长的超时时间 timeout=30, # 上传文件可能需要更长的超时时间
) )
return response return response
@@ -174,7 +174,7 @@ class HTTPClient:
response = requests.post( response = requests.post(
f"{self.remote_addr}/lab/registry/" if not self.backend_go else f"{self.remote_addr}/lab/resource", f"{self.remote_addr}/lab/registry/" if not self.backend_go else f"{self.remote_addr}/lab/resource",
json=registry_data, json=registry_data,
headers={"Authorization": f"lab {self.auth}"}, headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
timeout=30, timeout=30,
) )
if response.status_code not in [200, 201]: if response.status_code not in [200, 201]:
@@ -193,7 +193,7 @@ class HTTPClient:
""" """
response = requests.get( response = requests.get(
f"{self.remote_addr}/lab/resource/graph_info/", f"{self.remote_addr}/lab/resource/graph_info/",
headers={"Authorization": f"lab {self.auth}"}, headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
timeout=(3, 30), timeout=(3, 30),
) )
if response.status_code != 200: if response.status_code != 200:

View File

@@ -17,6 +17,7 @@ from urllib.parse import urlparse
from unilabos.app.controler import job_add from unilabos.app.controler import job_add
from unilabos.app.model import JobAddReq from unilabos.app.model import JobAddReq
from unilabos.ros.nodes.presets.host_node import HostNode from unilabos.ros.nodes.presets.host_node import HostNode
from unilabos.utils.type_check import serialize_result_info
try: try:
import websockets import websockets
@@ -80,8 +81,10 @@ class WebSocketClient(BaseCommunicationClient):
scheme = "wss" scheme = "wss"
else: else:
scheme = "ws" scheme = "ws"
if ":" in parsed.netloc:
self.websocket_url = f"{scheme}://{parsed.hostname}:{parsed.port + 1}/api/v1/lab"
else:
self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/lab" self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/lab"
logger.debug(f"[WebSocket] URL: {self.websocket_url}") logger.debug(f"[WebSocket] URL: {self.websocket_url}")
def start(self) -> None: def start(self) -> None:
@@ -148,7 +151,7 @@ class WebSocketClient(BaseCommunicationClient):
ssl=ssl_context, ssl=ssl_context,
ping_interval=WSConfig.ping_interval, ping_interval=WSConfig.ping_interval,
ping_timeout=10, ping_timeout=10,
additional_headers={"Authorization": f"Bearer {BasicConfig.auth_secret()}"}, additional_headers={"Authorization": f"Lab {BasicConfig.auth_secret()}"},
) as websocket: ) as websocket:
self.websocket = websocket self.websocket = websocket
self.connected = True self.connected = True
@@ -176,6 +179,8 @@ class WebSocketClient(BaseCommunicationClient):
elif self.reconnect_count >= WSConfig.max_reconnect_attempts: elif self.reconnect_count >= WSConfig.max_reconnect_attempts:
logger.error("[WebSocket] Max reconnection attempts reached") logger.error("[WebSocket] Max reconnection attempts reached")
break break
else:
self.reconnect_count -= 1
async def _close_connection(self): async def _close_connection(self):
"""关闭WebSocket连接""" """关闭WebSocket连接"""
@@ -232,10 +237,18 @@ class WebSocketClient(BaseCommunicationClient):
async def _handle_job_start(self, data: Dict[str, Any]): async def _handle_job_start(self, data: Dict[str, Any]):
"""处理作业启动消息""" """处理作业启动消息"""
try: try:
job_req = JobAddReq(**data.get("job_data", {})) req = JobAddReq(**data)
job_add(job_req) try:
job_id = getattr(job_req, "id", "unknown") req.job_id = str(uuid.uuid4())
logger.info(f"[WebSocket] Job started: {job_id}") logger.info(f"[WebSocket] Job started: {req.job_id}")
HostNode.get_instance().send_goal(req.device_id, action_type=req.action_type, action_name=req.action,
action_kwargs=req.action_args, goal_uuid=req.job_id,
server_info=req.server_info)
except Exception as e:
for bridge in HostNode.get_instance().bridges:
traceback.print_exc()
if hasattr(bridge, "publish_job_status"):
self.publish_job_status({}, req.job_id, "failed", serialize_result_info(traceback.format_exc(), False, {}))
except Exception as e: except Exception as e:
logger.error(f"[WebSocket] Error handling job start: {str(e)}") logger.error(f"[WebSocket] Error handling job start: {str(e)}")

View File

@@ -56,7 +56,7 @@ class MQConfig:
# WebSocket配置 # WebSocket配置
class WSConfig: class WSConfig:
reconnect_interval = 5 # 重连间隔(秒) reconnect_interval = 5 # 重连间隔(秒)
max_reconnect_attempts = 10 # 最大重连次数 max_reconnect_attempts = 999 # 最大重连次数
ping_interval = 30 # ping间隔 ping_interval = 30 # ping间隔
@@ -96,7 +96,7 @@ def _update_config_from_module(module, override_labid: Optional[str]):
setattr(obj, attr, getattr(getattr(module, name), attr)) setattr(obj, attr, getattr(getattr(module, name), attr))
# 更新OSS认证 # 更新OSS认证
if len(OSSUploadConfig.authorization) == 0: if len(OSSUploadConfig.authorization) == 0:
OSSUploadConfig.authorization = f"lab {MQConfig.lab_id}" OSSUploadConfig.authorization = f"Lab {MQConfig.lab_id}"
# 对 ca_file cert_file key_file 进行初始化 # 对 ca_file cert_file key_file 进行初始化
if override_labid: if override_labid:
MQConfig.lab_id = override_labid MQConfig.lab_id = override_labid