mirror of
https://github.com/dptech-corp/Uni-Lab-OS.git
synced 2026-02-07 07:25:15 +00:00
新增注册表补全功能,修复Protocol执行失败
This commit is contained in:
@@ -55,7 +55,7 @@ def ros2_device_node(
|
||||
"read": "read_data",
|
||||
"extra_info": [],
|
||||
}
|
||||
|
||||
# FIXME 后面要删除
|
||||
for k, v in cls.__dict__.items():
|
||||
if not k.startswith("_") and isinstance(v, property):
|
||||
# noinspection PyUnresolvedReferences
|
||||
|
||||
@@ -727,7 +727,6 @@ def ros_action_to_json_schema(action_class: Any) -> Dict[str, Any]:
|
||||
|
||||
# 创建基础 schema
|
||||
schema = {
|
||||
'$schema': 'http://json-schema.org/draft-07/schema#',
|
||||
'title': action_class.__name__,
|
||||
'description': f"ROS Action {action_class.__name__} 的 JSON Schema",
|
||||
'type': 'object',
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
@@ -10,6 +11,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
import asyncio
|
||||
|
||||
import rclpy
|
||||
import yaml
|
||||
from rclpy.node import Node
|
||||
from rclpy.action import ActionServer, ActionClient
|
||||
from rclpy.action.server import ServerGoalHandle
|
||||
@@ -302,6 +304,8 @@ class BaseROS2DeviceNode(Node, Generic[T]):
|
||||
# 创建动作服务
|
||||
if self.create_action_server:
|
||||
for action_name, action_value_mapping in self._action_value_mappings.items():
|
||||
if action_name.startswith("auto-"):
|
||||
continue
|
||||
self.create_ros_action_server(action_name, action_value_mapping)
|
||||
|
||||
# 创建线程池执行器
|
||||
@@ -838,6 +842,8 @@ class BaseROS2DeviceNode(Node, Generic[T]):
|
||||
class DeviceInitError(Exception):
|
||||
pass
|
||||
|
||||
class JsonCommandInitError(Exception):
|
||||
pass
|
||||
|
||||
class ROS2DeviceNode:
|
||||
"""
|
||||
@@ -954,12 +960,51 @@ class ROS2DeviceNode:
|
||||
self._ros_node: BaseROS2DeviceNode
|
||||
self._ros_node.lab_logger().info(f"初始化完成 {self._ros_node.uuid} {self.driver_is_ros}")
|
||||
self.driver_instance._ros_node = self._ros_node # type: ignore
|
||||
self.driver_instance._execute_driver_command = self._execute_driver_command # type: ignore
|
||||
self.driver_instance._execute_driver_command_async = self._execute_driver_command_async # type: ignore
|
||||
if hasattr(self.driver_instance, "post_init"):
|
||||
try:
|
||||
self.driver_instance.post_init(self._ros_node) # type: ignore
|
||||
except Exception as e:
|
||||
self._ros_node.lab_logger().error(f"设备后初始化失败: {e}")
|
||||
|
||||
def _execute_driver_command(self, string: str):
|
||||
try:
|
||||
target = json.loads(string)
|
||||
except Exception as ex:
|
||||
try:
|
||||
target = yaml.safe_load(io.StringIO(string))
|
||||
except Exception as ex2:
|
||||
raise JsonCommandInitError(f"执行动作时JSON/YAML解析失败: \n{ex}\n{ex2}\n原内容: {string}\n{traceback.format_exc()}")
|
||||
try:
|
||||
function_name = target["function_name"]
|
||||
function_args = target["function_args"]
|
||||
assert isinstance(function_args, dict), "执行动作时JSON必须为dict类型\n原JSON: {string}"
|
||||
function = getattr(self.driver_instance, function_name)
|
||||
assert callable(function), f"执行动作时JSON中的function_name对应的函数不可调用: {function_name}\n原JSON: {string}"
|
||||
return function(**function_args)
|
||||
except KeyError as ex:
|
||||
raise JsonCommandInitError(f"执行动作时JSON缺少function_name或function_args: {ex}\n原JSON: {string}\n{traceback.format_exc()}")
|
||||
|
||||
async def _execute_driver_command_async(self, string: str):
|
||||
try:
|
||||
target = json.loads(string)
|
||||
except Exception as ex:
|
||||
try:
|
||||
target = yaml.safe_load(io.StringIO(string))
|
||||
except Exception as ex2:
|
||||
raise JsonCommandInitError(f"执行动作时JSON/YAML解析失败: \n{ex}\n{ex2}\n原内容: {string}\n{traceback.format_exc()}")
|
||||
try:
|
||||
function_name = target["function_name"]
|
||||
function_args = target["function_args"]
|
||||
assert isinstance(function_args, dict), "执行动作时JSON必须为dict类型\n原JSON: {string}"
|
||||
function = getattr(self.driver_instance, function_name)
|
||||
assert callable(function), f"执行动作时JSON中的function_name对应的函数不可调用: {function_name}\n原JSON: {string}"
|
||||
assert asyncio.iscoroutinefunction(function), f"执行动作时JSON中的function并非异步: {function_name}\n原JSON: {string}"
|
||||
return await function(**function_args)
|
||||
except KeyError as ex:
|
||||
raise JsonCommandInitError(f"执行动作时JSON缺少function_name或function_args: {ex}\n原JSON: {string}\n{traceback.format_exc()}")
|
||||
|
||||
def _start_loop(self):
|
||||
def run_event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
@@ -567,6 +567,7 @@ class HostNode(BaseROS2DeviceNode):
|
||||
def send_goal(
|
||||
self,
|
||||
device_id: str,
|
||||
action_type: str,
|
||||
action_name: str,
|
||||
action_kwargs: Dict[str, Any],
|
||||
goal_uuid: Optional[str] = None,
|
||||
@@ -577,11 +578,26 @@ class HostNode(BaseROS2DeviceNode):
|
||||
|
||||
Args:
|
||||
device_id: 设备ID
|
||||
action_type: 动作类型
|
||||
action_name: 动作名称
|
||||
action_kwargs: 动作参数
|
||||
goal_uuid: 目标UUID,如果为None则自动生成
|
||||
server_info: 服务器发送信息,包含发送时间戳等
|
||||
"""
|
||||
action_id = f"/devices/{device_id}/{action_name}"
|
||||
if action_type.startswith("UniLabJsonCommand"):
|
||||
if action_name.startswith("auto-"):
|
||||
action_name = action_name[5:]
|
||||
action_id = f"/devices/{device_id}/_execute_driver_command"
|
||||
action_kwargs = {
|
||||
"string": json.dumps({
|
||||
"function_name": action_name,
|
||||
"function_args": action_kwargs,
|
||||
})
|
||||
}
|
||||
if action_type.startswith("UniLabJsonCommandAsync"):
|
||||
action_id = f"/devices/{device_id}/_execute_driver_command_async"
|
||||
else:
|
||||
action_id = f"/devices/{device_id}/{action_name}"
|
||||
if action_name == "test_latency" and server_info is not None:
|
||||
self.server_latest_timestamp = server_info.get("send_timestamp", 0.0)
|
||||
if action_id not in self._action_clients:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
from types import MethodType
|
||||
from typing import Union
|
||||
|
||||
import rclpy
|
||||
@@ -22,6 +20,8 @@ from unilabos.ros.msgs.message_converter import (
|
||||
convert_from_ros_msg_with_mapping,
|
||||
)
|
||||
from unilabos.ros.nodes.base_device_node import BaseROS2DeviceNode, DeviceNodeResourceTracker, ROS2DeviceNode
|
||||
from unilabos.utils.log import error
|
||||
from unilabos.utils.type_check import serialize_result_info
|
||||
|
||||
|
||||
class ROS2ProtocolNode(BaseROS2DeviceNode):
|
||||
@@ -33,7 +33,15 @@ class ROS2ProtocolNode(BaseROS2DeviceNode):
|
||||
|
||||
# create_action_server = False # Action Server要自己创建
|
||||
|
||||
def __init__(self, device_id: str, children: dict, protocol_type: Union[str, list[str]], resource_tracker: DeviceNodeResourceTracker, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
device_id: str,
|
||||
children: dict,
|
||||
protocol_type: Union[str, list[str]],
|
||||
resource_tracker: DeviceNodeResourceTracker,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self._setup_protocol_names(protocol_type)
|
||||
|
||||
# 初始化其它属性
|
||||
@@ -60,7 +68,9 @@ class ROS2ProtocolNode(BaseROS2DeviceNode):
|
||||
|
||||
for device_id, device_config in self.children.items():
|
||||
if device_config.get("type", "device") != "device":
|
||||
self.lab_logger().debug(f"[Protocol Node] Skipping type {device_config['type']} {device_id} already existed, skipping.")
|
||||
self.lab_logger().debug(
|
||||
f"[Protocol Node] Skipping type {device_config['type']} {device_id} already existed, skipping."
|
||||
)
|
||||
continue
|
||||
try:
|
||||
d = self.initialize_device(device_id, device_config)
|
||||
@@ -76,22 +86,27 @@ class ROS2ProtocolNode(BaseROS2DeviceNode):
|
||||
|
||||
# 设置硬件接口代理
|
||||
if d:
|
||||
hardware_interface = d.ros_node_instance._hardware_interface
|
||||
if (
|
||||
hasattr(d.driver_instance, d.ros_node_instance._hardware_interface["name"])
|
||||
and hasattr(d.driver_instance, d.ros_node_instance._hardware_interface["write"])
|
||||
and (d.ros_node_instance._hardware_interface["read"] is None or hasattr(d.driver_instance, d.ros_node_instance._hardware_interface["read"]))
|
||||
hasattr(d.driver_instance, hardware_interface["name"])
|
||||
and hasattr(d.driver_instance, hardware_interface["write"])
|
||||
and (hardware_interface["read"] is None or hasattr(d.driver_instance, hardware_interface["read"]))
|
||||
):
|
||||
|
||||
name = getattr(d.driver_instance, d.ros_node_instance._hardware_interface["name"])
|
||||
read = d.ros_node_instance._hardware_interface.get("read", None)
|
||||
write = d.ros_node_instance._hardware_interface.get("write", None)
|
||||
name = getattr(d.driver_instance, hardware_interface["name"])
|
||||
read = hardware_interface.get("read", None)
|
||||
write = hardware_interface.get("write", None)
|
||||
|
||||
# 如果硬件接口是字符串,通过通信设备提供
|
||||
if isinstance(name, str) and name in self.sub_devices:
|
||||
communicate_device = self.sub_devices[name]
|
||||
communicate_hardware_info = communicate_device.ros_node_instance._hardware_interface
|
||||
self._setup_hardware_proxy(d, self.sub_devices[name], read, write)
|
||||
self.lab_logger().info(f"\n通信代理:为子设备{device_id}\n 添加了{read}方法(来源:{name} {communicate_hardware_info['write']}) \n 添加了{write}方法(来源:{name} {communicate_hardware_info['read']})")
|
||||
self.lab_logger().info(
|
||||
f"\n通信代理:为子设备{device_id}\n "
|
||||
f"添加了{read}方法(来源:{name} {communicate_hardware_info['write']}) \n "
|
||||
f"添加了{write}方法(来源:{name} {communicate_hardware_info['read']})"
|
||||
)
|
||||
|
||||
def _setup_protocol_names(self, protocol_type):
|
||||
# 处理协议类型
|
||||
@@ -149,63 +164,127 @@ class ROS2ProtocolNode(BaseROS2DeviceNode):
|
||||
def _create_protocol_execute_callback(self, protocol_name, protocol_steps_generator):
|
||||
async def execute_protocol(goal_handle: ServerGoalHandle):
|
||||
"""执行完整的工作流"""
|
||||
self.get_logger().info(f'Executing {protocol_name} action...')
|
||||
action_value_mapping = self._action_value_mappings[protocol_name]
|
||||
print('+'*30)
|
||||
print(protocol_steps_generator)
|
||||
# 从目标消息中提取参数, 并调用Protocol生成器(根据设备连接图)生成action步骤
|
||||
goal = goal_handle.request
|
||||
protocol_kwargs = convert_from_ros_msg_with_mapping(goal, action_value_mapping["goal"])
|
||||
# 初始化结果信息变量
|
||||
execution_error = ""
|
||||
execution_success = False
|
||||
protocol_return_value = None
|
||||
|
||||
# 向Host查询物料当前状态
|
||||
for k, v in goal.get_fields_and_field_types().items():
|
||||
if v in ["unilabos_msgs/Resource", "sequence<unilabos_msgs/Resource>"]:
|
||||
r = ResourceGet.Request()
|
||||
r.id = protocol_kwargs[k]["id"] if v == "unilabos_msgs/Resource" else protocol_kwargs[k][0]["id"]
|
||||
r.with_children = True
|
||||
response = await self._resource_clients["resource_get"].call_async(r)
|
||||
protocol_kwargs[k] = list_to_nested_dict([convert_from_ros_msg(rs) for rs in response.resources])
|
||||
try:
|
||||
self.get_logger().info(f"Executing {protocol_name} action...")
|
||||
action_value_mapping = self._action_value_mappings[protocol_name]
|
||||
print("+" * 30)
|
||||
print(protocol_steps_generator)
|
||||
# 从目标消息中提取参数, 并调用Protocol生成器(根据设备连接图)生成action步骤
|
||||
goal = goal_handle.request
|
||||
protocol_kwargs = convert_from_ros_msg_with_mapping(goal, action_value_mapping["goal"])
|
||||
|
||||
from unilabos.resources.graphio import physical_setup_graph
|
||||
self.get_logger().info(f'Working on physical setup: {physical_setup_graph}')
|
||||
protocol_steps = protocol_steps_generator(G=physical_setup_graph, **protocol_kwargs)
|
||||
# 向Host查询物料当前状态
|
||||
for k, v in goal.get_fields_and_field_types().items():
|
||||
if v in ["unilabos_msgs/Resource", "sequence<unilabos_msgs/Resource>"]:
|
||||
r = ResourceGet.Request()
|
||||
resource_id = (
|
||||
protocol_kwargs[k]["id"] if v == "unilabos_msgs/Resource" else protocol_kwargs[k][0]["id"]
|
||||
)
|
||||
r.id = resource_id
|
||||
r.with_children = True
|
||||
response = await self._resource_clients["resource_get"].call_async(r)
|
||||
protocol_kwargs[k] = list_to_nested_dict(
|
||||
[convert_from_ros_msg(rs) for rs in response.resources]
|
||||
)
|
||||
|
||||
self.get_logger().info(f'Goal received: {protocol_kwargs}, running steps: \n{protocol_steps}')
|
||||
from unilabos.resources.graphio import physical_setup_graph
|
||||
|
||||
time_start = time.time()
|
||||
time_overall = 100
|
||||
self._busy = True
|
||||
self.lab_logger().info(f"Working on physical setup: {physical_setup_graph}")
|
||||
protocol_steps = protocol_steps_generator(G=physical_setup_graph, **protocol_kwargs)
|
||||
|
||||
# 逐步执行工作流
|
||||
for i, action in enumerate(protocol_steps):
|
||||
self.get_logger().info(f'Running step {i+1}: {action}')
|
||||
if type(action) == dict:
|
||||
# 如果是单个动作,直接执行
|
||||
if action["action_name"] == "wait":
|
||||
time.sleep(action["action_kwargs"]["time"])
|
||||
else:
|
||||
result = await self.execute_single_action(**action)
|
||||
elif type(action) == list:
|
||||
# 如果是并行动作,同时执行
|
||||
actions = action
|
||||
futures = [rclpy.get_global_executor().create_task(self.execute_single_action(**a)) for a in actions]
|
||||
results = [await f for f in futures]
|
||||
self.lab_logger().info(f"Goal received: {protocol_kwargs}, running steps: \n{protocol_steps}")
|
||||
|
||||
# 向Host更新物料当前状态
|
||||
for k, v in goal.get_fields_and_field_types().items():
|
||||
if v in ["unilabos_msgs/Resource", "sequence<unilabos_msgs/Resource>"]:
|
||||
r = ResourceUpdate.Request()
|
||||
r.resources = [
|
||||
convert_to_ros_msg(Resource, rs) for rs in nested_dict_to_list(protocol_kwargs[k])
|
||||
]
|
||||
response = await self._resource_clients["resource_update"].call_async(r)
|
||||
time_start = time.time()
|
||||
time_overall = 100
|
||||
self._busy = True
|
||||
|
||||
goal_handle.succeed()
|
||||
# 逐步执行工作流
|
||||
step_results = []
|
||||
for i, action in enumerate(protocol_steps):
|
||||
self.get_logger().info(f"Running step {i + 1}: {action}")
|
||||
if isinstance(action, dict):
|
||||
# 如果是单个动作,直接执行
|
||||
if action["action_name"] == "wait":
|
||||
time.sleep(action["action_kwargs"]["time"])
|
||||
step_results.append({"step": i + 1, "action": "wait", "result": "completed"})
|
||||
else:
|
||||
result = await self.execute_single_action(**action)
|
||||
step_results.append({"step": i + 1, "action": action["action_name"], "result": result})
|
||||
elif isinstance(action, list):
|
||||
# 如果是并行动作,同时执行
|
||||
actions = action
|
||||
futures = [
|
||||
rclpy.get_global_executor().create_task(self.execute_single_action(**a)) for a in actions
|
||||
]
|
||||
results = [await f for f in futures]
|
||||
step_results.append(
|
||||
{
|
||||
"step": i + 1,
|
||||
"parallel_actions": [a["action_name"] for a in actions],
|
||||
"results": results,
|
||||
}
|
||||
)
|
||||
|
||||
# 向Host更新物料当前状态
|
||||
for k, v in goal.get_fields_and_field_types().items():
|
||||
if v in ["unilabos_msgs/Resource", "sequence<unilabos_msgs/Resource>"]:
|
||||
r = ResourceUpdate.Request()
|
||||
r.resources = [
|
||||
convert_to_ros_msg(Resource, rs) for rs in nested_dict_to_list(protocol_kwargs[k])
|
||||
]
|
||||
response = await self._resource_clients["resource_update"].call_async(r)
|
||||
|
||||
# 设置成功状态和返回值
|
||||
execution_success = True
|
||||
protocol_return_value = {
|
||||
"protocol_name": protocol_name,
|
||||
"steps_executed": len(protocol_steps),
|
||||
"step_results": step_results,
|
||||
"total_time": time.time() - time_start,
|
||||
}
|
||||
|
||||
goal_handle.succeed()
|
||||
|
||||
except Exception as e:
|
||||
# 捕获并记录错误信息
|
||||
execution_error = traceback.format_exc()
|
||||
execution_success = False
|
||||
error(f"协议 {protocol_name} 执行失败")
|
||||
error(traceback.format_exc())
|
||||
self.lab_logger().error(f"协议执行出错: {str(e)}")
|
||||
|
||||
# 设置动作失败
|
||||
goal_handle.abort()
|
||||
|
||||
finally:
|
||||
self._busy = False
|
||||
|
||||
# 创建结果消息
|
||||
result = action_value_mapping["type"].Result()
|
||||
result.success = True
|
||||
result.success = execution_success
|
||||
|
||||
self._busy = False
|
||||
# 获取结果消息类型信息,检查是否有return_info字段
|
||||
result_msg_types = action_value_mapping["type"].Result.get_fields_and_field_types()
|
||||
|
||||
# 设置return_info字段(如果存在)
|
||||
for attr_name in result_msg_types.keys():
|
||||
if attr_name in ["success", "reached_goal"]:
|
||||
setattr(result, attr_name, execution_success)
|
||||
elif attr_name == "return_info":
|
||||
setattr(
|
||||
result,
|
||||
attr_name,
|
||||
serialize_result_info(execution_error, execution_success, protocol_return_value),
|
||||
)
|
||||
|
||||
self.lab_logger().info(f"协议 {protocol_name} 完成并返回结果")
|
||||
return result
|
||||
|
||||
return execute_protocol
|
||||
|
||||
async def execute_single_action(self, device_id, action_name, action_kwargs):
|
||||
@@ -241,14 +320,19 @@ class ROS2ProtocolNode(BaseROS2DeviceNode):
|
||||
|
||||
return result_future.result
|
||||
|
||||
|
||||
"""还没有改过的部分"""
|
||||
|
||||
def _setup_hardware_proxy(self, device: ROS2DeviceNode, communication_device: ROS2DeviceNode, read_method, write_method):
|
||||
def _setup_hardware_proxy(
|
||||
self, device: ROS2DeviceNode, communication_device: ROS2DeviceNode, read_method, write_method
|
||||
):
|
||||
"""为设备设置硬件接口代理"""
|
||||
# extra_info = [getattr(device.driver_instance, info) for info in communication_device.ros_node_instance._hardware_interface.get("extra_info", [])]
|
||||
write_func = getattr(communication_device.driver_instance, communication_device.ros_node_instance._hardware_interface["write"])
|
||||
read_func = getattr(communication_device.driver_instance, communication_device.ros_node_instance._hardware_interface["read"])
|
||||
write_func = getattr(
|
||||
communication_device.driver_instance, communication_device.ros_node_instance._hardware_interface["write"]
|
||||
)
|
||||
read_func = getattr(
|
||||
communication_device.driver_instance, communication_device.ros_node_instance._hardware_interface["read"]
|
||||
)
|
||||
|
||||
def _read(*args, **kwargs):
|
||||
return read_func(*args, **kwargs)
|
||||
@@ -264,7 +348,6 @@ class ROS2ProtocolNode(BaseROS2DeviceNode):
|
||||
# bound_write = MethodType(_write, device.driver_instance)
|
||||
setattr(device.driver_instance, write_method, _write)
|
||||
|
||||
|
||||
async def _update_resources(self, goal, protocol_kwargs):
|
||||
"""更新资源状态"""
|
||||
for k, v in goal.get_fields_and_field_types().items():
|
||||
|
||||
Reference in New Issue
Block a user