Add workflow upload func.

This commit is contained in:
Xuwznln
2025-12-08 19:12:05 +08:00
parent ced961050d
commit 16ee3de086
32 changed files with 811 additions and 222 deletions

View File

@@ -20,6 +20,7 @@ if unilabos_dir not in sys.path:
from unilabos.utils.banner_print import print_status, print_unilab_banner
from unilabos.config.config import load_config, BasicConfig, HTTPConfig
def load_config_from_file(config_path):
if config_path is None:
config_path = os.environ.get("UNILABOS_BASICCONFIG_CONFIG_PATH", None)
@@ -41,7 +42,7 @@ def convert_argv_dashes_to_underscores(args: argparse.ArgumentParser):
for i, arg in enumerate(sys.argv):
for option_string in option_strings:
if arg.startswith(option_string):
new_arg = arg[:2] + arg[2:len(option_string)].replace("-", "_") + arg[len(option_string):]
new_arg = arg[:2] + arg[2 : len(option_string)].replace("-", "_") + arg[len(option_string) :]
sys.argv[i] = new_arg
break
@@ -155,14 +156,39 @@ def parse_args():
default=False,
help="Complete registry information",
)
# label
# workflow upload subcommand
workflow_parser = subparsers.add_parser(
"workflow_upload",
aliases=["wf"],
help="Upload workflow from xdl/json/python files",
)
workflow_parser.add_argument("-t", "--labeltype", default="singlepoint", type=str,
help="QM calculation type, support 'singlepoint', 'optimize' and 'dimer' currently")
workflow_parser.add_argument(
"-f",
"--workflow_file",
type=str,
required=True,
help="Path to the workflow file (JSON format)",
)
workflow_parser.add_argument(
"-n",
"--workflow_name",
type=str,
default=None,
help="Workflow name, if not provided will use the name from file or filename",
)
workflow_parser.add_argument(
"--tags",
type=str,
nargs="*",
default=[],
help="Tags for the workflow (space-separated)",
)
workflow_parser.add_argument(
"--published",
action="store_true",
default=False,
help="Whether to publish the workflow (default: False)",
)
return parser
@@ -173,9 +199,6 @@ def main():
convert_argv_dashes_to_underscores(args)
args_dict = vars(args.parse_args())
# 显示启动横幅
print_unilab_banner(args_dict)
# 环境检查 - 检查并自动安装必需的包 (可选)
if not args_dict.get("skip_env_check", False):
from unilabos.utils.environment_check import check_environment
@@ -254,18 +277,10 @@ def main():
print_status("传入了sk参数优先采用传入参数", "info")
BasicConfig.working_dir = working_dir
# 显示启动横幅
print_unilab_banner(args_dict)
workflow_upload = args_dict.get("command") in ("workflow_upload", "wf")
#####################################
######## 启动设备接入端(主入口) ########
#####################################
launch(args_dict)
def launch(args_dict: Dict[str, Any]):
# 使用远程资源启动
if args_dict["use_remote_resource"]:
if not workflow_upload and args_dict["use_remote_resource"]:
print_status("使用远程资源启动", "info")
from unilabos.app.web import http_client
@@ -301,11 +316,36 @@ def launch(args_dict: Dict[str, Any]):
from unilabos.resources.graphio import modify_to_backend_format
from unilabos.ros.nodes.resource_tracker import ResourceTreeSet, ResourceDict
# 显示启动横幅
print_unilab_banner(args_dict)
# 注册表
lab_registry = build_registry(
args_dict["registry_path"], args_dict.get("complete_registry", False), args_dict["upload_registry"]
args_dict["registry_path"], args_dict.get("complete_registry", False), BasicConfig.upload_registry
)
if BasicConfig.upload_registry:
# 设备注册到服务端 - 需要 ak 和 sk
if BasicConfig.ak and BasicConfig.sk:
print_status("开始注册设备到服务端...", "info")
try:
register_devices_and_resources(lab_registry)
print_status("设备注册完成", "info")
except Exception as e:
print_status(f"设备注册失败: {e}", "error")
else:
print_status("未提供 ak 和 sk跳过设备注册", "info")
else:
print_status("本次启动注册表不报送云端,如果您需要联网调试,请在启动命令增加--upload_registry", "warning")
# 处理 workflow_upload 子命令
if workflow_upload:
from unilabos.workflow.wf_utils import handle_workflow_upload_command
handle_workflow_upload_command(args_dict)
print_status("工作流上传完成,程序退出", "info")
os._exit(0)
if not BasicConfig.ak or not BasicConfig.sk:
print_status("后续运行必须拥有一个实验室,请前往 https://uni-lab.bohrium.com 注册实验室!", "warning")
os._exit(1)
@@ -382,20 +422,6 @@ def launch(args_dict: Dict[str, Any]):
args_dict["devices_config"] = resource_tree_set
args_dict["graph"] = graph_res.physical_setup_graph
if BasicConfig.upload_registry:
# 设备注册到服务端 - 需要 ak 和 sk
if BasicConfig.ak and BasicConfig.sk:
print_status("开始注册设备到服务端...", "info")
try:
register_devices_and_resources(lab_registry)
print_status("设备注册完成", "info")
except Exception as e:
print_status(f"设备注册失败: {e}", "error")
else:
print_status("未提供 ak 和 sk跳过设备注册", "info")
else:
print_status("本次启动注册表不报送云端,如果您需要联网调试,请在启动命令增加--upload_registry", "warning")
if args_dict["controllers"] is not None:
args_dict["controllers_config"] = yaml.safe_load(open(args_dict["controllers"], encoding="utf-8"))
else:
@@ -410,6 +436,7 @@ def launch(args_dict: Dict[str, Any]):
comm_client = get_communication_client()
if "websocket" in args_dict["app_bridges"]:
args_dict["bridges"].append(comm_client)
def _exit(signum, frame):
comm_client.stop()
sys.exit(0)
@@ -451,16 +478,13 @@ def launch(args_dict: Dict[str, Any]):
resource_visualization.start()
except OSError as e:
if "AMENT_PREFIX_PATH" in str(e):
print_status(
f"ROS 2环境未正确设置跳过3D可视化启动。错误详情: {e}",
"warning"
)
print_status(f"ROS 2环境未正确设置跳过3D可视化启动。错误详情: {e}", "warning")
print_status(
"建议解决方案:\n"
"1. 激活Conda环境: conda activate unilab\n"
"2. 或使用 --backend simple 参数\n"
"3. 或使用 --visual disable 参数禁用可视化",
"info"
"info",
)
else:
raise

