移除MQTT,更新launch文档,提供注册表示例文件,更新到0.10.5

This commit is contained in:
Xuwznln
2025-09-15 02:39:43 +08:00
parent 94cdcbf24e
commit 2ca0311de6
24 changed files with 1902 additions and 715 deletions

View File

@@ -15,24 +15,33 @@ def start_backend(
without_host: bool = False,
visual: str = "None",
resources_mesh_config: dict = {},
**kwargs
**kwargs,
):
if backend == "ros":
# 假设 ros_main, simple_main, automancer_main 是不同 backend 的启动函数
from unilabos.ros.main_slave_run import main, slave # 如果选择 'ros' 作为 backend
elif backend == 'simple':
elif backend == "simple":
# 这里假设 simple_backend 和 automancer_backend 是你定义的其他两个后端
# from simple_backend import main as simple_main
pass
elif backend == 'automancer':
elif backend == "automancer":
# from automancer_backend import main as automancer_main
pass
else:
raise ValueError(f"Unsupported backend: {backend}")
backend_thread = threading.Thread(
target=main if not without_host else slave,
args=(devices_config, resources_config, resources_edge_config, graph, controllers_config, bridges, visual, resources_mesh_config),
args=(
devices_config,
resources_config,
resources_edge_config,
graph,
controllers_config,
bridges,
visual,
resources_mesh_config,
),
name="backend_thread",
daemon=True,
)

View File

