mirror of
https://github.com/dptech-corp/Uni-Lab-OS.git
synced 2025-12-17 21:11:12 +00:00
feat: websocket test
This commit is contained in:
@@ -50,11 +50,15 @@ class Resp(BaseModel):
|
||||
|
||||
class JobAddReq(BaseModel):
|
||||
device_id: str = Field(examples=["Gripper"], description="device id")
|
||||
data: dict = Field(examples=[{"position": 30, "torque": 5, "action": "push_to"}])
|
||||
action: str = Field(examples=["_execute_driver_command_async"], description="action name", default="")
|
||||
action_type: str = Field(examples=["unilabos_msgs.action._str_single_input.StrSingleInput"], description="action name", default="")
|
||||
action_args: dict = Field(examples=[{'string': 'string'}], description="action name", default="")
|
||||
job_id: str = Field(examples=["job_id"], description="goal uuid")
|
||||
node_id: str = Field(examples=["node_id"], description="node uuid")
|
||||
server_info: dict = Field(examples=[{"send_timestamp": 1717000000.0}], description="server info")
|
||||
|
||||
data: dict = Field(examples=[{"position": 30, "torque": 5, "action": "push_to"}], default={})
|
||||
|
||||
|
||||
class JobStepFinishReq(BaseModel):
|
||||
token: str = Field(examples=["030944"], description="token")
|
||||
|
||||
@@ -53,7 +53,7 @@ class HTTPClient:
|
||||
f"{self.remote_addr}/lab/resource/edge/batch_create/?database_process_later={database_param}"
|
||||
if not self.backend_go else f"{self.remote_addr}/lab/material/edge",
|
||||
json=resources,
|
||||
headers={"Authorization": f"lab {self.auth}"},
|
||||
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
|
||||
timeout=100,
|
||||
)
|
||||
if response.status_code != 200 and response.status_code != 201:
|
||||
@@ -73,7 +73,7 @@ class HTTPClient:
|
||||
response = requests.post(
|
||||
f"{self.remote_addr}/lab/resource/?database_process_later={1 if database_process_later else 0}" if not self.backend_go else f"{self.remote_addr}/lab/material",
|
||||
json=resources if not self.backend_go else {"nodes": resources},
|
||||
headers={"Authorization": f"lab {self.auth}"},
|
||||
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
|
||||
timeout=100,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
@@ -96,7 +96,7 @@ class HTTPClient:
|
||||
response = requests.get(
|
||||
f"{self.remote_addr}/lab/resource/?edge_format=1" if not self.backend_go else f"{self.remote_addr}/lab/material",
|
||||
params={"id": id, "with_children": with_children},
|
||||
headers={"Authorization": f"lab {self.auth}"},
|
||||
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
|
||||
timeout=20,
|
||||
)
|
||||
return response.json()
|
||||
@@ -114,7 +114,7 @@ class HTTPClient:
|
||||
response = requests.delete(
|
||||
f"{self.remote_addr}/lab/resource/batch_delete/",
|
||||
params={"id": id},
|
||||
headers={"Authorization": f"lab {self.auth}"},
|
||||
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
|
||||
timeout=20,
|
||||
)
|
||||
return response
|
||||
@@ -132,7 +132,7 @@ class HTTPClient:
|
||||
response = requests.patch(
|
||||
f"{self.remote_addr}/lab/resource/batch_update/?edge_format=1",
|
||||
json=resources,
|
||||
headers={"Authorization": f"lab {self.auth}"},
|
||||
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
|
||||
timeout=100,
|
||||
)
|
||||
return response
|
||||
@@ -156,7 +156,7 @@ class HTTPClient:
|
||||
response = requests.post(
|
||||
f"{self.remote_addr}/api/account/file_upload/{scene}",
|
||||
files=files,
|
||||
headers={"Authorization": f"lab {self.auth}"},
|
||||
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
|
||||
timeout=30, # 上传文件可能需要更长的超时时间
|
||||
)
|
||||
return response
|
||||
@@ -174,7 +174,7 @@ class HTTPClient:
|
||||
response = requests.post(
|
||||
f"{self.remote_addr}/lab/registry/" if not self.backend_go else f"{self.remote_addr}/lab/resource",
|
||||
json=registry_data,
|
||||
headers={"Authorization": f"lab {self.auth}"},
|
||||
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code not in [200, 201]:
|
||||
@@ -193,7 +193,7 @@ class HTTPClient:
|
||||
"""
|
||||
response = requests.get(
|
||||
f"{self.remote_addr}/lab/resource/graph_info/",
|
||||
headers={"Authorization": f"lab {self.auth}"},
|
||||
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
|
||||
timeout=(3, 30),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
|
||||
@@ -17,6 +17,7 @@ from urllib.parse import urlparse
|
||||
from unilabos.app.controler import job_add
|
||||
from unilabos.app.model import JobAddReq
|
||||
from unilabos.ros.nodes.presets.host_node import HostNode
|
||||
from unilabos.utils.type_check import serialize_result_info
|
||||
|
||||
try:
|
||||
import websockets
|
||||
@@ -80,8 +81,10 @@ class WebSocketClient(BaseCommunicationClient):
|
||||
scheme = "wss"
|
||||
else:
|
||||
scheme = "ws"
|
||||
self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/lab"
|
||||
|
||||
if ":" in parsed.netloc:
|
||||
self.websocket_url = f"{scheme}://{parsed.hostname}:{parsed.port + 1}/api/v1/lab"
|
||||
else:
|
||||
self.websocket_url = f"{scheme}://{parsed.netloc}/api/v1/lab"
|
||||
logger.debug(f"[WebSocket] URL: {self.websocket_url}")
|
||||
|
||||
def start(self) -> None:
|
||||
@@ -148,7 +151,7 @@ class WebSocketClient(BaseCommunicationClient):
|
||||
ssl=ssl_context,
|
||||
ping_interval=WSConfig.ping_interval,
|
||||
ping_timeout=10,
|
||||
additional_headers={"Authorization": f"Bearer {BasicConfig.auth_secret()}"},
|
||||
additional_headers={"Authorization": f"Lab {BasicConfig.auth_secret()}"},
|
||||
) as websocket:
|
||||
self.websocket = websocket
|
||||
self.connected = True
|
||||
@@ -176,6 +179,8 @@ class WebSocketClient(BaseCommunicationClient):
|
||||
elif self.reconnect_count >= WSConfig.max_reconnect_attempts:
|
||||
logger.error("[WebSocket] Max reconnection attempts reached")
|
||||
break
|
||||
else:
|
||||
self.reconnect_count -= 1
|
||||
|
||||
async def _close_connection(self):
|
||||
"""关闭WebSocket连接"""
|
||||
@@ -232,10 +237,18 @@ class WebSocketClient(BaseCommunicationClient):
|
||||
async def _handle_job_start(self, data: Dict[str, Any]):
|
||||
"""处理作业启动消息"""
|
||||
try:
|
||||
job_req = JobAddReq(**data.get("job_data", {}))
|
||||
job_add(job_req)
|
||||
job_id = getattr(job_req, "id", "unknown")
|
||||
logger.info(f"[WebSocket] Job started: {job_id}")
|
||||
req = JobAddReq(**data)
|
||||
try:
|
||||
req.job_id = str(uuid.uuid4())
|
||||
logger.info(f"[WebSocket] Job started: {req.job_id}")
|
||||
HostNode.get_instance().send_goal(req.device_id, action_type=req.action_type, action_name=req.action,
|
||||
action_kwargs=req.action_args, goal_uuid=req.job_id,
|
||||
server_info=req.server_info)
|
||||
except Exception as e:
|
||||
for bridge in HostNode.get_instance().bridges:
|
||||
traceback.print_exc()
|
||||
if hasattr(bridge, "publish_job_status"):
|
||||
self.publish_job_status({}, req.job_id, "failed", serialize_result_info(traceback.format_exc(), False, {}))
|
||||
except Exception as e:
|
||||
logger.error(f"[WebSocket] Error handling job start: {str(e)}")
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ class MQConfig:
|
||||
# WebSocket配置
|
||||
class WSConfig:
|
||||
reconnect_interval = 5 # 重连间隔(秒)
|
||||
max_reconnect_attempts = 10 # 最大重连次数
|
||||
max_reconnect_attempts = 999 # 最大重连次数
|
||||
ping_interval = 30 # ping间隔(秒)
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ def _update_config_from_module(module, override_labid: Optional[str]):
|
||||
setattr(obj, attr, getattr(getattr(module, name), attr))
|
||||
# 更新OSS认证
|
||||
if len(OSSUploadConfig.authorization) == 0:
|
||||
OSSUploadConfig.authorization = f"lab {MQConfig.lab_id}"
|
||||
OSSUploadConfig.authorization = f"Lab {MQConfig.lab_id}"
|
||||
# 对 ca_file cert_file key_file 进行初始化
|
||||
if override_labid:
|
||||
MQConfig.lab_id = override_labid
|
||||
|
||||
Reference in New Issue
Block a user