View File

@@ -76,7 +76,8 @@ class HTTPClient:
Dict[str, str]: 旧UUID到新UUID的映射关系 {old_uuid: new_uuid}
"""
with open(os.path.join(BasicConfig.working_dir, "req_resource_tree_add.json"), "w", encoding="utf-8") as f:
f.write(json.dumps({"nodes": [x for xs in resources.dump() for x in xs], "mount_uuid": mount_uuid}, indent=4))
payload = {"nodes": [x for xs in resources.dump() for x in xs], "mount_uuid": mount_uuid}
f.write(json.dumps(payload, indent=4))
# 从序列化数据中提取所有节点的UUID保存旧UUID
old_uuids = {n.res_content.uuid: n for n in resources.all_nodes}
if not self.initialized or first_add:
@@ -331,6 +332,67 @@ class HTTPClient:
logger.error(f"响应内容: {response.text}")
return None
def workflow_import(
self,
name: str,
workflow_uuid: str,
workflow_name: str,
nodes: List[Dict[str, Any]],
edges: List[Dict[str, Any]],
tags: Optional[List[str]] = None,
published: bool = False,
) -> Dict[str, Any]:
"""
导入工作流到服务器
Args:
name: 工作流名称(顶层)
workflow_uuid: 工作流UUID
workflow_name: 工作流名称data内部
nodes: 工作流节点列表
edges: 工作流边列表
tags: 工作流标签列表,默认为空列表
published: 是否发布工作流默认为False
Returns:
Dict: API响应数据包含 code 和 data (uuid, name)
"""
# target_lab_uuid 暂时使用默认值,后续由后端根据 ak/sk 获取
payload = {
"target_lab_uuid": "28c38bb0-63f6-4352-b0d8-b5b8eb1766d5",
"name": name,
"data": {
"workflow_uuid": workflow_uuid,
"workflow_name": workflow_name,
"nodes": nodes,
"edges": edges,
"tags": tags if tags is not None else [],
"published": published,
},
}
# 保存请求到文件
with open(os.path.join(BasicConfig.working_dir, "req_workflow_upload.json"), "w", encoding="utf-8") as f:
f.write(json.dumps(payload, indent=4, ensure_ascii=False))
response = requests.post(
f"{self.remote_addr}/lab/workflow/owner/import",
json=payload,
headers={"Authorization": f"Lab {self.auth}"},
timeout=60,
)
# 保存响应到文件
with open(os.path.join(BasicConfig.working_dir, "res_workflow_upload.json"), "w", encoding="utf-8") as f:
f.write(f"{response.status_code}" + "\n" + response.text)
if response.status_code == 200:
res = response.json()
if "code" in res and res["code"] != 0:
logger.error(f"导入工作流失败: {response.text}")
return res
else:
logger.error(f"导入工作流失败: {response.status_code}, {response.text}")
return {"code": response.status_code, "message": response.text}
# 创建默认客户端实例
http_client = HTTPClient()