@@ -3,7 +3,7 @@
"""
通信模块
提供MQTT和WebSocket的统一接口支持通过配置选择通信协议。
提供WebSocket的统一接口支持通过配置选择通信协议。
包含通信抽象层基类和通信客户端工厂。
"""
@@ -17,7 +17,7 @@ class BaseCommunicationClient(ABC):
"""
通信客户端抽象基类
定义了所有通信客户端(MQTT、WebSocket等需要实现的接口。
定义了所有通信客户端WebSocket等需要实现的接口。
"""
def __init__(self):
@@ -121,14 +121,12 @@ class CommunicationClientFactory:
protocol = protocol.lower()
if protocol == "mqtt":
return cls._create_mqtt_client()
elif protocol == "websocket":
if protocol == "websocket":
return cls._create_websocket_client()
else:
logger.error(f"[CommunicationFactory] Unsupported protocol: {protocol}")
logger.warning(f"[CommunicationFactory] Falling back to MQTT")
return cls._create_mqtt_client()
logger.warning(f"[CommunicationFactory] Falling back to WebSocket")
return cls._create_websocket_client()
@classmethod
def get_client(cls, protocol: Optional[str] = None) -> BaseCommunicationClient:
@@ -147,26 +145,16 @@ class CommunicationClientFactory:
return cls._client_cache
@classmethod
def _create_mqtt_client(cls) -> BaseCommunicationClient:
"""创建MQTT客户端"""
try:
from unilabos.app.mq import mqtt_client
return mqtt_client
except Exception as e:
logger.error(f"[CommunicationFactory] Failed to create MQTT client: {str(e)}")
raise
@classmethod
def _create_websocket_client(cls) -> BaseCommunicationClient:
"""创建WebSocket客户端"""
try:
from unilabos.app.ws_client import WebSocketClient
return WebSocketClient()
except Exception as e:
logger.error(f"[CommunicationFactory] Failed to create WebSocket client: {str(e)}")
logger.warning(f"[CommunicationFactory] Falling back to MQTT")
return cls._create_mqtt_client()
raise
@classmethod
def reset_client(cls):
@@ -188,7 +176,7 @@ class CommunicationClientFactory:
Returns:
支持的协议列表
"""
return ["mqtt", "websocket"]
return ["websocket"]
def get_communication_client(protocol: Optional[str] = None) -> BaseCommunicationClient:

View File

@@ -51,14 +51,14 @@ def convert_argv_dashes_to_underscores(args: argparse.ArgumentParser):
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description="Start Uni-Lab Edge server.")
parser.add_argument("-g", "--graph", help="Physical setup graph.")
parser.add_argument("-c", "--controllers", default=None, help="Controllers config file.")
parser.add_argument("-g", "--graph", help="Physical setup graph file path.")
parser.add_argument("-c", "--controllers", default=None, help="Controllers config file path.")
parser.add_argument(
"--registry_path",
type=str,
default=None,
action="append",
help="Path to the registry",
help="Path to the registry directory",
)
parser.add_argument(
"--working_dir",
@@ -75,84 +75,85 @@ def parse_args():
parser.add_argument(
"--app_bridges",
nargs="+",
default=["mqtt", "fastapi"],
help="Bridges to connect to. Now support 'mqtt' and 'fastapi'.",
default=["websocket", "fastapi"],
help="Bridges to connect to. Now support 'websocket' and 'fastapi'.",
)
parser.add_argument(
"--without_host",
"--is_slave",
action="store_true",
help="Run the backend as slave (without host).",
help="Run the backend as slave node (without host privileges).",
)
parser.add_argument(
"--slave_no_host",
action="store_true",
help="Slave模式下跳过等待host服务",
help="Skip waiting for host service in slave mode",
)
parser.add_argument(
"--upload_registry",
action="store_true",
help="启动unilab时同时报送注册表信息",
help="Upload registry information when starting unilab",
)
parser.add_argument(
"--use_remote_resource",
action="store_true",
help="启动unilab时使用远程资源启动",
help="Use remote resources when starting unilab",
)
parser.add_argument(
"--config",
type=str,
default=None,
help="配置文件路径,支持.py格式的Python配置文件",
help="Configuration file path, supports .py format Python config files",
)
parser.add_argument(
"--port",
type=int,
default=8002,
help="信息页web服务的启动端口",
help="Port for web service information page",
)
parser.add_argument(
"--disable_browser",
action="store_true",
help="是否在启动时关闭信息页",
help="Disable opening information page on startup",
)
parser.add_argument(
"--2d_vis",
action="store_true",
help="是否在pylabrobot实例启动时同时启动可视化",
help="Enable 2D visualization when starting pylabrobot instance",
)
parser.add_argument(
"--visual",
choices=["rviz", "web", "disable"],
default="disable",
help="选择可视化工具: rviz, web",
help="Choose visualization tool: rviz, web, or disable",
)
parser.add_argument(
"--ak",
type=str,
default="",
help="实验室请求的ak",
help="Access key for laboratory requests",
)
parser.add_argument(
"--sk",
type=str,
default="",
help="实验室请求的sk",
help="Secret key for laboratory requests",
)
parser.add_argument(
"--addr",
type=str,
default="https://uni-lab.bohrium.com/api/v1",
help="实验室后端地址",
)
parser.add_argument(
"--websocket",
action="store_true",
help="使用websocket而非mqtt作为通信协议",
help="Laboratory backend address",
)
parser.add_argument(
"--skip_env_check",
action="store_true",
help="跳过启动时的环境依赖检查",
help="Skip environment dependency check on startup",
)
parser.add_argument(
"--complete_registry",
action="store_true",
default=False,
help="Complete registry information",
)
return parser
@@ -237,13 +238,17 @@ def main():
print_status("远程资源不存在,本地将进行首次上报!", "info")
# 设置BasicConfig参数
BasicConfig.ak = args_dict.get("ak", "")
BasicConfig.sk = args_dict.get("sk", "")
if args_dict.get("ak", ""):
BasicConfig.ak = args_dict.get("ak", "")
print_status("传入了ak参数优先采用传入参数", "info")
if args_dict.get("sk", ""):
BasicConfig.sk = args_dict.get("sk", "")
print_status("传入了sk参数优先采用传入参数", "info")
BasicConfig.working_dir = working_dir
BasicConfig.is_host_mode = not args_dict.get("without_host", False)
BasicConfig.is_host_mode = not args_dict.get("is_slave", False)
BasicConfig.slave_no_host = args_dict.get("slave_no_host", False)
BasicConfig.upload_registry = args_dict.get("upload_registry", False)
BasicConfig.communication_protocol = "websocket" if args_dict.get("websocket", False) else "mqtt"
BasicConfig.communication_protocol = "websocket"
machine_name = os.popen("hostname").read().strip()
machine_name = "".join([c if c.isalnum() or c == "_" else "_" for c in machine_name])
BasicConfig.machine_name = machine_name
@@ -261,12 +266,19 @@ def main():
from unilabos.app.backend import start_backend
from unilabos.app.web import http_client
from unilabos.app.web import start_server
from unilabos.app.register import register_devices_and_resources
# 显示启动横幅
print_unilab_banner(args_dict)
# 注册表
lab_registry = build_registry(args_dict["registry_path"], False, args_dict["upload_registry"])
lab_registry = build_registry(
args_dict["registry_path"], args_dict.get("complete_registry", False), args_dict["upload_registry"]
)
if not BasicConfig.ak or not BasicConfig.sk:
print_status("后续运行必须拥有一个实验室,请前往 https://uni-lab.bohrium.com 注册实验室!", "warning")
os._exit(1)
if args_dict["graph"] is None:
request_startup_json = http_client.request_startup_json()
if not request_startup_json:
@@ -297,14 +309,24 @@ def main():
target_node = nodes[i["target"]]
source_handle = i["sourceHandle"]
target_handle = i["targetHandle"]
source_handler_keys = [h["handler_key"] for h in materials[source_node["class"]]["handles"] if h["io_type"] == 'source']
target_handler_keys = [h["handler_key"] for h in materials[target_node["class"]]["handles"] if h["io_type"] == 'target']
if not source_handle in source_handler_keys:
print_status(f"节点 {source_node['id']} 的source端点 {source_handle} 不存在,请检查,支持的端点 {source_handler_keys}", "error")
source_handler_keys = [
h["handler_key"] for h in materials[source_node["class"]]["handles"] if h["io_type"] == "source"
]
target_handler_keys = [
h["handler_key"] for h in materials[target_node["class"]]["handles"] if h["io_type"] == "target"
]
if source_handle not in source_handler_keys:
print_status(
f"节点 {source_node['id']} 的source端点 {source_handle} 不存在,请检查,支持的端点 {source_handler_keys}",
"error",
)
resource_edge_info.pop(edge_info - ind - 1)
continue
if not target_handle in target_handler_keys:
print_status(f"节点 {target_node['id']} 的target端点 {target_handle} 不存在,请检查,支持的端点 {target_handler_keys}", "error")
if target_handle not in target_handler_keys:
print_status(
f"节点 {target_node['id']} 的target端点 {target_handle} 不存在,请检查,支持的端点 {target_handler_keys}",
"error",
)
resource_edge_info.pop(edge_info - ind - 1)
continue
@@ -318,6 +340,19 @@ def main():
for i in args_dict["resources_config"]:
print_status(f"DeviceId: {i['id']}, Class: {i['class']}", "info")
# 设备注册到服务端 - 需要 ak 和 sk
if args_dict.get("ak") and args_dict.get("sk"):
print_status("检测到 ak 和 sk开始注册设备到服务端...", "info")
try:
register_devices_and_resources(lab_registry)
print_status("设备注册完成", "info")
except Exception as e:
print_status(f"设备注册失败: {e}", "error")
elif args_dict.get("ak") or args_dict.get("sk"):
print_status("警告ak 和 sk 必须同时提供才能注册设备到服务端", "warning")
else:
print_status("未提供 ak 和 sk跳过设备注册", "info")
if args_dict["controllers"] is not None:
args_dict["controllers_config"] = yaml.safe_load(open(args_dict["controllers"], encoding="utf-8"))
else:
@@ -325,14 +360,14 @@ def main():
args_dict["bridges"] = []
# 获取通信客户端(根据配置选择MQTT或WebSocket
# 获取通信客户端(仅支持WebSocket
comm_client = get_communication_client()
if "mqtt" in args_dict["app_bridges"]:
if "websocket" in args_dict["app_bridges"]:
args_dict["bridges"].append(comm_client)
if "fastapi" in args_dict["app_bridges"]:
args_dict["bridges"].append(http_client)
if "mqtt" in args_dict["app_bridges"]:
if "websocket" in args_dict["app_bridges"]:
def _exit(signum, frame):
comm_client.stop()

View File

@@ -1,225 +0,0 @@
import json
import time
import traceback
from typing import Optional
import uuid
import paho.mqtt.client as mqtt
import ssl
import base64
import hmac
from hashlib import sha1
import tempfile
import os
from unilabos.config.config import MQConfig
from unilabos.app.controler import job_add
from unilabos.app.model import JobAddReq
from unilabos.app.communication import BaseCommunicationClient
from unilabos.utils import logger
from unilabos.utils.type_check import TypeEncoder
from paho.mqtt.enums import CallbackAPIVersion
class MQTTClient(BaseCommunicationClient):
mqtt_disable = True
def __init__(self):
super().__init__()
self.mqtt_disable = not MQConfig.lab_id
self.is_disabled = self.mqtt_disable # 更新父类属性
self.client_id = f"{MQConfig.group_id}@@@{MQConfig.lab_id}{uuid.uuid4()}"
logger.info("[MQTT] Client_id: " + self.client_id)
self.client = mqtt.Client(CallbackAPIVersion.VERSION2, client_id=self.client_id, protocol=mqtt.MQTTv5)
self._setup_callbacks()
def _setup_callbacks(self):
self.client.on_log = self._on_log
self.client.on_connect = self._on_connect
self.client.on_message = self._on_message
self.client.on_disconnect = self._on_disconnect
def _on_log(self, client, userdata, level, buf):
# logger.info(f"[MQTT] log: {buf}")
pass
def _on_connect(self, client, userdata, flags, rc, properties=None):
logger.info("[MQTT] Connected with result code " + str(rc))
client.subscribe(f"labs/{MQConfig.lab_id}/job/start/", 0)
client.subscribe(f"labs/{MQConfig.lab_id}/pong/", 0)
def _on_message(self, client, userdata, msg) -> None:
# logger.info("[MQTT] on_message<<<< " + msg.topic + " " + str(msg.payload))
try:
payload_str = msg.payload.decode("utf-8")
payload_json = json.loads(payload_str)
if msg.topic == f"labs/{MQConfig.lab_id}/job/start/":
if "data" not in payload_json:
payload_json["data"] = {}
if "action" in payload_json:
payload_json["data"]["action"] = payload_json.pop("action")
if "action_type" in payload_json:
payload_json["data"]["action_type"] = payload_json.pop("action_type")
if "action_args" in payload_json:
payload_json["data"]["action_args"] = payload_json.pop("action_args")
if "action_kwargs" in payload_json:
payload_json["data"]["action_kwargs"] = payload_json.pop("action_kwargs")
job_req = JobAddReq.model_validate(payload_json)
data = job_add(job_req)
return
elif msg.topic == f"labs/{MQConfig.lab_id}/pong/":
# 处理pong响应通知HostNode
from unilabos.ros.nodes.presets.host_node import HostNode
host_instance = HostNode.get_instance(0)
if host_instance:
host_instance.handle_pong_response(payload_json)
return
except json.JSONDecodeError as e:
logger.error(f"[MQTT] JSON 解析错误: {e}")
logger.error(f"[MQTT] Raw message: {msg.payload}")
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"[MQTT] 处理消息时出错: {e}")
logger.error(traceback.format_exc())
def _on_disconnect(self, client, userdata, rc, reasonCode=None, properties=None):
if rc != 0:
logger.error(f"[MQTT] Unexpected disconnection {rc}")
def _setup_ssl_context(self):
temp_files = []
try:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as ca_temp:
ca_temp.write(MQConfig.ca_content)
temp_files.append(ca_temp.name)
with tempfile.NamedTemporaryFile(mode="w", delete=False) as cert_temp:
cert_temp.write(MQConfig.cert_content)
temp_files.append(cert_temp.name)
with tempfile.NamedTemporaryFile(mode="w", delete=False) as key_temp:
key_temp.write(MQConfig.key_content)
temp_files.append(key_temp.name)
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
context.load_verify_locations(cafile=temp_files[0])
context.load_cert_chain(certfile=temp_files[1], keyfile=temp_files[2])
self.client.tls_set_context(context)
finally:
for temp_file in temp_files:
try:
os.unlink(temp_file)
except Exception as e:
pass
def start(self):
if self.mqtt_disable:
logger.warning("MQTT is disabled, skipping connection.")
return
userName = f"Signature|{MQConfig.access_key}|{MQConfig.instance_id}"
password = base64.b64encode(
hmac.new(MQConfig.secret_key.encode(), self.client_id.encode(), sha1).digest()
).decode()
self.client.username_pw_set(userName, password)
self._setup_ssl_context()
# 创建连接线程
def connect_thread_func():
try:
self.client.connect(MQConfig.broker_url, MQConfig.port, 60)
self.client.loop_start()
# 添加连接超时检测
max_attempts = 5
attempt = 0
while not self.client.is_connected() and attempt < max_attempts:
logger.info(
f"[MQTT] 正在连接到 {MQConfig.broker_url}:{MQConfig.port},尝试 {attempt+1}/{max_attempts}"
)
time.sleep(3)
attempt += 1
if self.client.is_connected():
logger.info(f"[MQTT] 已成功连接到 {MQConfig.broker_url}:{MQConfig.port}")
else:
logger.error(f"[MQTT] 连接超时,可能是账号密码错误或网络问题")
self.client.loop_stop()
except Exception as e:
logger.error(f"[MQTT] 连接失败: {str(e)}")
connect_thread_func()
# connect_thread = threading.Thread(target=connect_thread_func)
# connect_thread.daemon = True
# connect_thread.start()
def stop(self):
if self.mqtt_disable:
return
self.client.disconnect()
self.client.loop_stop()
def publish_device_status(self, device_status: dict, device_id, property_name):
# status = device_status.get(device_id, {})
if self.mqtt_disable:
return
status = {"data": device_status.get(device_id, {}), "device_id": device_id, "timestamp": time.time()}
address = f"labs/{MQConfig.lab_id}/devices/"
self.client.publish(address, json.dumps(status), qos=2)
# logger.info(f"Device {device_id} status published: address: {address}, {status}")
def publish_job_status(self, feedback_data: dict, job_id: str, status: str, return_info: Optional[dict] = None):
if self.mqtt_disable:
return
if return_info is None:
return_info = {}
jobdata = {"job_id": job_id, "data": feedback_data, "status": status, "return_info": return_info}
self.client.publish(f"labs/{MQConfig.lab_id}/job/list/", json.dumps(jobdata), qos=2)
def publish_registry(self, device_id: str, device_info: dict, print_debug: bool = True):
if self.mqtt_disable:
return
address = f"labs/{MQConfig.lab_id}/registry/"
registry_data = json.dumps({device_id: device_info}, ensure_ascii=False, cls=TypeEncoder)
self.client.publish(address, registry_data, qos=2)
if print_debug:
logger.debug(f"Registry data published: address: {address}, {registry_data}")
def publish_actions(self, action_id: str, action_info: dict):
if self.mqtt_disable:
return
address = f"labs/{MQConfig.lab_id}/actions/"
self.client.publish(address, json.dumps(action_info), qos=2)
logger.debug(f"Action data published: address: {address}, {action_id}, {action_info}")
def send_ping(self, ping_id: str, timestamp: float):
"""发送ping消息到服务端"""
if self.mqtt_disable:
return
address = f"labs/{MQConfig.lab_id}/ping/"
ping_data = {"ping_id": ping_id, "client_timestamp": timestamp, "type": "ping"}
self.client.publish(address, json.dumps(ping_data), qos=2)
def setup_pong_subscription(self):
"""设置pong消息订阅"""
if self.mqtt_disable:
return
pong_topic = f"labs/{MQConfig.lab_id}/pong/"
self.client.subscribe(pong_topic, 0)
logger.debug(f"Subscribed to pong topic: {pong_topic}")
@property
def is_connected(self) -> bool:
"""检查MQTT是否已连接"""
if self.is_disabled:
return False
return hasattr(self.client, "is_connected") and self.client.is_connected()
mqtt_client = MQTTClient()
if __name__ == "__main__":
mqtt_client.start()

View File

@@ -10,117 +10,53 @@ from unilabos.utils.log import logger
from unilabos.utils.type_check import TypeEncoder
def register_devices_and_resources(comm_client, lab_registry):
def register_devices_and_resources(lab_registry):
"""
注册设备和资源到通信服务器(MQTT/WebSocket
注册设备和资源到服务器(仅支持HTTP
"""
# 注册资源信息 - 使用HTTP方式
from unilabos.app.web.client import http_client
logger.info("[UniLab Register] 开始注册设备和资源...")
if BasicConfig.auth_secret():
# 注册设备信息
devices_to_register = {}
for device_info in lab_registry.obtain_registry_device_info():
devices_to_register[device_info["id"]] = json.loads(
json.dumps(device_info, ensure_ascii=False, cls=TypeEncoder)
)
logger.debug(f"[UniLab Register] 收集设备: {device_info['id']}")
resources_to_register = {}
for resource_info in lab_registry.obtain_registry_resource_info():
resources_to_register[resource_info["id"]] = resource_info
logger.debug(f"[UniLab Register] 收集资源: {resource_info['id']}")
print(
"[UniLab Register] 设备注册",
http_client.resource_registry({"resources": list(devices_to_register.values())}).text,
# 注册设备信息
devices_to_register = {}
for device_info in lab_registry.obtain_registry_device_info():
devices_to_register[device_info["id"]] = json.loads(
json.dumps(device_info, ensure_ascii=False, cls=TypeEncoder)
)
print(
"[UniLab Register] 资源注册",
http_client.resource_registry({"resources": list(resources_to_register.values())}).text,
)
else:
# 注册设备信息
for device_info in lab_registry.obtain_registry_device_info():
comm_client.publish_registry(device_info["id"], device_info, False)
logger.debug(f"[UniLab Register] 注册设备: {device_info['id']}")
logger.debug(f"[UniLab Register] 收集设备: {device_info['id']}")
# # 注册资源信息
# for resource_info in lab_registry.obtain_registry_resource_info():
# comm_client.publish_registry(resource_info["id"], resource_info, False)
# logger.debug(f"[UniLab Register] 注册资源: {resource_info['id']}")
resources_to_register = {}
for resource_info in lab_registry.obtain_registry_resource_info():
resources_to_register[resource_info["id"]] = resource_info
logger.debug(f"[UniLab Register] 收集资源: {resource_info['id']}")
resources_to_register = {}
for resource_info in lab_registry.obtain_registry_resource_info():
resources_to_register[resource_info["id"]] = resource_info
logger.debug(f"[UniLab Register] 准备注册资源: {resource_info['id']}")
if resources_to_register:
# 注册设备
if devices_to_register:
try:
start_time = time.time()
response = http_client.resource_registry(resources_to_register)
response = http_client.resource_registry({"resources": list(devices_to_register.values())})
cost_time = time.time() - start_time
if response.status_code in [200, 201]:
logger.info(f"[UniLab Register] 成功通过HTTP注册 {len(resources_to_register)}资源 {cost_time}ms")
logger.info(f"[UniLab Register] 成功注册 {len(devices_to_register)}设备 {cost_time}ms")
else:
logger.error(
f"[UniLab Register] HTTP注册资源失败: {response.status_code}, {response.text} {cost_time}ms"
)
logger.error(f"[UniLab Register] 设备注册失败: {response.status_code}, {response.text} {cost_time}ms")
except Exception as e:
logger.error(f"[UniLab Register] 设备注册异常: {e}")
# 注册资源
if resources_to_register:
try:
start_time = time.time()
response = http_client.resource_registry({"resources": list(resources_to_register.values())})
cost_time = time.time() - start_time
if response.status_code in [200, 201]:
logger.info(f"[UniLab Register] 成功注册 {len(resources_to_register)} 个资源 {cost_time}ms")
else:
logger.error(f"[UniLab Register] 资源注册失败: {response.status_code}, {response.text} {cost_time}ms")
except Exception as e:
logger.error(f"[UniLab Register] 资源注册异常: {e}")
logger.info("[UniLab Register] 设备和资源注册完成.")
def main():
"""
命令行入口函数
"""
parser = argparse.ArgumentParser(description="注册设备和资源到 MQTT")
parser.add_argument(
"--registry",
type=str,
default=None,
action="append",
help="注册表路径",
)
parser.add_argument(
"--config",
type=str,
default=None,
help="配置文件路径,支持.py格式的Python配置文件",
)
parser.add_argument(
"--ak",
type=str,
default="",
help="实验室请求的ak",
)
parser.add_argument(
"--sk",
type=str,
default="",
help="实验室请求的sk",
)
parser.add_argument(
"--complete_registry",
action="store_true",
default=False,
help="是否补全注册表",
)
args = parser.parse_args()
load_config_from_file(args.config)
BasicConfig.ak = args.ak
BasicConfig.sk = args.sk
# 构建注册表
build_registry(args.registry, args.complete_registry, True)
from unilabos.app.communication import get_communication_client
# 获取通信客户端并启动连接
comm_client = get_communication_client()
comm_client.start()
from unilabos.registry.registry import lab_registry
# 注册设备和资源
register_devices_and_resources(comm_client, lab_registry)
if __name__ == "__main__":
main()

View File

@@ -9,16 +9,13 @@ import asyncio
import yaml
from unilabos.app.controler import devices, job_add, job_info
from unilabos.app.web.controler import devices, job_add, job_info
from unilabos.app.model import (
Resp,
RespCode,
JobStatusResp,
JobAddResp,
JobAddReq,
JobStepFinishReq,
JobPreintakeFinishReq,
JobFinishReq,
)
from unilabos.app.web.utils.host_utils import get_host_node_info
from unilabos.registry.registry import lab_registry

View File

@@ -3,6 +3,7 @@ HTTP客户端模块
提供与远程服务器通信的客户端功能只有host需要用
"""
import json
import os
from typing import List, Dict, Any, Optional
@@ -15,7 +16,6 @@ from unilabos.utils import logger
class HTTPClient:
"""HTTP客户端用于与远程服务器通信"""
backend_go = False # 是否使用Go后端
def __init__(self, remote_addr: Optional[str] = None, auth: Optional[str] = None) -> None:
"""
@@ -32,7 +32,6 @@ class HTTPClient:
auth_secret = BasicConfig.auth_secret()
if auth_secret:
self.auth = auth_secret
self.backend_go = True
info(f"正在使用ak sk作为授权信息 {auth_secret}")
else:
self.auth = MQConfig.lab_id
@@ -48,17 +47,15 @@ class HTTPClient:
Returns:
Response: API响应对象
"""
database_param = 1 if database_process_later else 0
response = requests.post(
f"{self.remote_addr}/lab/resource/edge/batch_create/?database_process_later={database_param}"
if not self.backend_go else f"{self.remote_addr}/lab/material/edge",
f"{self.remote_addr}/lab/material/edge",
json={
"edges": resources,
} if self.backend_go else resources,
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
},
headers={"Authorization": f"Lab {self.auth}"},
timeout=100,
)
if self.backend_go and response.status_code == 200:
if response.status_code == 200:
res = response.json()
if "code" in res and res["code"] != 0:
logger.error(f"添加物料关系失败: {response.text}")
@@ -77,12 +74,12 @@ class HTTPClient:
Response: API响应对象
"""
response = requests.post(
f"{self.remote_addr}/lab/resource/?database_process_later={1 if database_process_later else 0}" if not self.backend_go else f"{self.remote_addr}/lab/material",
json=resources if not self.backend_go else {"nodes": resources},
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
f"{self.remote_addr}/lab/material",
json={"nodes": resources},
headers={"Authorization": f"Lab {self.auth}"},
timeout=100,
)
if self.backend_go and response.status_code == 200:
if response.status_code == 200:
res = response.json()
if "code" in res and res["code"] != 0:
logger.error(f"添加物料失败: {response.text}")
@@ -102,9 +99,9 @@ class HTTPClient:
Dict: 返回的资源数据
"""
response = requests.get(
f"{self.remote_addr}/lab/resource/?edge_format=1" if not self.backend_go else f"{self.remote_addr}/lab/material",
f"{self.remote_addr}/lab/material",
params={"id": id, "with_children": with_children},
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
headers={"Authorization": f"Lab {self.auth}"},
timeout=20,
)
return response.json()
@@ -122,7 +119,7 @@ class HTTPClient:
response = requests.delete(
f"{self.remote_addr}/lab/resource/batch_delete/",
params={"id": id},
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
headers={"Authorization": f"Lab {self.auth}"},
timeout=20,
)
return response
@@ -140,7 +137,7 @@ class HTTPClient:
response = requests.patch(
f"{self.remote_addr}/lab/resource/batch_update/?edge_format=1",
json=resources,
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
headers={"Authorization": f"Lab {self.auth}"},
timeout=100,
)
return response
@@ -164,7 +161,7 @@ class HTTPClient:
response = requests.post(
f"{self.remote_addr}/api/account/file_upload/{scene}",
files=files,
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
headers={"Authorization": f"Lab {self.auth}"},
timeout=30, # 上传文件可能需要更长的超时时间
)
return response
@@ -180,9 +177,9 @@ class HTTPClient:
Response: API响应对象
"""
response = requests.post(
f"{self.remote_addr}/lab/registry/" if not self.backend_go else f"{self.remote_addr}/lab/resource",
f"{self.remote_addr}/lab/resource",
json=registry_data,
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
headers={"Authorization": f"Lab {self.auth}"},
timeout=30,
)
if response.status_code not in [200, 201]:
@@ -201,7 +198,7 @@ class HTTPClient:
"""
response = requests.get(
f"{self.remote_addr}/lab/resource/graph_info/",
headers={"Authorization": f"{'lab' if not self.backend_go else 'Lab'} {self.auth}"},
headers={"Authorization": f"Lab {self.auth}"},
timeout=(3, 30),
)
if response.status_code != 200: