mirror of
https://github.com/dptech-corp/Uni-Lab-OS.git
synced 2025-12-17 04:51:10 +00:00
Initial commit
This commit is contained in:
0
unilabos/ros/__init__.py
Normal file
0
unilabos/ros/__init__.py
Normal file
86
unilabos/ros/device_node_wrapper.py
Normal file
86
unilabos/ros/device_node_wrapper.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from typing import Dict, Any, Optional, Type, TypeVar
|
||||
|
||||
from unilabos.ros.msgs.message_converter import (
|
||||
get_msg_type,
|
||||
get_action_type,
|
||||
)
|
||||
from unilabos.ros.nodes.base_device_node import init_wrapper, ROS2DeviceNode
|
||||
|
||||
# 定义泛型类型变量
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# noinspection PyMissingConstructor
|
||||
class ROS2DeviceNodeWrapper(ROS2DeviceNode):
|
||||
def __init__(self, device_id: str, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def ros2_device_node(
|
||||
cls: Type[T],
|
||||
status_types: Optional[Dict[str, Any]] = None,
|
||||
action_value_mappings: Optional[Dict[str, Any]] = None,
|
||||
hardware_interface: Optional[Dict[str, Any]] = None,
|
||||
print_publish: bool = False,
|
||||
children: Optional[Dict[str, Any]] = None,
|
||||
) -> Type[ROS2DeviceNodeWrapper]:
|
||||
"""Create a ROS2 Node class for a device class with properties and actions.
|
||||
|
||||
Args:
|
||||
cls: 要封装的设备类
|
||||
status_types: 需要发布的状态和传感器信息,每个(PROP: TYPE),PROP应该匹配cls.PROP或cls.get_PROP(),
|
||||
TYPE应该是ROS2消息类型。默认为{}。
|
||||
action_value_mappings: 设备动作。默认为{}。
|
||||
每个(ACTION: {'type': CMD_TYPE, 'goal': {FIELD: PROP}, 'feedback': {FIELD: PROP}, 'result': {FIELD: PROP}}),
|
||||
hardware_interface: 硬件接口配置。默认为{"name": "hardware_interface", "write": "send_command", "read": "read_data", "extra_info": []}。
|
||||
print_publish: 是否打印发布信息。默认为False。
|
||||
children: 物料/子节点信息。
|
||||
|
||||
Returns:
|
||||
Type: 封装了设备类的ROS2节点类。
|
||||
"""
|
||||
# 从属性中自动发现可发布状态
|
||||
if status_types is None:
|
||||
status_types = {}
|
||||
if action_value_mappings is None:
|
||||
action_value_mappings = {}
|
||||
if hardware_interface is None:
|
||||
hardware_interface = {
|
||||
"name": "hardware_interface",
|
||||
"write": "send_command",
|
||||
"read": "read_data",
|
||||
"extra_info": [],
|
||||
}
|
||||
|
||||
for k, v in cls.__dict__.items():
|
||||
if not k.startswith("_") and isinstance(v, property):
|
||||
# noinspection PyUnresolvedReferences
|
||||
property_type = v.fget.__annotations__.get("return", str)
|
||||
get_method_name = f"get_{k}"
|
||||
set_method_name = f"set_{k}"
|
||||
|
||||
if k not in status_types and hasattr(cls, get_method_name):
|
||||
status_types[k] = get_msg_type(property_type)
|
||||
|
||||
if f"set_{k}" not in action_value_mappings and hasattr(cls, set_method_name):
|
||||
action_value_mappings[f"set_{k}"] = get_action_type(property_type)
|
||||
# 创建一个包装类来返回ROS2DeviceNode
|
||||
wrapper_class_name = f"ROS2NodeWrapper4{cls.__name__}"
|
||||
ROS2DeviceNodeWrapper = type(
|
||||
wrapper_class_name,
|
||||
(ROS2DeviceNode,),
|
||||
{
|
||||
"__init__": lambda self, *args, **kwargs: init_wrapper(
|
||||
self,
|
||||
driver_class=cls,
|
||||
status_types=status_types,
|
||||
action_value_mappings=action_value_mappings,
|
||||
hardware_interface=hardware_interface,
|
||||
print_publish=print_publish,
|
||||
children=children,
|
||||
*args,
|
||||
**kwargs,
|
||||
),
|
||||
},
|
||||
)
|
||||
return ROS2DeviceNodeWrapper
|
||||
51
unilabos/ros/initialize_device.py
Normal file
51
unilabos/ros/initialize_device.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from typing import Optional
|
||||
from unilabos.registry.registry import lab_registry
|
||||
from unilabos.ros.nodes.base_device_node import ROS2DeviceNode, DeviceInitError
|
||||
from unilabos.ros.device_node_wrapper import ros2_device_node
|
||||
from unilabos.utils import logger
|
||||
from unilabos.utils.import_manager import default_manager
|
||||
|
||||
|
||||
def initialize_device_from_dict(device_id, device_config) -> Optional[ROS2DeviceNode]:
|
||||
"""Initializes a device based on its configuration.
|
||||
|
||||
This function dynamically imports the appropriate device class and creates an instance of it using the provided device configuration.
|
||||
It also sets up action clients for the device based on its action value mappings.
|
||||
|
||||
Args:
|
||||
device_id (str): The unique identifier for the device.
|
||||
device_config (dict): The configuration dictionary for the device, which includes the class type and other parameters.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
d = None
|
||||
device_class_config = device_config["class"]
|
||||
if isinstance(device_class_config, str): # 如果是字符串,则直接去lab_registry中查找,获取class
|
||||
if device_class_config not in lab_registry.device_type_registry:
|
||||
raise ValueError(f"Device class {device_class_config} not found.")
|
||||
device_class_config = device_config["class"] = lab_registry.device_type_registry[device_class_config]["class"]
|
||||
if isinstance(device_class_config, dict):
|
||||
DEVICE = default_manager.get_class(device_class_config["module"])
|
||||
# 不管是ros2的实例,还是python的,都必须包一次,除了HostNode
|
||||
DEVICE = ros2_device_node(
|
||||
DEVICE,
|
||||
status_types=device_class_config.get("status_types", {}),
|
||||
action_value_mappings=device_class_config.get("action_value_mappings", {}),
|
||||
hardware_interface=device_class_config.get(
|
||||
"hardware_interface",
|
||||
{"name": "hardware_interface", "write": "send_command", "read": "read_data", "extra_info": []},
|
||||
),
|
||||
children=device_config.get("children", {})
|
||||
)
|
||||
try:
|
||||
d = DEVICE(
|
||||
device_id=device_id, driver_is_ros=device_class_config["type"] == "ros2", driver_params=device_config.get("config", {})
|
||||
)
|
||||
except DeviceInitError as ex:
|
||||
return d
|
||||
else:
|
||||
logger.warning(f"initialize device {device_id} failed, provided device_config: {device_config}")
|
||||
return d
|
||||
124
unilabos/ros/main_slave_run.py
Normal file
124
unilabos/ros/main_slave_run.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import os
|
||||
import traceback
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
import rclpy
|
||||
from unilabos_msgs.msg import Resource # type: ignore
|
||||
from unilabos_msgs.srv import ResourceAdd # type: ignore
|
||||
from rclpy.executors import MultiThreadedExecutor
|
||||
from rclpy.node import Node
|
||||
from rclpy.timer import Timer
|
||||
|
||||
from unilabos.ros.initialize_device import initialize_device_from_dict
|
||||
from unilabos.ros.msgs.message_converter import (
|
||||
convert_to_ros_msg,
|
||||
)
|
||||
from unilabos.ros.nodes.presets.host_node import HostNode
|
||||
from unilabos.ros.x.rclpyx import run_event_loop_in_thread
|
||||
from unilabos.utils import logger
|
||||
from unilabos.config.config import BasicConfig
|
||||
|
||||
|
||||
def exit() -> None:
|
||||
"""关闭ROS节点和资源"""
|
||||
host_instance = HostNode.get_instance()
|
||||
if host_instance is not None:
|
||||
# 停止发现定时器
|
||||
# noinspection PyProtectedMember
|
||||
if hasattr(host_instance, "_discovery_timer") and isinstance(host_instance._discovery_timer, Timer):
|
||||
# noinspection PyProtectedMember
|
||||
host_instance._discovery_timer.cancel()
|
||||
for _, device_node in host_instance.devices_instances.items():
|
||||
if hasattr(device_node, "destroy_node"):
|
||||
device_node.ros_node_instance.destroy_node()
|
||||
host_instance.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
def main(
|
||||
devices_config: Dict[str, Any] = {},
|
||||
resources_config={},
|
||||
graph: Optional[Dict[str, Any]] = None,
|
||||
controllers_config: Dict[str, Any] = {},
|
||||
bridges: List[Any] = [],
|
||||
args: List[str] = ["--log-level", "debug"],
|
||||
discovery_interval: float = 5.0,
|
||||
) -> None:
|
||||
"""主函数"""
|
||||
rclpy.init(args=args)
|
||||
rclpy.__executor = executor = MultiThreadedExecutor()
|
||||
|
||||
# 创建主机节点
|
||||
host_node = HostNode(
|
||||
"host_node",
|
||||
devices_config,
|
||||
resources_config,
|
||||
graph,
|
||||
controllers_config,
|
||||
bridges,
|
||||
discovery_interval,
|
||||
)
|
||||
|
||||
executor.add_node(host_node)
|
||||
# run_event_loop_in_thread()
|
||||
|
||||
try:
|
||||
executor.spin()
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
print(f"Exception caught: {e}")
|
||||
finally:
|
||||
exit()
|
||||
|
||||
|
||||
def slave(
|
||||
devices_config: Dict[str, Any] = {},
|
||||
resources_config=[],
|
||||
graph: Optional[Dict[str, Any]] = None,
|
||||
controllers_config: Dict[str, Any] = {},
|
||||
bridges: List[Any] = [],
|
||||
args: List[str] = ["--log-level", "debug"],
|
||||
) -> None:
|
||||
"""从节点函数"""
|
||||
rclpy.init(args=args)
|
||||
rclpy.__executor = executor = MultiThreadedExecutor()
|
||||
|
||||
for device_id, device_config in devices_config.items():
|
||||
d = initialize_device_from_dict(device_id, device_config)
|
||||
if d is None:
|
||||
continue
|
||||
# 默认初始化
|
||||
# if d is not None and isinstance(d, Node):
|
||||
# executor.add_node(d)
|
||||
# else:
|
||||
# print(f"Warning: Device {device_id} could not be initialized or is not a valid Node")
|
||||
|
||||
machine_name = os.popen("hostname").read().strip()
|
||||
machine_name = "".join([c if c.isalnum() or c == "_" else "_" for c in machine_name])
|
||||
n = Node(f"slaveMachine_{machine_name}", parameter_overrides=[])
|
||||
executor.add_node(n)
|
||||
|
||||
if BasicConfig.slave_no_host:
|
||||
# 确保ResourceAdd存在
|
||||
if "ResourceAdd" in globals():
|
||||
rclient = n.create_client(ResourceAdd, "/resources/add")
|
||||
rclient.wait_for_service() # FIXME 可能一直等待,加一个参数
|
||||
|
||||
request = ResourceAdd.Request()
|
||||
request.resources = [convert_to_ros_msg(Resource, resource) for resource in resources_config]
|
||||
response = rclient.call_async(request)
|
||||
else:
|
||||
print("Warning: ResourceAdd service not available")
|
||||
|
||||
run_event_loop_in_thread()
|
||||
|
||||
try:
|
||||
executor.spin()
|
||||
except Exception as e:
|
||||
print(f"Exception caught: {e}")
|
||||
finally:
|
||||
exit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
unilabos/ros/msgs/__init__.py
Normal file
0
unilabos/ros/msgs/__init__.py
Normal file
789
unilabos/ros/msgs/message_converter.py
Normal file
789
unilabos/ros/msgs/message_converter.py
Normal file
@@ -0,0 +1,789 @@
|
||||
"""
|
||||
消息转换器
|
||||
|
||||
该模块提供了在Python对象(dataclass, Pydantic模型)和ROS消息类型之间进行转换的功能。
|
||||
使用ImportManager动态导入和管理所需模块。
|
||||
"""
|
||||
|
||||
import json
|
||||
import traceback
|
||||
from io import StringIO
|
||||
from typing import Iterable, Any, Dict, Type, TypeVar, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
from dataclasses import asdict, is_dataclass
|
||||
|
||||
from rosidl_parser.definition import UnboundedSequence, NamespacedType, BasicType, UnboundedString
|
||||
|
||||
from unilabos.utils import logger
|
||||
from unilabos.utils.import_manager import ImportManager
|
||||
from unilabos.config.config import ROSConfig
|
||||
|
||||
# 定义泛型类型
|
||||
T = TypeVar("T")
|
||||
DataClassT = TypeVar("DataClassT")
|
||||
|
||||
# 从配置中获取需要导入的模块列表
|
||||
ROS_MODULES = ROSConfig.modules
|
||||
|
||||
msg_converter_manager = ImportManager(ROS_MODULES)
|
||||
|
||||
|
||||
"""geometry_msgs"""
|
||||
Point = msg_converter_manager.get_class("geometry_msgs.msg:Point")
|
||||
Pose = msg_converter_manager.get_class("geometry_msgs.msg:Pose")
|
||||
"""std_msgs"""
|
||||
Float64 = msg_converter_manager.get_class("std_msgs.msg:Float64")
|
||||
Float64MultiArray = msg_converter_manager.get_class("std_msgs.msg:Float64MultiArray")
|
||||
Int32 = msg_converter_manager.get_class("std_msgs.msg:Int32")
|
||||
Int64 = msg_converter_manager.get_class("std_msgs.msg:Int64")
|
||||
String = msg_converter_manager.get_class("std_msgs.msg:String")
|
||||
Bool = msg_converter_manager.get_class("std_msgs.msg:Bool")
|
||||
"""nav2_msgs"""
|
||||
NavigateToPose = msg_converter_manager.get_class("nav2_msgs.action:NavigateToPose")
|
||||
NavigateThroughPoses = msg_converter_manager.get_class("nav2_msgs.action:NavigateThroughPoses")
|
||||
SingleJointPosition = msg_converter_manager.get_class("control_msgs.action:SingleJointPosition")
|
||||
"""unilabos_msgs"""
|
||||
Resource = msg_converter_manager.get_class("unilabos_msgs.msg:Resource")
|
||||
SendCmd = msg_converter_manager.get_class("unilabos_msgs.action:SendCmd")
|
||||
"""unilabos"""
|
||||
imsg = msg_converter_manager.get_module("unilabos.messages")
|
||||
Point3D = msg_converter_manager.get_class("unilabos.messages:Point3D")
|
||||
|
||||
# 基本消息类型映射
|
||||
_msg_mapping: Dict[Type, Type] = {
|
||||
float: Float64,
|
||||
list[float]: Float64MultiArray,
|
||||
int: Int32,
|
||||
str: String,
|
||||
bool: Bool,
|
||||
Point3D: Point,
|
||||
}
|
||||
|
||||
# Action类型映射
|
||||
_action_mapping: Dict[Type, Dict[str, Any]] = {
|
||||
float: {
|
||||
"type": SingleJointPosition,
|
||||
"goal": {"position": "position", "max_velocity": "max_velocity"},
|
||||
"feedback": {"position": "position"},
|
||||
"result": {},
|
||||
},
|
||||
str: {
|
||||
"type": SendCmd,
|
||||
"goal": {"command": "position"},
|
||||
"feedback": {"status": "status"},
|
||||
"result": {},
|
||||
},
|
||||
Point3D: {
|
||||
"type": NavigateToPose,
|
||||
"goal": {"pose.pose.position": "position"},
|
||||
"feedback": {
|
||||
"current_pose.pose.position": "position",
|
||||
"navigation_time.sec": "time_spent",
|
||||
"estimated_time_remaining.sec": "time_remaining",
|
||||
},
|
||||
"result": {},
|
||||
},
|
||||
list[Point3D]: {
|
||||
"type": NavigateThroughPoses,
|
||||
"goal": {"poses[].pose.position": "positions[]"},
|
||||
"feedback": {
|
||||
"current_pose.pose.position": "position",
|
||||
"navigation_time.sec": "time_spent",
|
||||
"estimated_time_remaining.sec": "time_remaining",
|
||||
"number_of_poses_remaining": "pose_number_remaining",
|
||||
},
|
||||
"result": {},
|
||||
},
|
||||
}
|
||||
|
||||
# 添加Protocol action类型到映射
|
||||
for py_msgtype in imsg.__all__:
|
||||
if py_msgtype not in _action_mapping and py_msgtype.endswith("Protocol"):
|
||||
try:
|
||||
protocol_class = msg_converter_manager.get_class(f"unilabos.messages.{py_msgtype}")
|
||||
action_name = py_msgtype.replace("Protocol", "")
|
||||
action_type = msg_converter_manager.get_class(f"unilabos_msgs.action.{action_name}")
|
||||
|
||||
if action_type:
|
||||
_action_mapping[protocol_class] = {
|
||||
"type": action_type,
|
||||
"goal": {k: k for k in action_type.Goal().get_fields_and_field_types().keys()},
|
||||
"feedback": {
|
||||
(k if "time" not in k else f"{k}.sec"): k
|
||||
for k in action_type.Feedback().get_fields_and_field_types().keys()
|
||||
},
|
||||
"result": {k: k for k in action_type.Result().get_fields_and_field_types().keys()},
|
||||
}
|
||||
except Exception:
|
||||
logger.debug(f"Failed to load Protocol class: {py_msgtype}")
|
||||
|
||||
# Python到ROS消息转换器
|
||||
_msg_converter: Dict[Type, Any] = {
|
||||
float: float,
|
||||
Float64: lambda x: Float64(data=float(x)),
|
||||
Float64MultiArray: lambda x: Float64MultiArray(data=[float(i) for i in x]),
|
||||
int: int,
|
||||
Int32: lambda x: Int32(data=int(x)),
|
||||
Int64: lambda x: Int64(data=int(x)),
|
||||
bool: bool,
|
||||
Bool: lambda x: Bool(data=bool(x)),
|
||||
str: str,
|
||||
String: lambda x: String(data=str(x)),
|
||||
Point: lambda x: Point(x=x.x, y=x.y, z=x.z),
|
||||
Resource: lambda x: Resource(
|
||||
id=x["id"],
|
||||
name=x["name"],
|
||||
sample_id=x.get("sample_id", "") or "",
|
||||
children=list(x.get("children", [])),
|
||||
parent=x.get("parent", "") or "",
|
||||
type=x["type"],
|
||||
category=x.get("class", "") or x["type"],
|
||||
pose=(
|
||||
Pose(position=Point(x=float(x["position"]["x"]), y=float(x["position"]["y"]), z=float(x["position"]["z"])))
|
||||
if x.get("position", None) is not None
|
||||
else Pose()
|
||||
),
|
||||
config=json.dumps(x.get("config", {})),
|
||||
data=json.dumps(x.get("data", {})),
|
||||
),
|
||||
}
|
||||
|
||||
def json_or_yaml_loads(data: str) -> Any:
|
||||
try:
|
||||
return json.loads(data)
|
||||
except Exception as e:
|
||||
try:
|
||||
return yaml.safe_load(StringIO(data))
|
||||
except:
|
||||
pass
|
||||
raise e
|
||||
|
||||
# ROS消息到Python转换器
|
||||
_msg_converter_back: Dict[Type, Any] = {
|
||||
float: float,
|
||||
Float64: lambda x: x.data,
|
||||
Float64MultiArray: lambda x: x.data,
|
||||
int: int,
|
||||
Int32: lambda x: x.data,
|
||||
Int64: lambda x: x.data,
|
||||
bool: bool,
|
||||
Bool: lambda x: x.data,
|
||||
str: str,
|
||||
String: lambda x: x.data,
|
||||
Point: lambda x: Point3D(x=x.x, y=x.y, z=x.z),
|
||||
Resource: lambda x: {
|
||||
"id": x.id,
|
||||
"name": x.name,
|
||||
"sample_id": x.sample_id if x.sample_id else None,
|
||||
"children": list(x.children),
|
||||
"parent": x.parent if x.parent else None,
|
||||
"type": x.type,
|
||||
"class": x.category,
|
||||
"position": {"x": x.pose.position.x, "y": x.pose.position.y, "z": x.pose.position.z},
|
||||
"config": json_or_yaml_loads(x.config or "{}"),
|
||||
"data": json_or_yaml_loads(x.data or "{}"),
|
||||
},
|
||||
}
|
||||
|
||||
# 消息数据类型映射
|
||||
_msg_data_mapping: Dict[str, Type] = {
|
||||
"double": float,
|
||||
"float": float,
|
||||
"int": int,
|
||||
"bool": bool,
|
||||
"str": str,
|
||||
}
|
||||
|
||||
|
||||
def compare_model_fields(cls1: Any, cls2: Any) -> bool:
|
||||
"""比较两个类的字段是否相同"""
|
||||
|
||||
def get_class_fields(cls: Any) -> set:
|
||||
if hasattr(cls, "__annotations__"):
|
||||
return set(cls.__annotations__.keys())
|
||||
else:
|
||||
return set(cls.__dict__.keys())
|
||||
|
||||
fields1 = get_class_fields(cls1)
|
||||
fields2 = get_class_fields(cls2)
|
||||
return fields1 == fields2
|
||||
|
||||
|
||||
def get_msg_type(datatype: Type) -> Type:
|
||||
"""
|
||||
获取与Python数据类型对应的ROS消息类型
|
||||
|
||||
Args:
|
||||
datatype: Python数据类型、Pydantic模型或dataclass
|
||||
|
||||
Returns:
|
||||
对应的ROS消息类型
|
||||
|
||||
Raises:
|
||||
ValueError: 如果不支持的消息类型
|
||||
"""
|
||||
# 直接匹配已知类型
|
||||
if isinstance(datatype, type) and datatype in _msg_mapping:
|
||||
return _msg_mapping[datatype]
|
||||
|
||||
# 尝试通过字段比较匹配
|
||||
for k, v in _msg_mapping.items():
|
||||
if compare_model_fields(k, datatype):
|
||||
return v
|
||||
|
||||
raise ValueError(f"Unsupported message type: {datatype}")
|
||||
|
||||
|
||||
def get_action_type(datatype: Type) -> Dict[str, Any]:
|
||||
"""
|
||||
获取与Python数据类型对应的ROS动作类型
|
||||
|
||||
Args:
|
||||
datatype: Python数据类型、Pydantic模型或dataclass
|
||||
|
||||
Returns:
|
||||
对应的ROS动作类型配置
|
||||
|
||||
Raises:
|
||||
ValueError: 如果不支持的动作类型
|
||||
"""
|
||||
# 直接匹配已知类型
|
||||
if isinstance(datatype, type) and datatype in _action_mapping:
|
||||
return _action_mapping[datatype]
|
||||
|
||||
# 尝试通过字段比较匹配
|
||||
for k, v in _action_mapping.items():
|
||||
if compare_model_fields(k, datatype):
|
||||
return v
|
||||
|
||||
raise ValueError(f"Unsupported action type: {datatype}")
|
||||
|
||||
|
||||
def get_ros_type_by_msgname(msgname: str) -> Type:
|
||||
"""
|
||||
通过消息名称获取ROS类型
|
||||
|
||||
Args:
|
||||
msgname: ROS消息名称,格式为 'package_name/(action,msg,srv)/TypeName'
|
||||
|
||||
Returns:
|
||||
对应的ROS类型
|
||||
|
||||
Raises:
|
||||
ValueError: 如果无效的ROS消息名称
|
||||
ImportError: 如果无法加载类型
|
||||
"""
|
||||
parts = msgname.split("/")
|
||||
if len(parts) != 3 or parts[1] not in ("action", "msg", "srv"):
|
||||
raise ValueError(
|
||||
f"Invalid ROS message name: {msgname}. Format should be 'package_name/(action,msg,srv)/TypeName'"
|
||||
)
|
||||
|
||||
package_name, msg_type, type_name = parts
|
||||
full_module_path = f"{package_name}.{msg_type}"
|
||||
|
||||
try:
|
||||
# 尝试通过ImportManager获取
|
||||
return msg_converter_manager.get_class(f"{full_module_path}.{type_name}")
|
||||
except KeyError:
|
||||
# 尝试动态导入
|
||||
try:
|
||||
msg_converter_manager.load_module(full_module_path)
|
||||
return msg_converter_manager.get_class(f"{full_module_path}.{type_name}")
|
||||
except Exception as e:
|
||||
raise ImportError(f"Failed to load type {type_name}. Make sure the package is installed.") from e
|
||||
|
||||
|
||||
def _extract_data(obj: Any) -> Dict[str, Any]:
|
||||
"""提取对象数据为字典"""
|
||||
if is_dataclass(obj) and not isinstance(obj, type) and hasattr(obj, "__dataclass_fields__"):
|
||||
return asdict(obj)
|
||||
elif isinstance(obj, BaseModel):
|
||||
return obj.model_dump()
|
||||
elif isinstance(obj, dict):
|
||||
return obj
|
||||
else:
|
||||
return {"data": obj}
|
||||
|
||||
|
||||
def convert_to_ros_msg(ros_msg_type: Union[Type, Any], obj: Any) -> Any:
|
||||
"""
|
||||
将Python对象转换为ROS消息实例
|
||||
|
||||
Args:
|
||||
ros_msg_type: 目标ROS消息类型
|
||||
obj: Python对象(基本类型、dataclass或Pydantic实例)
|
||||
|
||||
Returns:
|
||||
ROS消息实例
|
||||
"""
|
||||
# 尝试使用预定义转换器
|
||||
try:
|
||||
if isinstance(ros_msg_type, type) and ros_msg_type in _msg_converter:
|
||||
return _msg_converter[ros_msg_type](obj)
|
||||
except Exception as e:
|
||||
logger.error(f"Converter error: {type(ros_msg_type)} -> {obj}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 创建ROS消息实例
|
||||
ros_msg = ros_msg_type() if isinstance(ros_msg_type, type) else ros_msg_type
|
||||
|
||||
# 提取数据
|
||||
data = _extract_data(obj)
|
||||
|
||||
# 转换数据到ROS消息
|
||||
for key, value in data.items():
|
||||
if hasattr(ros_msg, key):
|
||||
attr = getattr(ros_msg, key)
|
||||
if isinstance(attr, (float, int, str, bool)):
|
||||
setattr(ros_msg, key, value)
|
||||
elif isinstance(attr, (list, tuple)) and isinstance(value, Iterable):
|
||||
setattr(ros_msg, key, list(value))
|
||||
else:
|
||||
nested_ros_msg = convert_to_ros_msg(type(attr)(), value)
|
||||
setattr(ros_msg, key, nested_ros_msg)
|
||||
else:
|
||||
# 跳过不存在的字段,防止报错
|
||||
continue
|
||||
|
||||
return ros_msg
|
||||
|
||||
|
||||
def convert_to_ros_msg_with_mapping(ros_msg_type: Type, obj: Any, value_mapping: Dict[str, str]) -> Any:
|
||||
"""
|
||||
根据字段映射将Python对象转换为ROS消息
|
||||
|
||||
Args:
|
||||
ros_msg_type: 目标ROS消息类型
|
||||
obj: Python对象
|
||||
value_mapping: 字段名映射关系字典
|
||||
|
||||
Returns:
|
||||
ROS消息实例
|
||||
"""
|
||||
# 创建ROS消息实例
|
||||
ros_msg = ros_msg_type() if isinstance(ros_msg_type, type) else ros_msg_type
|
||||
|
||||
# 提取数据
|
||||
data = _extract_data(obj)
|
||||
|
||||
# 按照映射关系处理每个字段
|
||||
for msg_name, attr_name in value_mapping.items():
|
||||
msg_path = msg_name.split(".")
|
||||
attr_base = attr_name.rstrip("[]")
|
||||
|
||||
if attr_base not in data:
|
||||
continue
|
||||
|
||||
value = data[attr_base]
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
if not attr_name.endswith("[]"):
|
||||
# 处理单值映射,如 {"pose.position": "position"}
|
||||
current = ros_msg
|
||||
for i, name in enumerate(msg_path[:-1]):
|
||||
current = getattr(current, name)
|
||||
|
||||
last_field = msg_path[-1]
|
||||
field_type = type(getattr(current, last_field))
|
||||
setattr(current, last_field, convert_to_ros_msg(field_type, value))
|
||||
else:
|
||||
# 处理列表值映射,如 {"poses[].position": "positions[]"}
|
||||
if not isinstance(value, Iterable) or isinstance(value, (str, dict)):
|
||||
continue
|
||||
|
||||
items = list(value)
|
||||
if not items:
|
||||
continue
|
||||
|
||||
# 仅支持简单路径的数组映射
|
||||
if len(msg_path) <= 2:
|
||||
array_field = msg_path[0]
|
||||
if hasattr(ros_msg, array_field):
|
||||
if len(msg_path) == 1:
|
||||
# 直接设置数组
|
||||
setattr(ros_msg, array_field, items)
|
||||
else:
|
||||
# 设置数组元素的属性
|
||||
target_field = msg_path[1]
|
||||
array_items = getattr(ros_msg, array_field)
|
||||
|
||||
# 确保数组大小匹配
|
||||
while len(array_items) < len(items):
|
||||
# 添加新元素类型
|
||||
if array_items:
|
||||
elem_type = type(array_items[0])
|
||||
array_items.append(elem_type())
|
||||
else:
|
||||
# 无法确定元素类型时中断
|
||||
break
|
||||
|
||||
# 设置每个元素的属性
|
||||
for i, val in enumerate(items):
|
||||
if i < len(array_items):
|
||||
setattr(array_items[i], target_field, val)
|
||||
except Exception as e:
|
||||
# 忽略映射错误
|
||||
logger.debug(f"Mapping error for {msg_name} -> {attr_name}: {str(e)}")
|
||||
continue
|
||||
|
||||
return ros_msg
|
||||
|
||||
|
||||
def convert_from_ros_msg(msg: Any) -> Any:
|
||||
"""
|
||||
将ROS消息对象递归转换为Python对象
|
||||
|
||||
Args:
|
||||
msg: ROS消息实例
|
||||
|
||||
Returns:
|
||||
Python对象(字典或基本类型)
|
||||
"""
|
||||
# 使用预定义转换器
|
||||
if type(msg) in _msg_converter_back:
|
||||
return _msg_converter_back[type(msg)](msg)
|
||||
|
||||
# 处理标准ROS消息
|
||||
elif hasattr(msg, "__slots__") and hasattr(msg, "_fields_and_field_types"):
|
||||
result = {}
|
||||
for field in msg.__slots__:
|
||||
field_value = getattr(msg, field)
|
||||
field_name = field[1:] if field.startswith("_") else field
|
||||
result[field_name] = convert_from_ros_msg(field_value)
|
||||
return result
|
||||
|
||||
# 处理列表或元组
|
||||
elif isinstance(msg, (list, tuple)):
|
||||
return [convert_from_ros_msg(item) for item in msg]
|
||||
|
||||
# 返回基本类型
|
||||
else:
|
||||
return msg
|
||||
|
||||
|
||||
def convert_from_ros_msg_with_mapping(ros_msg: Any, value_mapping: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""
|
||||
根据字段映射将ROS消息转换为Python字典
|
||||
|
||||
Args:
|
||||
ros_msg: ROS消息实例
|
||||
value_mapping: 字段名映射关系字典
|
||||
|
||||
Returns:
|
||||
Python字典
|
||||
"""
|
||||
data: Dict[str, Any] = {}
|
||||
|
||||
for msg_name, attr_name in value_mapping.items():
|
||||
msg_path = msg_name.split(".")
|
||||
current = ros_msg
|
||||
|
||||
try:
|
||||
if not attr_name.endswith("[]"):
|
||||
# 处理单值映射
|
||||
for name in msg_path:
|
||||
current = getattr(current, name)
|
||||
data[attr_name] = convert_from_ros_msg(current)
|
||||
else:
|
||||
# 处理列表值映射
|
||||
for name in msg_path:
|
||||
if name.endswith("[]"):
|
||||
base_name = name[:-2]
|
||||
if hasattr(current, base_name):
|
||||
current = list(getattr(current, base_name))
|
||||
else:
|
||||
current = []
|
||||
break
|
||||
else:
|
||||
if isinstance(current, list):
|
||||
next_level = []
|
||||
for item in current:
|
||||
if hasattr(item, name):
|
||||
next_level.append(getattr(item, name))
|
||||
current = next_level
|
||||
elif hasattr(current, name):
|
||||
current = getattr(current, name)
|
||||
else:
|
||||
current = []
|
||||
break
|
||||
|
||||
attr_key = attr_name[:-2]
|
||||
if current:
|
||||
data[attr_key] = [convert_from_ros_msg(item) for item in current]
|
||||
except (AttributeError, TypeError):
|
||||
logger.debug(f"Mapping conversion error for {msg_name} -> {attr_name}")
|
||||
continue
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def set_msg_data(dtype_str: str, data: Any) -> Any:
|
||||
"""
|
||||
将数据转换为指定消息类型
|
||||
|
||||
Args:
|
||||
dtype_str: 消息类型字符串
|
||||
data: 要转换的数据
|
||||
|
||||
Returns:
|
||||
转换后的数据
|
||||
"""
|
||||
converter = _msg_data_mapping.get(dtype_str, str)
|
||||
return converter(data)
|
||||
|
||||
|
||||
"""
|
||||
ROS Action 到 JSON Schema 转换器
|
||||
|
||||
该模块提供了将 ROS Action 定义转换为 JSON Schema 的功能,
|
||||
用于规范化 Action 接口和生成文档。
|
||||
"""
|
||||
|
||||
import json
|
||||
import yaml
|
||||
from typing import Any, Dict, Type, Union, Optional
|
||||
|
||||
from unilabos.utils import logger
|
||||
from unilabos.utils.import_manager import ImportManager
|
||||
from unilabos.config.config import ROSConfig
|
||||
|
||||
basic_type_map = {
|
||||
'bool': {'type': 'boolean'},
|
||||
'int8': {'type': 'integer', 'minimum': -128, 'maximum': 127},
|
||||
'uint8': {'type': 'integer', 'minimum': 0, 'maximum': 255},
|
||||
'int16': {'type': 'integer', 'minimum': -32768, 'maximum': 32767},
|
||||
'uint16': {'type': 'integer', 'minimum': 0, 'maximum': 65535},
|
||||
'int32': {'type': 'integer', 'minimum': -2147483648, 'maximum': 2147483647},
|
||||
'uint32': {'type': 'integer', 'minimum': 0, 'maximum': 4294967295},
|
||||
'int64': {'type': 'integer'},
|
||||
'uint64': {'type': 'integer', 'minimum': 0},
|
||||
'double': {'type': 'number'},
|
||||
'float32': {'type': 'number'},
|
||||
'float64': {'type': 'number'},
|
||||
'string': {'type': 'string'},
|
||||
'char': {'type': 'string', 'maxLength': 1},
|
||||
'byte': {'type': 'integer', 'minimum': 0, 'maximum': 255},
|
||||
}
|
||||
|
||||
|
||||
def ros_field_type_to_json_schema(type_info: Type | str, slot_type: str=None) -> Dict[str, Any]:
|
||||
"""
|
||||
将 ROS 字段类型转换为 JSON Schema 类型定义
|
||||
|
||||
Args:
|
||||
type_info: ROS 类型
|
||||
slot_type: ROS 类型
|
||||
|
||||
Returns:
|
||||
对应的 JSON Schema 类型定义
|
||||
"""
|
||||
if isinstance(type_info, UnboundedSequence):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': ros_field_type_to_json_schema(type_info.value_type)
|
||||
}
|
||||
if isinstance(type_info, NamespacedType):
|
||||
cls_name = ".".join(type_info.namespaces) + ":" + type_info.name
|
||||
type_class = msg_converter_manager.get_class(cls_name)
|
||||
return ros_field_type_to_json_schema(type_class)
|
||||
elif isinstance(type_info, BasicType):
|
||||
return ros_field_type_to_json_schema(type_info.typename)
|
||||
elif isinstance(type_info, UnboundedString):
|
||||
return basic_type_map['string']
|
||||
elif isinstance(type_info, str):
|
||||
if type_info in basic_type_map:
|
||||
return basic_type_map[type_info]
|
||||
|
||||
# 处理时间和持续时间类型
|
||||
if type_info in ('time', 'duration', 'builtin_interfaces/Time', 'builtin_interfaces/Duration'):
|
||||
return {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'sec': {'type': 'integer', 'description': '秒'},
|
||||
'nanosec': {'type': 'integer', 'description': '纳秒'}
|
||||
},
|
||||
'required': ['sec', 'nanosec']
|
||||
}
|
||||
else:
|
||||
return ros_message_to_json_schema(type_info)
|
||||
# # 处理数组类型
|
||||
# if field_type.endswith('[]'):
|
||||
# item_type = field_type[:-2]
|
||||
# return {
|
||||
# 'type': 'array',
|
||||
# 'items': ros_field_type_to_json_schema(item_type)
|
||||
# }
|
||||
|
||||
|
||||
|
||||
# # 处理复杂类型(尝试加载并处理)
|
||||
# try:
|
||||
# # 如果它是一个完整的消息类型规范 (包名/msg/类型名)
|
||||
# if '/' in field_type:
|
||||
# msg_class = get_ros_type_by_msgname(field_type)
|
||||
# return ros_message_to_json_schema(msg_class)
|
||||
# else:
|
||||
# # 可能是相对引用或简单名称
|
||||
# return {'type': 'object', 'description': f'复合类型: {field_type}'}
|
||||
# except Exception as e:
|
||||
# # 如果无法解析,返回通用对象类型
|
||||
# logger.debug(f"无法解析类型 {field_type}: {str(e)}")
|
||||
# return {'type': 'object', 'description': f'未知类型: {field_type}'}
|
||||
|
||||
def ros_message_to_json_schema(msg_class: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
将 ROS 消息类转换为 JSON Schema
|
||||
|
||||
Args:
|
||||
msg_class: ROS 消息类
|
||||
|
||||
Returns:
|
||||
对应的 JSON Schema 定义
|
||||
"""
|
||||
schema = {
|
||||
'type': 'object',
|
||||
'properties': {},
|
||||
'required': []
|
||||
}
|
||||
|
||||
# 获取类名作为标题
|
||||
if hasattr(msg_class, '__name__'):
|
||||
schema['title'] = msg_class.__name__
|
||||
|
||||
# 获取消息的字段和字段类型
|
||||
try:
|
||||
for ind, slot_info in enumerate(msg_class._fields_and_field_types.items()):
|
||||
slot_name, slot_type = slot_info
|
||||
type_info = msg_class.SLOT_TYPES[ind]
|
||||
field_schema = ros_field_type_to_json_schema(type_info, slot_type)
|
||||
schema['properties'][slot_name] = field_schema
|
||||
schema['required'].append(slot_name)
|
||||
# if hasattr(msg_class, 'get_fields_and_field_types'):
|
||||
# fields_and_types = msg_class.get_fields_and_field_types()
|
||||
#
|
||||
# for field_name, field_type in fields_and_types.items():
|
||||
# # 将 ROS 字段类型转换为 JSON Schema
|
||||
# field_schema = ros_field_type_to_json_schema(field_type)
|
||||
#
|
||||
# schema['properties'][field_name] = field_schema
|
||||
# schema['required'].append(field_name)
|
||||
# elif hasattr(msg_class, '__slots__') and hasattr(msg_class, '_fields_and_field_types'):
|
||||
# # 直接从实例属性获取
|
||||
# for field_name in msg_class.__slots__:
|
||||
# # 移除前导下划线(如果有)
|
||||
# clean_name = field_name[1:] if field_name.startswith('_') else field_name
|
||||
#
|
||||
# # 从 _fields_and_field_types 获取类型
|
||||
# if clean_name in msg_class._fields_and_field_types:
|
||||
# field_type = msg_class._fields_and_field_types[clean_name]
|
||||
# field_schema = ros_field_type_to_json_schema(field_type)
|
||||
#
|
||||
# schema['properties'][clean_name] = field_schema
|
||||
# schema['required'].append(clean_name)
|
||||
except Exception as e:
|
||||
# 如果获取字段类型失败,添加错误信息
|
||||
schema['description'] = f"解析消息字段时出错: {str(e)}"
|
||||
logger.error(f"解析 {msg_class.__name__} 消息字段失败: {str(e)}")
|
||||
|
||||
return schema
|
||||
|
||||
def ros_action_to_json_schema(action_class: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
将 ROS Action 类转换为 JSON Schema
|
||||
|
||||
Args:
|
||||
action_class: ROS Action 类
|
||||
|
||||
Returns:
|
||||
完整的 JSON Schema 定义
|
||||
"""
|
||||
if not hasattr(action_class, 'Goal') or not hasattr(action_class, 'Feedback') or not hasattr(action_class, 'Result'):
|
||||
raise ValueError(f"{action_class.__name__} 不是有效的 ROS Action 类")
|
||||
|
||||
# 创建基础 schema
|
||||
schema = {
|
||||
'$schema': 'http://json-schema.org/draft-07/schema#',
|
||||
'title': action_class.__name__,
|
||||
'description': f"ROS Action {action_class.__name__} 的 JSON Schema",
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'goal': {
|
||||
'description': 'Action 目标 - 从客户端发送到服务器',
|
||||
**ros_message_to_json_schema(action_class.Goal)
|
||||
},
|
||||
'feedback': {
|
||||
'description': 'Action 反馈 - 执行过程中从服务器发送到客户端',
|
||||
**ros_message_to_json_schema(action_class.Feedback)
|
||||
},
|
||||
'result': {
|
||||
'description': 'Action 结果 - 完成后从服务器发送到客户端',
|
||||
**ros_message_to_json_schema(action_class.Result)
|
||||
}
|
||||
},
|
||||
'required': ['goal']
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
def convert_ros_action_to_jsonschema(
|
||||
action_name_or_type: Union[str, Type],
|
||||
output_file: Optional[str] = None,
|
||||
format: str = 'json'
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将 ROS Action 类型转换为 JSON Schema,并可选地保存到文件
|
||||
|
||||
Args:
|
||||
action_name_or_type: ROS Action 类型名称或类
|
||||
output_file: 可选,输出 JSON Schema 的文件路径
|
||||
format: 输出格式,'json' 或 'yaml'
|
||||
|
||||
Returns:
|
||||
JSON Schema 定义(字典)
|
||||
"""
|
||||
# 处理输入参数
|
||||
if isinstance(action_name_or_type, str):
|
||||
# 如果是字符串,尝试加载 Action 类型
|
||||
action_type = get_ros_type_by_msgname(action_name_or_type)
|
||||
else:
|
||||
action_type = action_name_or_type
|
||||
|
||||
# 生成 JSON Schema
|
||||
schema = ros_action_to_json_schema(action_type)
|
||||
|
||||
# 如果指定了输出文件,将 Schema 保存到文件
|
||||
if output_file:
|
||||
if format.lower() == 'json':
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(schema, f, indent=2, ensure_ascii=False)
|
||||
elif format.lower() == 'yaml':
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
yaml.safe_dump(schema, f, default_flow_style=False, allow_unicode=True)
|
||||
else:
|
||||
raise ValueError(f"不支持的格式: {format},请使用 'json' 或 'yaml'")
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
# 示例:转换 NavigateToPose action
|
||||
try:
|
||||
from nav2_msgs.action import NavigateToPose
|
||||
|
||||
# 转换为 JSON Schema 并打印
|
||||
schema = convert_ros_action_to_jsonschema(NavigateToPose)
|
||||
print(json.dumps(schema, indent=2, ensure_ascii=False))
|
||||
|
||||
# 保存到文件
|
||||
# convert_ros_action_to_jsonschema(NavigateToPose, "navigate_to_pose_schema.json")
|
||||
|
||||
# 或者使用字符串形式的 action 名称
|
||||
# schema = convert_ros_action_to_jsonschema("nav2_msgs/action/NavigateToPose")
|
||||
except ImportError:
|
||||
print("无法导入 NavigateToPose action,请确保已安装相关 ROS 包。")
|
||||
0
unilabos/ros/nodes/__init__.py
Normal file
0
unilabos/ros/nodes/__init__.py
Normal file
672
unilabos/ros/nodes/base_device_node.py
Normal file
672
unilabos/ros/nodes/base_device_node.py
Normal file
@@ -0,0 +1,672 @@
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import get_type_hints, TypeVar, Generic, Dict, Any, Type, TypedDict, Optional
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import asyncio
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.action import ActionServer
|
||||
from rclpy.action.server import ServerGoalHandle
|
||||
from rclpy.client import Client
|
||||
from rclpy.callback_groups import ReentrantCallbackGroup
|
||||
|
||||
from unilabos.resources.graphio import convert_resources_to_type, convert_resources_from_type
|
||||
from unilabos.ros.msgs.message_converter import (
|
||||
convert_to_ros_msg,
|
||||
convert_from_ros_msg,
|
||||
convert_from_ros_msg_with_mapping,
|
||||
convert_to_ros_msg_with_mapping,
|
||||
)
|
||||
from unilabos_msgs.srv import ResourceAdd, ResourceGet, ResourceDelete, ResourceUpdate, ResourceList # type: ignore
|
||||
from unilabos_msgs.msg import Resource # type: ignore
|
||||
|
||||
from unilabos.ros.nodes.resource_tracker import DeviceNodeResourceTracker
|
||||
from unilabos.ros.x.rclpyx import get_event_loop
|
||||
from unilabos.ros.utils.driver_creator import ProtocolNodeCreator, PyLabRobotCreator, DeviceClassCreator
|
||||
from unilabos.utils.async_util import run_async_func
|
||||
from unilabos.utils.log import info, debug, warning, error, critical, logger
|
||||
from unilabos.utils.type_check import get_type_class
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# 在线设备注册表
|
||||
registered_devices: Dict[str, "DeviceInfoType"] = {}
|
||||
|
||||
|
||||
# 实现同时记录自定义日志和ROS2日志的适配器
|
||||
class ROSLoggerAdapter:
|
||||
"""同时向自定义日志和ROS2日志发送消息的适配器"""
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
return f"{self.namespace}/{self.node_name}"
|
||||
|
||||
def __init__(self, ros_logger, node_name, namespace):
|
||||
"""
|
||||
初始化日志适配器
|
||||
|
||||
Args:
|
||||
ros_logger: ROS2日志记录器
|
||||
node_name: 节点名称
|
||||
namespace: 命名空间
|
||||
"""
|
||||
self.ros_logger = ros_logger
|
||||
self.node_name = node_name
|
||||
self.namespace = namespace
|
||||
self.level_2_logger_func = {
|
||||
"info": info,
|
||||
"debug": debug,
|
||||
"warning": warning,
|
||||
"error": error,
|
||||
"critical": critical,
|
||||
}
|
||||
|
||||
def _log(self, level, msg, *args, **kwargs):
|
||||
"""实际执行日志记录的内部方法"""
|
||||
# 添加前缀,使日志更易识别
|
||||
msg = f"[{self.identifier}] {msg}"
|
||||
# 向ROS2日志发送消息(标准库logging不支持stack_level参数)
|
||||
ros_log_func = getattr(self.ros_logger, "debug") # 默认发送debug,这样不会显示在控制台
|
||||
ros_log_func(msg)
|
||||
self.level_2_logger_func[level](msg, *args, stack_level=1, **kwargs)
|
||||
|
||||
def debug(self, msg, *args, **kwargs):
|
||||
"""记录DEBUG级别日志"""
|
||||
self._log("debug", msg, *args, **kwargs)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
"""记录INFO级别日志"""
|
||||
self._log("info", msg, *args, **kwargs)
|
||||
|
||||
def warning(self, msg, *args, **kwargs):
|
||||
"""记录WARNING级别日志"""
|
||||
self._log("warning", msg, *args, **kwargs)
|
||||
|
||||
def error(self, msg, *args, **kwargs):
|
||||
"""记录ERROR级别日志"""
|
||||
self._log("error", msg, *args, **kwargs)
|
||||
|
||||
def critical(self, msg, *args, **kwargs):
|
||||
"""记录CRITICAL级别日志"""
|
||||
self._log("critical", msg, *args, **kwargs)
|
||||
|
||||
|
||||
def init_wrapper(
|
||||
self,
|
||||
device_id: str,
|
||||
driver_class: type[T],
|
||||
status_types: Dict[str, Any],
|
||||
action_value_mappings: Dict[str, Any],
|
||||
hardware_interface: Dict[str, Any],
|
||||
print_publish: bool,
|
||||
children: Optional[list] = None,
|
||||
driver_params: Optional[Dict[str, Any]] = None,
|
||||
driver_is_ros: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化设备节点的包装函数,和ROS2DeviceNode初始化保持一致"""
|
||||
if driver_params is None:
|
||||
driver_params = kwargs.copy()
|
||||
if children is None:
|
||||
children = []
|
||||
kwargs["device_id"] = device_id
|
||||
kwargs["driver_class"] = driver_class
|
||||
kwargs["driver_params"] = driver_params
|
||||
kwargs["status_types"] = status_types
|
||||
kwargs["action_value_mappings"] = action_value_mappings
|
||||
kwargs["hardware_interface"] = hardware_interface
|
||||
kwargs["children"] = children
|
||||
kwargs["print_publish"] = print_publish
|
||||
kwargs["driver_is_ros"] = driver_is_ros
|
||||
super(type(self), self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class PropertyPublisher:
|
||||
def __init__(
|
||||
self,
|
||||
node: "BaseROS2DeviceNode",
|
||||
name: str,
|
||||
get_method,
|
||||
msg_type,
|
||||
initial_period: float = 5.0,
|
||||
print_publish=True,
|
||||
):
|
||||
self.node = node
|
||||
self.name = name
|
||||
self.msg_type = msg_type
|
||||
self.get_method = get_method
|
||||
self.timer_period = initial_period
|
||||
self.print_publish = print_publish
|
||||
|
||||
self._value = None
|
||||
self.publisher_ = node.create_publisher(msg_type, f"{name}", 10)
|
||||
self.timer = node.create_timer(self.timer_period, self.publish_property)
|
||||
self.__loop = get_event_loop()
|
||||
str_msg_type = str(msg_type)[8:-2]
|
||||
self.node.lab_logger().debug(f"发布属性: {name}, 类型: {str_msg_type}, 周期: {initial_period}秒")
|
||||
|
||||
def get_property(self):
|
||||
if asyncio.iscoroutinefunction(self.get_method):
|
||||
# 如果是异步函数,运行事件循环并等待结果
|
||||
self.node.get_logger().debug(f"【PropertyPublisher.get_property】获取异步属性: {self.name}")
|
||||
loop = self.__loop
|
||||
if loop:
|
||||
future = asyncio.run_coroutine_threadsafe(self.get_method(), loop)
|
||||
self._value = future.result()
|
||||
return self._value
|
||||
else:
|
||||
self.node.get_logger().error(f"【PropertyPublisher.get_property】事件循环未初始化")
|
||||
return None
|
||||
else:
|
||||
# 如果是同步函数,直接调用并返回结果
|
||||
self.node.get_logger().debug(f"【PropertyPublisher.get_property】获取同步属性: {self.name}")
|
||||
self._value = self.get_method()
|
||||
return self._value
|
||||
|
||||
async def get_property_async(self):
|
||||
try:
|
||||
# 获取异步属性值
|
||||
self.node.get_logger().debug(f"【PropertyPublisher.get_property_async】异步获取属性: {self.name}")
|
||||
self._value = await self.get_method()
|
||||
except Exception as e:
|
||||
self.node.get_logger().error(f"【PropertyPublisher.get_property_async】获取异步属性出错: {str(e)}")
|
||||
|
||||
def publish_property(self):
|
||||
try:
|
||||
self.node.get_logger().debug(f"【PropertyPublisher.publish_property】开始发布属性: {self.name}")
|
||||
value = self.get_property()
|
||||
if self.print_publish:
|
||||
self.node.get_logger().info(f"【PropertyPublisher.publish_property】发布 {self.msg_type}: {value}")
|
||||
if value is not None:
|
||||
msg = convert_to_ros_msg(self.msg_type, value)
|
||||
self.publisher_.publish(msg)
|
||||
self.node.get_logger().debug(f"【PropertyPublisher.publish_property】属性 {self.name} 发布成功")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
self.node.get_logger().error(f"【PropertyPublisher.publish_property】发布属性出错: {str(e)}")
|
||||
|
||||
def change_frequency(self, period):
|
||||
# 动态改变定时器频率
|
||||
self.timer_period = period
|
||||
self.node.get_logger().info(
|
||||
f"【PropertyPublisher.change_frequency】修改 {self.name} 定时器周期为: {self.timer_period} 秒"
|
||||
)
|
||||
|
||||
# 重置定时器
|
||||
self.timer.cancel()
|
||||
self.timer = self.node.create_timer(self.timer_period, self.publish_property)
|
||||
|
||||
|
||||
class BaseROS2DeviceNode(Node, Generic[T]):
|
||||
"""
|
||||
ROS2设备节点基类
|
||||
|
||||
这个类提供了ROS2设备节点的基本功能,包括属性发布、动作服务等。
|
||||
通过泛型参数T来指定具体的设备类型。
|
||||
"""
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
return f"{self.namespace}/{self.device_id}"
|
||||
|
||||
node_name: str
|
||||
namespace: str
|
||||
# TODO 要删除,添加时间相关的属性,避免动态添加属性的警告
|
||||
time_spent = 0.0
|
||||
time_remaining = 0.0
|
||||
create_action_server = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
driver_instance: T,
|
||||
device_id: str,
|
||||
status_types: Dict[str, Any],
|
||||
action_value_mappings: Dict[str, Any],
|
||||
hardware_interface: Dict[str, Any],
|
||||
print_publish=True,
|
||||
resource_tracker: Optional["DeviceNodeResourceTracker"] = None,
|
||||
):
|
||||
"""
|
||||
初始化ROS2设备节点
|
||||
|
||||
Args:
|
||||
driver_instance: 设备实例
|
||||
device_id: 设备标识符
|
||||
status_types: 需要发布的状态和传感器信息
|
||||
action_value_mappings: 设备动作
|
||||
hardware_interface: 硬件接口配置
|
||||
print_publish: 是否打印发布信息
|
||||
"""
|
||||
self.driver_instance = driver_instance
|
||||
self.device_id = device_id
|
||||
self.uuid = str(uuid.uuid4())
|
||||
self.publish_high_frequency = False
|
||||
self.callback_group = ReentrantCallbackGroup()
|
||||
self.resource_tracker = resource_tracker
|
||||
|
||||
# 初始化ROS节点
|
||||
self.node_name = f'{device_id.split("/")[-1]}'
|
||||
self.namespace = f"/devices/{device_id}"
|
||||
Node.__init__(self, self.node_name, namespace=self.namespace) # type: ignore
|
||||
if self.resource_tracker is None:
|
||||
self.lab_logger().critical("资源跟踪器未初始化,请检查")
|
||||
|
||||
# 创建自定义日志记录器
|
||||
self._lab_logger = ROSLoggerAdapter(self.get_logger(), self.node_name, self.namespace)
|
||||
|
||||
self._action_servers = {}
|
||||
self._property_publishers = {}
|
||||
self._status_types = status_types
|
||||
self._action_value_mappings = action_value_mappings
|
||||
self._hardware_interface = hardware_interface
|
||||
self._print_publish = print_publish
|
||||
|
||||
# 创建属性发布者
|
||||
for attr_name, msg_type in self._status_types.items():
|
||||
if isinstance(attr_name, (int, float)):
|
||||
if "param" in msg_type.keys():
|
||||
pass
|
||||
else:
|
||||
for k, v in msg_type.items():
|
||||
self.create_ros_publisher(k, v, initial_period=5.0)
|
||||
else:
|
||||
self.create_ros_publisher(attr_name, msg_type)
|
||||
|
||||
# 创建动作服务
|
||||
if self.create_action_server:
|
||||
for action_name, action_value_mapping in self._action_value_mappings.items():
|
||||
self.create_ros_action_server(action_name, action_value_mapping)
|
||||
|
||||
# 创建线程池执行器
|
||||
self._executor = ThreadPoolExecutor(max_workers=max(len(action_value_mappings), 1))
|
||||
|
||||
# 创建资源管理客户端
|
||||
self._resource_clients: Dict[str, Client] = {
|
||||
"resource_add": self.create_client(ResourceAdd, "/resources/add"),
|
||||
"resource_get": self.create_client(ResourceGet, "/resources/get"),
|
||||
"resource_delete": self.create_client(ResourceDelete, "/resources/delete"),
|
||||
"resource_update": self.create_client(ResourceUpdate, "/resources/update"),
|
||||
"resource_list": self.create_client(ResourceList, "/resources/list"),
|
||||
}
|
||||
|
||||
# 向全局在线设备注册表添加设备信息
|
||||
self.register_device()
|
||||
rclpy.get_global_executor().add_node(self)
|
||||
self.lab_logger().debug(f"ROS节点初始化完成")
|
||||
|
||||
def register_device(self):
|
||||
"""向注册表中注册设备信息"""
|
||||
topics_info = self._property_publishers.copy()
|
||||
actions_info = self._action_servers.copy()
|
||||
# 创建设备信息
|
||||
device_info = DeviceInfoType(
|
||||
id=self.device_id,
|
||||
uuid=self.uuid,
|
||||
node_name=self.node_name,
|
||||
namespace=self.namespace,
|
||||
driver_instance=self.driver_instance,
|
||||
status_publishers=topics_info,
|
||||
actions=actions_info,
|
||||
hardware_interface=self._hardware_interface,
|
||||
base_node_instance=self,
|
||||
)
|
||||
# 加入全局注册表
|
||||
registered_devices[self.device_id] = device_info
|
||||
|
||||
def lab_logger(self):
|
||||
"""
|
||||
获取实验室自定义日志记录器
|
||||
|
||||
这个日志记录器会同时向ROS2日志和自定义日志发送消息,
|
||||
并使用node_name和namespace作为标识。
|
||||
|
||||
Returns:
|
||||
日志记录器实例
|
||||
"""
|
||||
return self._lab_logger
|
||||
|
||||
def create_ros_publisher(self, attr_name, msg_type, initial_period=5.0):
|
||||
"""创建ROS发布者"""
|
||||
|
||||
# 获取属性值的方法
|
||||
def get_device_attr():
|
||||
try:
|
||||
if hasattr(self.driver_instance, f"get_{attr_name}"):
|
||||
return getattr(self.driver_instance, f"get_{attr_name}")()
|
||||
else:
|
||||
return getattr(self.driver_instance, attr_name)
|
||||
except AttributeError as ex:
|
||||
self.lab_logger().error(
|
||||
f"publish error, {str(type(self.driver_instance))[8:-2]} has no attribute '{attr_name}'"
|
||||
)
|
||||
|
||||
self._property_publishers[attr_name] = PropertyPublisher(
|
||||
self, attr_name, get_device_attr, msg_type, initial_period, self._print_publish
|
||||
)
|
||||
|
||||
def create_ros_action_server(self, action_name, action_value_mapping):
|
||||
"""创建ROS动作服务器"""
|
||||
action_type = action_value_mapping["type"]
|
||||
str_action_type = str(action_type)[8:-2]
|
||||
|
||||
self._action_servers[action_name] = ActionServer(
|
||||
self,
|
||||
action_type,
|
||||
action_name,
|
||||
execute_callback=self._create_execute_callback(action_name, action_value_mapping),
|
||||
callback_group=ReentrantCallbackGroup(),
|
||||
)
|
||||
|
||||
self.lab_logger().debug(f"发布动作: {action_name}, 类型: {str_action_type}")
|
||||
|
||||
def _create_execute_callback(self, action_name, action_value_mapping):
|
||||
"""创建动作执行回调函数"""
|
||||
|
||||
async def execute_callback(goal_handle: ServerGoalHandle):
|
||||
self.lab_logger().info(f"执行动作: {action_name}")
|
||||
goal = goal_handle.request
|
||||
|
||||
# 从目标消息中提取参数, 并调用对应的方法
|
||||
if "sequence" in self._action_value_mappings:
|
||||
# 如果一个指令对应函数的连续调用,如启动和等待结果,默认参数应该属于第一个函数调用
|
||||
def ACTION(**kwargs):
|
||||
for i, action in enumerate(self._action_value_mappings["sequence"]):
|
||||
if i == 0:
|
||||
self.lab_logger().info(f"执行序列动作第一步: {action}")
|
||||
getattr(self.driver_instance, action)(**kwargs)
|
||||
else:
|
||||
self.lab_logger().info(f"执行序列动作后续步骤: {action}")
|
||||
getattr(self.driver_instance, action)()
|
||||
|
||||
action_paramtypes = get_type_hints(
|
||||
getattr(self.driver_instance, self._action_value_mappings["sequence"][0])
|
||||
)
|
||||
else:
|
||||
ACTION = getattr(self.driver_instance, action_name)
|
||||
action_paramtypes = get_type_hints(ACTION)
|
||||
|
||||
action_kwargs = convert_from_ros_msg_with_mapping(goal, action_value_mapping["goal"])
|
||||
self.lab_logger().debug(f"接收到原始目标: {action_kwargs}")
|
||||
|
||||
# 向Host查询物料当前状态
|
||||
for k, v in goal.get_fields_and_field_types().items():
|
||||
if v in ["unilabos_msgs/Resource", "sequence<unilabos_msgs/Resource>"]:
|
||||
self.lab_logger().info(f"查询资源状态: Key: {k} Type: {v}")
|
||||
try:
|
||||
r = ResourceGet.Request()
|
||||
r.id = action_kwargs[k]["id"] if v == "unilabos_msgs/Resource" else action_kwargs[k][0]["id"]
|
||||
r.with_children = True
|
||||
response = await self._resource_clients["resource_get"].call_async(r)
|
||||
except Exception:
|
||||
logger.error(f"资源查询失败,默认使用本地资源")
|
||||
# 删除对response.resources的检查,因为它总是存在
|
||||
resources_list = [convert_from_ros_msg(rs) for rs in response.resources] # type: ignore # FIXME
|
||||
self.lab_logger().debug(f"资源查询结果: {len(resources_list)} 个资源")
|
||||
type_hint = action_paramtypes[k]
|
||||
final_type = get_type_class(type_hint)
|
||||
# 判断 ACTION 是否需要特殊的物料类型如 pylabrobot.resources.Resource,并做转换
|
||||
final_resource = convert_resources_to_type(resources_list, final_type)
|
||||
action_kwargs[k] = self.resource_tracker.figure_resource(final_resource)
|
||||
|
||||
self.lab_logger().info(f"准备执行: {action_kwargs}, 函数: {ACTION.__name__}")
|
||||
time_start = time.time()
|
||||
time_overall = 100
|
||||
|
||||
# 将阻塞操作放入线程池执行
|
||||
if asyncio.iscoroutinefunction(ACTION):
|
||||
try:
|
||||
self.lab_logger().info(f"异步执行动作 {ACTION}")
|
||||
future = ROS2DeviceNode.run_async_func(ACTION, **action_kwargs)
|
||||
except Exception as e:
|
||||
self.lab_logger().error(f"创建异步任务失败: {traceback.format_exc()}")
|
||||
raise e
|
||||
else:
|
||||
self.lab_logger().info(f"同步执行动作 {ACTION}")
|
||||
future = self._executor.submit(ACTION, **action_kwargs)
|
||||
|
||||
action_type = action_value_mapping["type"]
|
||||
feedback_msg_types = action_type.Feedback.get_fields_and_field_types()
|
||||
result_msg_types = action_type.Result.get_fields_and_field_types()
|
||||
|
||||
while not future.done():
|
||||
if goal_handle.is_cancel_requested:
|
||||
self.lab_logger().info(f"取消动作: {action_name}")
|
||||
future.cancel() # 尝试取消线程池中的任务
|
||||
goal_handle.canceled()
|
||||
return action_type.Result()
|
||||
|
||||
self.time_spent = time.time() - time_start
|
||||
self.time_remaining = time_overall - self.time_spent
|
||||
|
||||
# 发布反馈
|
||||
feedback_values = {}
|
||||
for msg_name, attr_name in action_value_mapping["feedback"].items():
|
||||
if hasattr(self.driver_instance, f"get_{attr_name}"):
|
||||
method = getattr(self.driver_instance, f"get_{attr_name}")
|
||||
if not asyncio.iscoroutinefunction(method):
|
||||
feedback_values[msg_name] = method()
|
||||
elif hasattr(self.driver_instance, attr_name):
|
||||
feedback_values[msg_name] = getattr(self.driver_instance, attr_name)
|
||||
|
||||
if self._print_publish:
|
||||
self.lab_logger().info(f"反馈: {feedback_values}")
|
||||
|
||||
feedback_msg = convert_to_ros_msg_with_mapping(
|
||||
ros_msg_type=action_type.Feedback(),
|
||||
obj=feedback_values,
|
||||
value_mapping=action_value_mapping["feedback"],
|
||||
)
|
||||
|
||||
goal_handle.publish_feedback(feedback_msg)
|
||||
time.sleep(0.5)
|
||||
|
||||
if future.cancelled():
|
||||
self.lab_logger().info(f"动作 {action_name} 已取消")
|
||||
return action_type.Result()
|
||||
|
||||
self.lab_logger().info(f"动作执行完成: {action_name}")
|
||||
del future
|
||||
|
||||
# 向Host更新物料当前状态
|
||||
for k, v in goal.get_fields_and_field_types().items():
|
||||
if v not in ["unilabos_msgs/Resource", "sequence<unilabos_msgs/Resource>"]:
|
||||
continue
|
||||
self.lab_logger().info(f"更新资源状态: {k}")
|
||||
r = ResourceUpdate.Request()
|
||||
# 仅当action_kwargs[k]不为None时尝试转换
|
||||
akv = action_kwargs[k]
|
||||
apv = action_paramtypes[k]
|
||||
final_type = get_type_class(apv)
|
||||
if final_type is None:
|
||||
continue
|
||||
try:
|
||||
r.resources = [
|
||||
convert_to_ros_msg(Resource, self.resource_tracker.root_resource(rs))
|
||||
for rs in convert_resources_from_type(akv, final_type) # type: ignore # FIXME # 考虑反查到最大的
|
||||
]
|
||||
response = await self._resource_clients["resource_update"].call_async(r)
|
||||
self.lab_logger().debug(f"资源更新结果: {response}")
|
||||
except Exception as e:
|
||||
self.lab_logger().error(f"资源更新失败: {e}")
|
||||
self.lab_logger().error(traceback.format_exc())
|
||||
|
||||
# 发布结果
|
||||
goal_handle.succeed()
|
||||
self.lab_logger().info(f"设置动作成功: {action_name}")
|
||||
|
||||
result_values = {}
|
||||
for msg_name, attr_name in action_value_mapping["result"].items():
|
||||
if hasattr(self.driver_instance, f"get_{attr_name}"):
|
||||
result_values[msg_name] = getattr(self.driver_instance, f"get_{attr_name}")()
|
||||
elif hasattr(self.driver_instance, attr_name):
|
||||
result_values[msg_name] = getattr(self.driver_instance, attr_name)
|
||||
|
||||
result_msg = convert_to_ros_msg_with_mapping(
|
||||
ros_msg_type=action_type.Result(), obj=result_values, value_mapping=action_value_mapping["result"]
|
||||
)
|
||||
|
||||
for attr_name in result_msg_types.keys():
|
||||
if attr_name in ["success", "reached_goal"]:
|
||||
setattr(result_msg, attr_name, True)
|
||||
|
||||
self.lab_logger().info(f"动作 {action_name} 完成并返回结果")
|
||||
return result_msg
|
||||
|
||||
return execute_callback
|
||||
|
||||
# 异步上下文管理方法
|
||||
async def __aenter__(self):
|
||||
"""进入异步上下文"""
|
||||
self.lab_logger().info(f"进入异步上下文: {self.device_id}")
|
||||
if hasattr(self.driver_instance, "__aenter__"):
|
||||
await self.driver_instance.__aenter__() # type: ignore
|
||||
self.lab_logger().info(f"异步上下文初始化完成: {self.device_id}")
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""退出异步上下文"""
|
||||
self.lab_logger().info(f"退出异步上下文: {self.device_id}")
|
||||
if hasattr(self.driver_instance, "__aexit__"):
|
||||
await self.driver_instance.__aexit__(exc_type, exc_val, exc_tb) # type: ignore
|
||||
self.lab_logger().info(f"异步上下文清理完成: {self.device_id}")
|
||||
|
||||
|
||||
class DeviceInitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ROS2DeviceNode:
|
||||
"""
|
||||
ROS2设备节点类
|
||||
|
||||
这个类封装了设备类实例和ROS2节点的功能,提供ROS2接口。
|
||||
它不继承设备类,而是通过代理模式访问设备类的属性和方法。
|
||||
"""
|
||||
|
||||
# 类变量,用于循环管理
|
||||
_loop = None
|
||||
_loop_running = False
|
||||
_loop_thread = None
|
||||
|
||||
@classmethod
|
||||
def get_loop(cls):
|
||||
return cls._loop
|
||||
|
||||
@classmethod
|
||||
def run_async_func(cls, func, **kwargs):
|
||||
return run_async_func(func, loop=cls._loop, **kwargs)
|
||||
|
||||
@property
|
||||
def driver_instance(self):
|
||||
return self._driver_instance
|
||||
|
||||
@property
|
||||
def ros_node_instance(self):
|
||||
return self._ros_node
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_id: str,
|
||||
driver_class: Type[T],
|
||||
driver_params: Dict[str, Any],
|
||||
status_types: Dict[str, Any],
|
||||
action_value_mappings: Dict[str, Any],
|
||||
hardware_interface: Dict[str, Any],
|
||||
children: Dict[str, Any],
|
||||
print_publish: bool = True,
|
||||
driver_is_ros: bool = False,
|
||||
):
|
||||
"""
|
||||
初始化ROS2设备节点
|
||||
|
||||
Args:
|
||||
device_id: 设备标识符
|
||||
driver_class: 设备类
|
||||
status_types: 状态类型映射
|
||||
action_value_mappings: 动作值映射
|
||||
hardware_interface: 硬件接口配置
|
||||
children:
|
||||
print_publish: 是否打印发布信息
|
||||
driver_is_ros:
|
||||
"""
|
||||
# 在初始化时检查循环状态
|
||||
if ROS2DeviceNode._loop_running and ROS2DeviceNode._loop_thread is not None:
|
||||
pass
|
||||
elif ROS2DeviceNode._loop_thread is None:
|
||||
self._start_loop()
|
||||
|
||||
# 保存设备类是否支持异步上下文
|
||||
self._has_async_context = hasattr(driver_class, "__aenter__") and hasattr(driver_class, "__aexit__")
|
||||
self._driver_class = driver_class
|
||||
self.driver_is_ros = driver_is_ros
|
||||
self.resource_tracker = DeviceNodeResourceTracker()
|
||||
|
||||
# use_pylabrobot_creator 使用 cls的包路径检测
|
||||
use_pylabrobot_creator = driver_class.__module__.startswith("pylabrobot")
|
||||
|
||||
# TODO: 要在创建之前预先请求服务器是否有当前id的物料,放到resource_tracker中,让pylabrobot进行创建
|
||||
# 创建设备类实例
|
||||
if use_pylabrobot_creator:
|
||||
self._driver_creator = PyLabRobotCreator(
|
||||
driver_class, children=children, resource_tracker=self.resource_tracker
|
||||
)
|
||||
else:
|
||||
from unilabos.ros.nodes.presets.protocol_node import ROS2ProtocolNode
|
||||
|
||||
if self._driver_class is ROS2ProtocolNode:
|
||||
self._driver_creator = ProtocolNodeCreator(driver_class, children=children)
|
||||
else:
|
||||
self._driver_creator = DeviceClassCreator(driver_class)
|
||||
|
||||
if driver_is_ros:
|
||||
driver_params["device_id"] = device_id
|
||||
driver_params["resource_tracker"] = self.resource_tracker
|
||||
self._driver_instance = self._driver_creator.create_instance(driver_params)
|
||||
if self._driver_instance is None:
|
||||
logger.critical(f"设备实例创建失败 {driver_class}, params: {driver_params}")
|
||||
raise DeviceInitError("错误: 设备实例创建失败")
|
||||
|
||||
# 创建ROS2节点
|
||||
if driver_is_ros:
|
||||
self._ros_node = self._driver_instance # type: ignore
|
||||
else:
|
||||
self._ros_node = BaseROS2DeviceNode(
|
||||
driver_instance=self._driver_instance,
|
||||
device_id=device_id,
|
||||
status_types=status_types,
|
||||
action_value_mappings=action_value_mappings,
|
||||
hardware_interface=hardware_interface,
|
||||
print_publish=print_publish,
|
||||
resource_tracker=self.resource_tracker,
|
||||
)
|
||||
self._ros_node: BaseROS2DeviceNode
|
||||
self._ros_node.lab_logger().info(f"初始化完成 {self._ros_node.uuid} {self.driver_is_ros}")
|
||||
|
||||
def _start_loop(self):
|
||||
def run_event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
ROS2DeviceNode._loop = loop
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
ROS2DeviceNode._loop_thread = threading.Thread(target=run_event_loop, daemon=True, name="ROS2DeviceNode")
|
||||
ROS2DeviceNode._loop_thread.start()
|
||||
logger.info(f"循环线程已启动")
|
||||
|
||||
|
||||
class DeviceInfoType(TypedDict):
|
||||
id: str
|
||||
uuid: str
|
||||
node_name: str
|
||||
namespace: str
|
||||
driver_instance: Any
|
||||
status_publishers: Dict[str, PropertyPublisher]
|
||||
actions: Dict[str, ActionServer]
|
||||
hardware_interface: Dict[str, Any]
|
||||
base_node_instance: BaseROS2DeviceNode
|
||||
0
unilabos/ros/nodes/presets/__init__.py
Normal file
0
unilabos/ros/nodes/presets/__init__.py
Normal file
122
unilabos/ros/nodes/presets/controller_node.py
Normal file
122
unilabos/ros/nodes/presets/controller_node.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import Callable, Dict
|
||||
from std_msgs.msg import Float64
|
||||
|
||||
from unilabos.ros.nodes.base_device_node import BaseROS2DeviceNode, DeviceNodeResourceTracker
|
||||
|
||||
|
||||
class ControllerNode(BaseROS2DeviceNode):
|
||||
namespace_prefix = "/controllers"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_id: str,
|
||||
controller_func: Callable,
|
||||
update_rate: float,
|
||||
inputs: Dict[str, Dict[str, type | str]],
|
||||
outputs: Dict[str, Dict[str, type]],
|
||||
parameters: Dict,
|
||||
resource_tracker: DeviceNodeResourceTracker,
|
||||
):
|
||||
"""
|
||||
通用控制器节点
|
||||
|
||||
:param controller_id: 控制器的唯一标识符(作为命名空间的一部分)
|
||||
:param update_rate: 控制器更新频率 (Hz)
|
||||
:param controller_func: 控制器函数,接收 Python 格式的 inputs 和 parameters 返回 outputs
|
||||
:param input_types: 输入话题及其消息类型的字典
|
||||
:param output_types: 输出话题及其消息类型的字典
|
||||
:param parameters: 控制器函数的额外参数
|
||||
"""
|
||||
# 先准备所需的属性,以便在调用父类初始化前就可以使用
|
||||
self.device_id = device_id
|
||||
self.controller_func = controller_func
|
||||
self.update_rate = update_rate
|
||||
self.update_time = 1.0 / update_rate
|
||||
self.parameters = parameters
|
||||
self.inputs = {topic: None for topic in inputs.keys()}
|
||||
self.control_input_subscribers = {}
|
||||
self.control_output_publishers = {}
|
||||
self.topic_mapping = {
|
||||
**{input_info["topic"]: input for input, input_info in inputs.items()},
|
||||
**{output_info["topic"]: output for output, output_info in outputs.items()},
|
||||
}
|
||||
|
||||
# 调用BaseROS2DeviceNode初始化,使用自身作为driver_instance
|
||||
status_types = {}
|
||||
action_value_mappings = {}
|
||||
hardware_interface = {}
|
||||
|
||||
# 使用短ID作为节点名,完整ID(带namespace_prefix)作为device_id
|
||||
BaseROS2DeviceNode.__init__(
|
||||
self,
|
||||
driver_instance=self,
|
||||
device_id=device_id,
|
||||
status_types=status_types,
|
||||
action_value_mappings=action_value_mappings,
|
||||
hardware_interface=hardware_interface,
|
||||
print_publish=False,
|
||||
resource_tracker=resource_tracker
|
||||
)
|
||||
|
||||
# 原始初始化逻辑
|
||||
# 初始化订阅者
|
||||
for input, input_info in inputs.items():
|
||||
msg_type = input_info["type"]
|
||||
topic = str(input_info["topic"])
|
||||
self.control_input_subscribers[input] = self.create_subscription(
|
||||
msg_type, topic, lambda msg, t=topic: self.input_callback(t, msg), 10
|
||||
)
|
||||
|
||||
# 初始化发布者
|
||||
for output, output_info in outputs.items():
|
||||
self.lab_logger().info(f"Creating publisher for {output} {output_info}")
|
||||
msg_type = output_info["type"]
|
||||
topic = str(output_info["topic"])
|
||||
self.control_output_publishers[output] = self.create_publisher(msg_type, topic, 10)
|
||||
|
||||
# 定时器,用于定期调用控制逻辑
|
||||
self.timer = self.create_timer(self.update_time, self.control_loop)
|
||||
|
||||
def input_callback(self, topic: str, msg):
|
||||
"""
|
||||
更新指定话题的输入数据,并将 ROS 消息转换为普通 Python 数据。
|
||||
支持 `std_msgs` 类型消息。
|
||||
"""
|
||||
self.inputs[self.topic_mapping[topic]] = msg.data
|
||||
self.lab_logger().info(f"Received input on topic {topic}: {msg.data}")
|
||||
|
||||
def control_loop(self):
|
||||
"""主控制逻辑"""
|
||||
# 检查所有输入是否已更新
|
||||
if all(value is not None for value in self.inputs.values()):
|
||||
self.lab_logger().info(
|
||||
f"Calling controller function with inputs: {self.inputs}, parameters: {self.parameters}"
|
||||
)
|
||||
try:
|
||||
# 调用控制器函数,传入 Python 格式的数据
|
||||
outputs = self.controller_func(**self.inputs, **self.parameters)
|
||||
self.lab_logger().info(f"Inputs: {self.inputs}, Outputs: {outputs}")
|
||||
self.inputs = {topic: None for topic in self.inputs.keys()}
|
||||
except Exception as e:
|
||||
self.lab_logger().error(f"Controller function execution failed: {e}")
|
||||
return
|
||||
|
||||
# 发布控制信号,将普通 Python 数据转换为 ROS 消息
|
||||
if isinstance(outputs, dict):
|
||||
for topic, value in outputs.items():
|
||||
if topic in self.control_output_publishers:
|
||||
# 支持 Float64 输出
|
||||
if isinstance(value, (float, int)):
|
||||
self.control_output_publishers[topic].publish(Float64(data=value))
|
||||
else:
|
||||
self.lab_logger().error(f"Unsupported output type for topic {topic}: {type(value)}")
|
||||
else:
|
||||
self.lab_logger().warning(f"Output topic {topic} is not defined in output_types.")
|
||||
else:
|
||||
publisher = list(self.control_output_publishers.values())[0]
|
||||
if isinstance(outputs, (float, int)):
|
||||
publisher.publish(Float64(data=outputs))
|
||||
else:
|
||||
self.lab_logger().error(f"Unsupported output type: {type(outputs)}")
|
||||
else:
|
||||
self.lab_logger().info("Waiting for all inputs to be received.")
|
||||
623
unilabos/ros/nodes/presets/host_node.py
Normal file
623
unilabos/ros/nodes/presets/host_node.py
Normal file
@@ -0,0 +1,623 @@
|
||||
import copy
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, List, ClassVar, Set
|
||||
|
||||
from action_msgs.msg import GoalStatus
|
||||
from unilabos_msgs.msg import Resource # type: ignore
|
||||
from unilabos_msgs.srv import ResourceAdd, ResourceGet, ResourceDelete, ResourceUpdate, ResourceList # type: ignore
|
||||
from rclpy.action import ActionClient, get_action_server_names_and_types_by_node
|
||||
from rclpy.callback_groups import ReentrantCallbackGroup
|
||||
from rclpy.service import Service
|
||||
from unique_identifier_msgs.msg import UUID
|
||||
|
||||
from unilabos.resources.registry import add_schema
|
||||
from unilabos.ros.initialize_device import initialize_device_from_dict
|
||||
from unilabos.ros.msgs.message_converter import (
|
||||
get_msg_type,
|
||||
get_ros_type_by_msgname,
|
||||
convert_from_ros_msg,
|
||||
convert_to_ros_msg,
|
||||
msg_converter_manager, ros_action_to_json_schema,
|
||||
)
|
||||
from unilabos.ros.nodes.base_device_node import BaseROS2DeviceNode, ROS2DeviceNode, DeviceNodeResourceTracker
|
||||
from unilabos.ros.nodes.presets.controller_node import ControllerNode
|
||||
|
||||
|
||||
class HostNode(BaseROS2DeviceNode):
|
||||
"""
|
||||
主机节点类,负责管理设备、资源和控制器
|
||||
|
||||
作为单例模式实现,确保整个应用中只有一个主机节点实例
|
||||
"""
|
||||
|
||||
_instance: ClassVar[Optional["HostNode"]] = None
|
||||
_ready_event: ClassVar[threading.Event] = threading.Event()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, timeout=None) -> Optional["HostNode"]:
|
||||
if cls._ready_event.wait(timeout):
|
||||
return cls._instance
|
||||
return None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_id: str,
|
||||
devices_config: Dict[str, Any],
|
||||
resources_config: Any,
|
||||
physical_setup_graph: Optional[Dict[str, Any]] = None,
|
||||
controllers_config: Optional[Dict[str, Any]] = None,
|
||||
bridges: Optional[List[Any]] = None,
|
||||
discovery_interval: float = 180.0, # 设备发现间隔,单位为秒
|
||||
):
|
||||
"""
|
||||
初始化主机节点
|
||||
|
||||
Args:
|
||||
device_id: 节点名称
|
||||
devices_config: 设备配置
|
||||
resources_config: 资源配置
|
||||
physical_setup_graph: 物理设置图
|
||||
controllers_config: 控制器配置
|
||||
bridges: 桥接器列表
|
||||
discovery_interval: 设备发现间隔(秒),默认5秒
|
||||
"""
|
||||
if self._instance is not None:
|
||||
self._instance.lab_logger().critical("[Host Node] HostNode instance already exists.")
|
||||
# 初始化Node基类,传递空参数覆盖列表
|
||||
BaseROS2DeviceNode.__init__(
|
||||
self,
|
||||
driver_instance=self,
|
||||
device_id=device_id,
|
||||
status_types={},
|
||||
action_value_mappings={},
|
||||
hardware_interface={},
|
||||
print_publish=False,
|
||||
resource_tracker=DeviceNodeResourceTracker(), # host node并不是通过initialize 包一层传进来的
|
||||
)
|
||||
|
||||
# 设置单例实例
|
||||
self.__class__._instance = self
|
||||
|
||||
# 初始化配置
|
||||
self.devices_config = devices_config
|
||||
self.resources_config = resources_config
|
||||
self.physical_setup_graph = physical_setup_graph
|
||||
if controllers_config is None:
|
||||
controllers_config = {}
|
||||
self.controllers_config = controllers_config
|
||||
if bridges is None:
|
||||
bridges = []
|
||||
self.bridges = bridges
|
||||
|
||||
# 创建设备、动作客户端和目标存储
|
||||
self.devices_names: Dict[str, str] = {} # 存储设备名称和命名空间的映射
|
||||
self.devices_instances: Dict[str, ROS2DeviceNode] = {} # 存储设备实例
|
||||
self._action_clients: Dict[str, ActionClient] = {} # 用来存储多个ActionClient实例
|
||||
self._action_value_mappings: Dict[str, Dict] = (
|
||||
{}
|
||||
) # 用来存储多个ActionClient的type, goal, feedback, result的变量名映射关系
|
||||
self._goals: Dict[str, Any] = {} # 用来存储多个目标的状态
|
||||
self._online_devices: Set[str] = set() # 用于跟踪在线设备
|
||||
self._last_discovery_time = 0.0 # 上次设备发现的时间
|
||||
self._discovery_lock = threading.Lock() # 设备发现的互斥锁
|
||||
self._subscribed_topics = set() # 用于跟踪已订阅的话题
|
||||
|
||||
# 创建物料增删改查服务(非客户端)
|
||||
self._init_resource_service()
|
||||
|
||||
self.device_status = {} # 用来存储设备状态
|
||||
self.device_status_timestamps = {} # 用来存储设备状态最后更新时间
|
||||
|
||||
# 首次发现网络中的设备
|
||||
self._discover_devices()
|
||||
|
||||
# 初始化所有本机设备节点,多一次过滤,防止重复初始化
|
||||
for device_id, device_config in devices_config.items():
|
||||
if device_config.get("type", "device") != "device":
|
||||
self.lab_logger().debug(f"[Host Node] Skipping type {device_config['type']} {device_id} already existed, skipping.")
|
||||
continue
|
||||
if device_id not in self.devices_names:
|
||||
self.initialize_device(device_id, device_config)
|
||||
else:
|
||||
self.lab_logger().warning(f"[Host Node] Device {device_id} already existed, skipping.")
|
||||
self.update_device_status_subscriptions()
|
||||
# TODO: 需要验证 初始化所有控制器节点
|
||||
if controllers_config:
|
||||
update_rate = controllers_config["controller_manager"]["ros__parameters"]["update_rate"]
|
||||
for controller_id, controller_config in controllers_config["controller_manager"]["ros__parameters"][
|
||||
"controllers"
|
||||
].items():
|
||||
controller_config["update_rate"] = update_rate
|
||||
self.initialize_controller(controller_id, controller_config)
|
||||
|
||||
for bridge in self.bridges:
|
||||
if hasattr(bridge, "resource_add"):
|
||||
self.lab_logger().info("[Host Node-Resource] Adding resources to bridge.")
|
||||
bridge.resource_add(add_schema(resources_config))
|
||||
|
||||
# 创建定时器,定期发现设备
|
||||
self._discovery_timer = self.create_timer(
|
||||
discovery_interval, self._discovery_devices_callback, callback_group=ReentrantCallbackGroup()
|
||||
)
|
||||
|
||||
self.lab_logger().info("[Host Node] Host node initialized.")
|
||||
HostNode._ready_event.set()
|
||||
|
||||
def _discover_devices(self) -> None:
|
||||
"""
|
||||
发现网络中的设备
|
||||
|
||||
检测ROS2网络中的所有设备节点,并为它们创建ActionClient
|
||||
同时检测设备离线情况
|
||||
"""
|
||||
self.lab_logger().debug("[Host Node] Discovering devices in the network...")
|
||||
|
||||
# 获取当前所有设备
|
||||
nodes_and_names = self.get_node_names_and_namespaces()
|
||||
|
||||
# 跟踪本次发现的设备,用于检测离线设备
|
||||
current_devices = set()
|
||||
|
||||
for device_id, namespace in nodes_and_names:
|
||||
if not namespace.startswith("/devices"):
|
||||
continue
|
||||
|
||||
# 将设备添加到当前设备集合
|
||||
device_key = f"{namespace}/{device_id}"
|
||||
current_devices.add(device_key)
|
||||
|
||||
# 如果是新设备,记录并创建ActionClient
|
||||
if device_id not in self.devices_names:
|
||||
self.lab_logger().info(f"[Host Node] Discovered new device: {device_key}")
|
||||
self.devices_names[device_id] = namespace
|
||||
self._create_action_clients_for_device(device_id, namespace)
|
||||
self._online_devices.add(device_key)
|
||||
elif device_key not in self._online_devices:
|
||||
# 设备重新上线
|
||||
self.lab_logger().info(f"[Host Node] Device reconnected: {device_key}")
|
||||
self._online_devices.add(device_key)
|
||||
|
||||
# 检测离线设备
|
||||
offline_devices = self._online_devices - current_devices
|
||||
for device_key in offline_devices:
|
||||
self.lab_logger().warning(f"[Host Node] Device offline: {device_key}")
|
||||
self._online_devices.discard(device_key)
|
||||
|
||||
# 更新在线设备列表
|
||||
self._online_devices = current_devices
|
||||
self.lab_logger().debug(f"[Host Node] Total online devices: {len(self._online_devices)}")
|
||||
|
||||
def _discovery_devices_callback(self) -> None:
|
||||
"""
|
||||
设备发现定时器回调函数
|
||||
"""
|
||||
# 使用互斥锁确保同时只有一个发现过程
|
||||
if self._discovery_lock.acquire(blocking=False):
|
||||
try:
|
||||
self._discover_devices()
|
||||
# 发现新设备后,更新设备状态订阅
|
||||
self.update_device_status_subscriptions()
|
||||
finally:
|
||||
self._discovery_lock.release()
|
||||
else:
|
||||
self.lab_logger().debug("[Host Node] Device discovery already in progress, skipping.")
|
||||
|
||||
def _create_action_clients_for_device(self, device_id: str, namespace: str) -> None:
|
||||
"""
|
||||
为设备创建所有必要的ActionClient
|
||||
|
||||
Args:
|
||||
device_id: 设备ID
|
||||
namespace: 设备命名空间
|
||||
"""
|
||||
for action_id, action_types in get_action_server_names_and_types_by_node(self, device_id, namespace):
|
||||
if action_id not in self._action_clients:
|
||||
try:
|
||||
action_type = get_ros_type_by_msgname(action_types[0])
|
||||
self._action_clients[action_id] = ActionClient(
|
||||
self, action_type, action_id, callback_group=self.callback_group
|
||||
)
|
||||
self.lab_logger().debug(f"[Host Node] Created ActionClient: {action_id}")
|
||||
from unilabos.app.mq import mqtt_client
|
||||
info_with_schema = ros_action_to_json_schema(action_type)
|
||||
mqtt_client.publish_actions(action_id, info_with_schema)
|
||||
except Exception as e:
|
||||
self.lab_logger().error(f"[Host Node] Failed to create ActionClient for {action_id}: {str(e)}")
|
||||
|
||||
def initialize_device(self, device_id: str, device_config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
根据配置初始化设备
|
||||
|
||||
此函数根据提供的设备配置动态导入适当的设备类并创建其实例。
|
||||
同时为设备的动作值映射设置动作客户端。
|
||||
|
||||
Args:
|
||||
device_id: 设备唯一标识符
|
||||
device_config: 设备配置字典,包含类型和其他参数
|
||||
"""
|
||||
self.lab_logger().info(f"[Host Node] Initializing device: {device_id}")
|
||||
|
||||
device_config_copy = copy.deepcopy(device_config)
|
||||
d = initialize_device_from_dict(device_id, device_config_copy)
|
||||
if d is None:
|
||||
return
|
||||
# noinspection PyProtectedMember
|
||||
self.devices_names[device_id] = d._ros_node.namespace
|
||||
self.devices_instances[device_id] = d
|
||||
# noinspection PyProtectedMember
|
||||
for action_name, action_value_mapping in d._ros_node._action_value_mappings.items():
|
||||
action_id = f"/devices/{device_id}/{action_name}"
|
||||
if action_id not in self._action_clients:
|
||||
action_type = action_value_mapping["type"]
|
||||
self._action_clients[action_id] = ActionClient(self, action_type, action_id)
|
||||
self.lab_logger().debug(f"[Host Node] Created ActionClient: {action_id}")
|
||||
from unilabos.app.mq import mqtt_client
|
||||
info_with_schema = ros_action_to_json_schema(action_type)
|
||||
mqtt_client.publish_actions(action_id, info_with_schema)
|
||||
else:
|
||||
self.lab_logger().warning(f"[Host Node] ActionClient {action_id} already exists.")
|
||||
device_key = f"{self.devices_names[device_id]}/{device_id}"
|
||||
# 添加到在线设备列表
|
||||
self._online_devices.add(device_key)
|
||||
|
||||
def update_device_status_subscriptions(self) -> None:
|
||||
"""
|
||||
更新设备状态订阅
|
||||
|
||||
扫描所有设备话题,为新的话题创建订阅,确保不会重复订阅
|
||||
"""
|
||||
topic_names_and_types = self.get_topic_names_and_types()
|
||||
for topic, types in topic_names_and_types:
|
||||
# 检查是否为设备状态话题且未订阅过
|
||||
if (
|
||||
topic.startswith("/devices/")
|
||||
and not types[0].endswith("FeedbackMessage")
|
||||
and "_action" not in topic
|
||||
and topic not in self._subscribed_topics
|
||||
):
|
||||
|
||||
# 解析设备名和属性名
|
||||
parts = topic.split("/")
|
||||
if len(parts) >= 4:
|
||||
device_id = parts[-2]
|
||||
property_name = parts[-1]
|
||||
|
||||
# 初始化设备状态字典
|
||||
if device_id not in self.device_status:
|
||||
self.device_status[device_id] = {}
|
||||
self.device_status_timestamps[device_id] = {}
|
||||
|
||||
# 默认初始化属性值为 None
|
||||
self.device_status[device_id][property_name] = None
|
||||
self.device_status_timestamps[device_id][property_name] = 0 # 初始化时间戳
|
||||
|
||||
# 动态创建订阅
|
||||
try:
|
||||
type_class = msg_converter_manager.search_class(types[0].replace("/", "."))
|
||||
if type_class is None:
|
||||
self.lab_logger().error(f"[Host Node] Invalid type {types[0]} for {topic}")
|
||||
else:
|
||||
self.create_subscription(
|
||||
type_class,
|
||||
topic,
|
||||
lambda msg, d=device_id, p=property_name: self.property_callback(msg, d, p),
|
||||
1,
|
||||
callback_group=ReentrantCallbackGroup(),
|
||||
)
|
||||
# 标记为已订阅
|
||||
self._subscribed_topics.add(topic)
|
||||
self.lab_logger().debug(f"[Host Node] Subscribed to new topic: {topic}")
|
||||
except (NameError, SyntaxError) as e:
|
||||
self.lab_logger().error(f"[Host Node] Failed to create subscription for topic {topic}: {e}")
|
||||
|
||||
"""设备相关"""
|
||||
|
||||
def property_callback(self, msg, device_id: str, property_name: str) -> None:
|
||||
"""
|
||||
更新设备状态字典中的属性值,并发送到桥接器。
|
||||
|
||||
Args:
|
||||
msg: 接收到的消息
|
||||
device_id: 设备ID
|
||||
property_name: 属性名称
|
||||
"""
|
||||
# 更新设备状态字典
|
||||
if hasattr(msg, "data"):
|
||||
bChange = False
|
||||
if isinstance(msg.data, (float, int, str)):
|
||||
if self.device_status[device_id][property_name] != msg.data:
|
||||
bChange = True
|
||||
self.device_status[device_id][property_name] = msg.data
|
||||
# 更新时间戳
|
||||
self.device_status_timestamps[device_id][property_name] = time.time()
|
||||
else:
|
||||
self.lab_logger().debug(
|
||||
f"[Host Node] Unsupported data type for {device_id}/{property_name}: {type(msg.data)}"
|
||||
)
|
||||
|
||||
# 所有 Bridge 对象都应具有 publish_device_status 方法;都会收到设备状态更新
|
||||
if bChange:
|
||||
for bridge in self.bridges:
|
||||
if hasattr(bridge, "publish_device_status"):
|
||||
bridge.publish_device_status(self.device_status, device_id, property_name)
|
||||
self.lab_logger().debug(
|
||||
f"[Host Node] Status updated: {device_id}.{property_name} = {msg.data}"
|
||||
)
|
||||
|
||||
def send_goal(
|
||||
self, device_id: str, action_name: str, action_kwargs: Dict[str, Any], goal_uuid: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
向设备发送目标请求
|
||||
|
||||
Args:
|
||||
device_id: 设备ID
|
||||
action_name: 动作名称
|
||||
action_kwargs: 动作参数
|
||||
goal_uuid: 目标UUID,如果为None则自动生成
|
||||
"""
|
||||
action_id = f"/devices/{device_id}/{action_name}"
|
||||
if action_id not in self._action_clients:
|
||||
self.lab_logger().error(f"[Host Node] ActionClient {action_id} not found.")
|
||||
return
|
||||
|
||||
action_client: ActionClient = self._action_clients[action_id]
|
||||
|
||||
goal_msg = convert_to_ros_msg(action_client._action_type.Goal(), action_kwargs)
|
||||
|
||||
self.lab_logger().info(f"[Host Node] Sending goal for {action_id}: {goal_msg}")
|
||||
action_client.wait_for_server()
|
||||
|
||||
uuid_str = goal_uuid
|
||||
if goal_uuid is not None:
|
||||
u = uuid.UUID(goal_uuid)
|
||||
goal_uuid_obj = UUID(uuid=list(u.bytes))
|
||||
else:
|
||||
goal_uuid_obj = None
|
||||
|
||||
future = action_client.send_goal_async(
|
||||
goal_msg,
|
||||
feedback_callback=lambda feedback_msg: self.feedback_callback(action_id, uuid_str, feedback_msg),
|
||||
goal_uuid=goal_uuid_obj,
|
||||
)
|
||||
future.add_done_callback(lambda future: self.goal_response_callback(action_id, uuid_str, future))
|
||||
|
||||
def goal_response_callback(self, action_id: str, uuid_str: Optional[str], future) -> None:
|
||||
"""目标响应回调"""
|
||||
goal_handle = future.result()
|
||||
if not goal_handle.accepted:
|
||||
self.lab_logger().warning(f"[Host Node] Goal {action_id} ({uuid_str}) rejected")
|
||||
return
|
||||
|
||||
self.lab_logger().info(f"[Host Node] Goal {action_id} ({uuid_str}) accepted")
|
||||
if uuid_str:
|
||||
self._goals[uuid_str] = goal_handle
|
||||
goal_handle.get_result_async().add_done_callback(
|
||||
lambda future: self.get_result_callback(action_id, uuid_str, future)
|
||||
)
|
||||
|
||||
def feedback_callback(self, action_id: str, uuid_str: Optional[str], feedback_msg) -> None:
|
||||
"""反馈回调"""
|
||||
feedback_data = convert_from_ros_msg(feedback_msg)
|
||||
feedback_data.pop("goal_id")
|
||||
self.lab_logger().debug(f"[Host Node] Feedback for {action_id} ({uuid_str}): {feedback_data}")
|
||||
|
||||
if uuid_str:
|
||||
for bridge in self.bridges:
|
||||
if hasattr(bridge, "publish_job_status"):
|
||||
bridge.publish_job_status(feedback_data, uuid_str, "running")
|
||||
|
||||
def get_result_callback(self, action_id: str, uuid_str: Optional[str], future) -> None:
|
||||
"""获取结果回调"""
|
||||
result_msg = future.result().result
|
||||
result_data = convert_from_ros_msg(result_msg)
|
||||
self.lab_logger().info(f"[Host Node] Result for {action_id} ({uuid_str}): success")
|
||||
self.lab_logger().debug(f"[Host Node] Result data: {result_data}")
|
||||
|
||||
if uuid_str:
|
||||
for bridge in self.bridges:
|
||||
if hasattr(bridge, "publish_job_status"):
|
||||
bridge.publish_job_status(result_data, uuid_str, "success")
|
||||
|
||||
def cancel_goal(self, goal_uuid: str) -> None:
|
||||
"""取消目标"""
|
||||
if goal_uuid in self._goals:
|
||||
self.lab_logger().info(f"[Host Node] Cancelling goal {goal_uuid}")
|
||||
self._goals[goal_uuid].cancel_goal_async()
|
||||
else:
|
||||
self.lab_logger().warning(f"[Host Node] Goal {goal_uuid} not found, cannot cancel")
|
||||
|
||||
def get_goal_status(self, uuid_str: str) -> int:
|
||||
"""获取目标状态"""
|
||||
if uuid_str in self._goals:
|
||||
g = self._goals[uuid_str]
|
||||
status = g.status
|
||||
self.lab_logger().debug(f"[Host Node] Goal status for {uuid_str}: {status}")
|
||||
return status
|
||||
self.lab_logger().warning(f"[Host Node] Goal {uuid_str} not found, status unknown")
|
||||
return GoalStatus.STATUS_UNKNOWN
|
||||
|
||||
"""Controller Node"""
|
||||
|
||||
def initialize_controller(self, controller_id: str, controller_config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
初始化控制器
|
||||
|
||||
Args:
|
||||
controller_id: 控制器ID
|
||||
controller_config: 控制器配置
|
||||
"""
|
||||
self.lab_logger().info(f"[Host Node] Initializing controller: {controller_id}")
|
||||
|
||||
class_name = controller_config.pop("type")
|
||||
controller_func = globals()[class_name]
|
||||
|
||||
for input_name, input_info in controller_config["inputs"].items():
|
||||
controller_config["inputs"][input_name]["type"] = get_msg_type(eval(input_info["type"]))
|
||||
for output_name, output_info in controller_config["outputs"].items():
|
||||
controller_config["outputs"][output_name]["type"] = get_msg_type(eval(output_info["type"]))
|
||||
|
||||
if controller_config["parameters"] is None:
|
||||
controller_config["parameters"] = {}
|
||||
|
||||
controller = ControllerNode(controller_id, controller_func=controller_func, **controller_config)
|
||||
self.lab_logger().info(f"[Host Node] Controller {controller_id} created.")
|
||||
# rclpy.get_global_executor().add_node(controller)
|
||||
|
||||
"""Resource"""
|
||||
|
||||
def _init_resource_service(self):
|
||||
self._resource_services: Dict[str, Service] = {
|
||||
"resource_add": self.create_service(
|
||||
ResourceAdd, "/resources/add", self._resource_add_callback, callback_group=ReentrantCallbackGroup()
|
||||
),
|
||||
"resource_get": self.create_service(
|
||||
ResourceGet, "/resources/get", self._resource_get_callback, callback_group=ReentrantCallbackGroup()
|
||||
),
|
||||
"resource_delete": self.create_service(
|
||||
ResourceDelete,
|
||||
"/resources/delete",
|
||||
self._resource_delete_callback,
|
||||
callback_group=ReentrantCallbackGroup(),
|
||||
),
|
||||
"resource_update": self.create_service(
|
||||
ResourceUpdate,
|
||||
"/resources/update",
|
||||
self._resource_update_callback,
|
||||
callback_group=ReentrantCallbackGroup(),
|
||||
),
|
||||
"resource_list": self.create_service(
|
||||
ResourceList, "/resources/list", self._resource_list_callback, callback_group=ReentrantCallbackGroup()
|
||||
),
|
||||
}
|
||||
|
||||
def _resource_add_callback(self, request, response):
|
||||
"""
|
||||
添加资源回调
|
||||
|
||||
处理添加资源请求,将资源数据传递到桥接器
|
||||
|
||||
Args:
|
||||
request: 包含资源数据的请求对象
|
||||
response: 响应对象
|
||||
|
||||
Returns:
|
||||
响应对象,包含操作结果
|
||||
"""
|
||||
resources = [convert_from_ros_msg(resource) for resource in request.resources]
|
||||
self.lab_logger().info(f"[Host Node-Resource] Add request received: {len(resources)} resources")
|
||||
|
||||
success = False
|
||||
if len(self.bridges) > 0:
|
||||
r = self.bridges[-1].resource_add(add_schema(resources))
|
||||
success = bool(r)
|
||||
|
||||
response.success = success
|
||||
self.lab_logger().info(f"[Host Node-Resource] Add request completed, success: {success}")
|
||||
return response
|
||||
|
||||
def _resource_get_callback(self, request, response):
|
||||
"""
|
||||
获取资源回调
|
||||
|
||||
处理获取资源请求,从桥接器或本地查询资源数据
|
||||
|
||||
Args:
|
||||
request: 包含资源ID的请求对象
|
||||
response: 响应对象
|
||||
|
||||
Returns:
|
||||
响应对象,包含查询到的资源
|
||||
"""
|
||||
self.lab_logger().info(f"[Host Node-Resource] Get request for ID: {request.id}")
|
||||
|
||||
if len(self.bridges) > 0:
|
||||
# 云上物料服务,根据 id 查询物料
|
||||
try:
|
||||
r = self.bridges[-1].resource_get(request.id, request.with_children)["data"]
|
||||
self.lab_logger().debug(f"[Host Node-Resource] Retrieved from bridge: {len(r)} resources")
|
||||
except Exception as e:
|
||||
self.lab_logger().error(f"[Host Node-Resource] Error retrieving from bridge: {str(e)}")
|
||||
r = []
|
||||
else:
|
||||
# 本地物料服务,根据 id 查询物料
|
||||
r = [resource for resource in self.resources_config if resource.get("id") == request.id]
|
||||
self.lab_logger().debug(f"[Host Node-Resource] Retrieved from local: {len(r)} resources")
|
||||
|
||||
response.resources = [convert_to_ros_msg(Resource, resource) for resource in r]
|
||||
return response
|
||||
|
||||
def _resource_delete_callback(self, request, response):
|
||||
"""
|
||||
删除资源回调
|
||||
|
||||
处理删除资源请求,将删除指令传递到桥接器
|
||||
|
||||
Args:
|
||||
request: 包含资源ID的请求对象
|
||||
response: 响应对象
|
||||
|
||||
Returns:
|
||||
响应对象,包含操作结果
|
||||
"""
|
||||
self.lab_logger().info(f"[Host Node-Resource] Delete request for ID: {request.id}")
|
||||
|
||||
success = False
|
||||
if len(self.bridges) > 0:
|
||||
try:
|
||||
r = self.bridges[-1].resource_delete(request.id)
|
||||
success = bool(r)
|
||||
except Exception as e:
|
||||
self.lab_logger().error(f"[Host Node-Resource] Error deleting resource: {str(e)}")
|
||||
|
||||
response.success = success
|
||||
self.lab_logger().info(f"[Host Node-Resource] Delete request completed, success: {success}")
|
||||
return response
|
||||
|
||||
def _resource_update_callback(self, request, response):
|
||||
"""
|
||||
更新资源回调
|
||||
|
||||
处理更新资源请求,将更新指令传递到桥接器
|
||||
|
||||
Args:
|
||||
request: 包含资源数据的请求对象
|
||||
response: 响应对象
|
||||
|
||||
Returns:
|
||||
响应对象,包含操作结果
|
||||
"""
|
||||
resources = [convert_from_ros_msg(resource) for resource in request.resources]
|
||||
self.lab_logger().info(f"[Host Node-Resource] Update request received: {len(resources)} resources")
|
||||
|
||||
success = False
|
||||
if len(self.bridges) > 0:
|
||||
try:
|
||||
r = self.bridges[-1].resource_update(add_schema(resources))
|
||||
success = bool(r)
|
||||
except Exception as e:
|
||||
self.lab_logger().error(f"[Host Node-Resource] Error updating resources: {str(e)}")
|
||||
|
||||
response.success = success
|
||||
self.lab_logger().info(f"[Host Node-Resource] Update request completed, success: {success}")
|
||||
return response
|
||||
|
||||
def _resource_list_callback(self, request, response):
|
||||
"""
|
||||
列出资源回调
|
||||
|
||||
处理列出资源请求,返回所有可用资源
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
response: 响应对象
|
||||
|
||||
Returns:
|
||||
响应对象,包含资源列表
|
||||
"""
|
||||
self.lab_logger().info(f"[Host Node-Resource] List request received")
|
||||
# 这里可以实现返回资源列表的逻辑
|
||||
self.lab_logger().debug(f"[Host Node-Resource] List parameters: {request}")
|
||||
return response
|
||||
267
unilabos/ros/nodes/presets/protocol_node.py
Normal file
267
unilabos/ros/nodes/presets/protocol_node.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
import rclpy
|
||||
from unilabos.messages import * # type: ignore # protocol names
|
||||
from rclpy.action import ActionServer, ActionClient
|
||||
from rclpy.action.server import ServerGoalHandle
|
||||
from rclpy.callback_groups import ReentrantCallbackGroup
|
||||
from unilabos_msgs.msg import Resource # type: ignore
|
||||
from unilabos_msgs.srv import ResourceGet, ResourceUpdate # type: ignore
|
||||
|
||||
from unilabos.compile import action_protocol_generators
|
||||
from unilabos.resources.graphio import list_to_nested_dict, nested_dict_to_list
|
||||
from unilabos.ros.initialize_device import initialize_device_from_dict
|
||||
from unilabos.ros.msgs.message_converter import (
|
||||
get_action_type,
|
||||
convert_to_ros_msg,
|
||||
convert_from_ros_msg,
|
||||
convert_from_ros_msg_with_mapping,
|
||||
)
|
||||
from unilabos.ros.nodes.base_device_node import BaseROS2DeviceNode, DeviceNodeResourceTracker
|
||||
|
||||
|
||||
class ROS2ProtocolNode(BaseROS2DeviceNode):
|
||||
"""
|
||||
ROS2ProtocolNode代表管理ROS2环境中设备通信和动作的协议节点。
|
||||
它初始化设备节点,处理动作客户端,并基于指定的协议执行工作流。
|
||||
它还物理上代表一组协同工作的设备,如带夹持器的机械臂,带传送带的CNC机器等。
|
||||
"""
|
||||
|
||||
# create_action_server = False # Action Server要自己创建
|
||||
|
||||
def __init__(self, device_id: str, children: dict, protocol_type: Union[str, list[str]], resource_tracker: DeviceNodeResourceTracker, *args, **kwargs):
|
||||
self._setup_protocol_names(protocol_type)
|
||||
|
||||
# 初始化其它属性
|
||||
self.children = children
|
||||
self._busy = False
|
||||
self.sub_devices = {}
|
||||
self._goals = {}
|
||||
self._protocol_servers = {}
|
||||
self._action_clients = {}
|
||||
|
||||
# 初始化基类,让基类处理常规动作
|
||||
super().__init__(
|
||||
driver_instance=self,
|
||||
device_id=device_id,
|
||||
status_types={},
|
||||
action_value_mappings=self.protocol_action_mappings,
|
||||
hardware_interface={},
|
||||
print_publish=False,
|
||||
resource_tracker=resource_tracker,
|
||||
)
|
||||
|
||||
# 初始化子设备
|
||||
communication_node_id = None
|
||||
for device_id, device_config in self.children.items():
|
||||
if device_config.get("type", "device") != "device":
|
||||
self.lab_logger().debug(f"[Protocol Node] Skipping type {device_config['type']} {device_id} already existed, skipping.")
|
||||
continue
|
||||
d = self.initialize_device(device_id, device_config)
|
||||
if d is None:
|
||||
continue
|
||||
|
||||
if "serial_" in device_id or "io_" in device_id:
|
||||
communication_node_id = device_id
|
||||
continue
|
||||
|
||||
# 设置硬件接口代理
|
||||
if d and hasattr(d, "_hardware_interface"):
|
||||
if (
|
||||
hasattr(d, d._hardware_interface["name"])
|
||||
and hasattr(d, d._hardware_interface["write"])
|
||||
and (d._hardware_interface["read"] is None or hasattr(d, d._hardware_interface["read"]))
|
||||
):
|
||||
|
||||
name = getattr(d, d._hardware_interface["name"])
|
||||
read = d._hardware_interface.get("read", None)
|
||||
write = d._hardware_interface.get("write", None)
|
||||
|
||||
# 如果硬件接口是字符串,通过通信设备提供
|
||||
if isinstance(name, str) and communication_node_id in self.sub_devices:
|
||||
self._setup_hardware_proxy(d, self.sub_devices[communication_node_id], read, write)
|
||||
|
||||
def _setup_protocol_names(self, protocol_type):
|
||||
# 处理协议类型
|
||||
if isinstance(protocol_type, str):
|
||||
if "," not in protocol_type:
|
||||
self.protocol_names = [protocol_type]
|
||||
else:
|
||||
self.protocol_names = [protocol.strip() for protocol in protocol_type.split(",")]
|
||||
else:
|
||||
self.protocol_names = protocol_type
|
||||
# 准备协议相关的动作值映射
|
||||
self.protocol_action_mappings = {}
|
||||
for protocol_name in self.protocol_names:
|
||||
protocol_type = globals()[protocol_name]
|
||||
self.protocol_action_mappings[protocol_name] = get_action_type(protocol_type)
|
||||
|
||||
def initialize_device(self, device_id, device_config):
|
||||
"""初始化设备并创建相应的动作客户端"""
|
||||
device_id_abs = f"{self.device_id}/{device_id}"
|
||||
self.lab_logger().info(f"初始化子设备: {device_id_abs}")
|
||||
d = self.sub_devices[device_id] = initialize_device_from_dict(device_id_abs, device_config)
|
||||
|
||||
# 为子设备的每个动作创建动作客户端
|
||||
if d is not None and hasattr(d, "ros_node_instance"):
|
||||
node = d.ros_node_instance
|
||||
for action_name, action_mapping in node._action_value_mappings.items():
|
||||
action_id = f"/devices/{device_id_abs}/{action_name}"
|
||||
if action_id not in self._action_clients:
|
||||
self._action_clients[action_id] = ActionClient(
|
||||
self, action_mapping["type"], action_id, callback_group=self.callback_group
|
||||
)
|
||||
self.lab_logger().debug(f"为子设备 {device_id} 创建动作客户端: {action_name}")
|
||||
return d
|
||||
|
||||
def create_ros_action_server(self, action_name, action_value_mapping):
|
||||
"""创建ROS动作服务器"""
|
||||
# 和Base创建的路径是一致的
|
||||
protocol_name = action_name
|
||||
action_type = action_value_mapping["type"]
|
||||
str_action_type = str(action_type)[8:-2]
|
||||
protocol_type = globals()[protocol_name]
|
||||
protocol_steps_generator = action_protocol_generators[protocol_type]
|
||||
|
||||
self._action_servers[action_name] = ActionServer(
|
||||
self,
|
||||
action_type,
|
||||
action_name,
|
||||
execute_callback=self._create_protocol_execute_callback(action_name, protocol_steps_generator),
|
||||
callback_group=ReentrantCallbackGroup(),
|
||||
)
|
||||
|
||||
self.lab_logger().debug(f"发布动作: {action_name}, 类型: {str_action_type}")
|
||||
|
||||
def _create_protocol_execute_callback(self, protocol_name, protocol_steps_generator):
|
||||
async def execute_protocol(goal_handle: ServerGoalHandle):
|
||||
"""执行完整的工作流"""
|
||||
self.get_logger().info(f'Executing {protocol_name} action...')
|
||||
action_value_mapping = self._action_value_mappings[protocol_name]
|
||||
print('+'*30)
|
||||
print(protocol_steps_generator)
|
||||
# 从目标消息中提取参数, 并调用Protocol生成器(根据设备连接图)生成action步骤
|
||||
goal = goal_handle.request
|
||||
protocol_kwargs = convert_from_ros_msg_with_mapping(goal, action_value_mapping["goal"])
|
||||
|
||||
# 向Host查询物料当前状态
|
||||
for k, v in goal.get_fields_and_field_types().items():
|
||||
if v in ["unilabos_msgs/Resource", "sequence<unilabos_msgs/Resource>"]:
|
||||
r = ResourceGet.Request()
|
||||
r.id = protocol_kwargs[k]["id"] if v == "unilabos_msgs/Resource" else protocol_kwargs[k][0]["id"]
|
||||
r.with_children = True
|
||||
response = await self._resource_clients["resource_get"].call_async(r)
|
||||
protocol_kwargs[k] = list_to_nested_dict([convert_from_ros_msg(rs) for rs in response.resources])
|
||||
|
||||
from unilabos.resources.graphio import physical_setup_graph
|
||||
self.get_logger().info(f'Working on physical setup: {physical_setup_graph}')
|
||||
protocol_steps = protocol_steps_generator(G=physical_setup_graph, **protocol_kwargs)
|
||||
|
||||
self.get_logger().info(f'Goal received: {protocol_kwargs}, running steps: \n{protocol_steps}')
|
||||
|
||||
time_start = time.time()
|
||||
time_overall = 100
|
||||
self._busy = True
|
||||
|
||||
# 逐步执行工作流
|
||||
for i, action in enumerate(protocol_steps):
|
||||
self.get_logger().info(f'Running step {i+1}: {action}')
|
||||
if type(action) == dict:
|
||||
# 如果是单个动作,直接执行
|
||||
if action["action_name"] == "wait":
|
||||
time.sleep(action["action_kwargs"]["time"])
|
||||
else:
|
||||
result = await self.execute_single_action(**action)
|
||||
elif type(action) == list:
|
||||
# 如果是并行动作,同时执行
|
||||
actions = action
|
||||
futures = [rclpy.get_global_executor().create_task(self.execute_single_action(**a)) for a in actions]
|
||||
results = [await f for f in futures]
|
||||
|
||||
# 向Host更新物料当前状态
|
||||
for k, v in goal.get_fields_and_field_types().items():
|
||||
if v in ["unilabos_msgs/Resource", "sequence<unilabos_msgs/Resource>"]:
|
||||
r = ResourceUpdate.Request()
|
||||
r.resources = [
|
||||
convert_to_ros_msg(Resource, rs) for rs in nested_dict_to_list(protocol_kwargs[k])
|
||||
]
|
||||
response = await self._resource_clients["resource_update"].call_async(r)
|
||||
|
||||
goal_handle.succeed()
|
||||
result = action_value_mapping["type"].Result()
|
||||
result.success = True
|
||||
|
||||
self._busy = False
|
||||
return result
|
||||
return execute_protocol
|
||||
|
||||
async def execute_single_action(self, device_id, action_name, action_kwargs):
|
||||
"""执行单个动作"""
|
||||
# 构建动作ID
|
||||
if device_id in ["", None, "self"]:
|
||||
action_id = f"/devices/{self.device_id}/{action_name}"
|
||||
else:
|
||||
action_id = f"/devices/{self.device_id}/{device_id}/{action_name}"
|
||||
|
||||
# 检查动作客户端是否存在
|
||||
if action_id not in self._action_clients:
|
||||
self.lab_logger().error(f"找不到动作客户端: {action_id}")
|
||||
return None
|
||||
|
||||
# 发送动作请求
|
||||
action_client = self._action_clients[action_id]
|
||||
goal_msg = convert_to_ros_msg(action_client._action_type.Goal(), action_kwargs)
|
||||
|
||||
self.lab_logger().info(f"发送动作请求到: {action_id}")
|
||||
action_client.wait_for_server()
|
||||
|
||||
# 等待动作完成
|
||||
request_future = action_client.send_goal_async(goal_msg)
|
||||
handle = await request_future
|
||||
|
||||
if not handle.accepted:
|
||||
self.lab_logger().error(f"动作请求被拒绝: {action_name}")
|
||||
return None
|
||||
|
||||
result_future = await handle.get_result_async()
|
||||
self.lab_logger().info(f"动作完成: {action_name}")
|
||||
|
||||
return result_future.result
|
||||
|
||||
|
||||
"""还没有改过的部分"""
|
||||
|
||||
def _setup_hardware_proxy(self, device, communication_device, read_method, write_method):
|
||||
"""为设备设置硬件接口代理"""
|
||||
extra_info = [getattr(device, info) for info in communication_device._hardware_interface.get("extra_info", [])]
|
||||
write_func = getattr(communication_device, communication_device._hardware_interface["write"])
|
||||
read_func = getattr(communication_device, communication_device._hardware_interface["read"])
|
||||
|
||||
def _read():
|
||||
return read_func(*extra_info)
|
||||
|
||||
def _write(command):
|
||||
return write_func(*extra_info, command)
|
||||
|
||||
if read_method:
|
||||
setattr(device, read_method, _read)
|
||||
if write_method:
|
||||
setattr(device, write_method, _write)
|
||||
|
||||
|
||||
async def _update_resources(self, goal, protocol_kwargs):
|
||||
"""更新资源状态"""
|
||||
for k, v in goal.get_fields_and_field_types().items():
|
||||
if v in ["unilabos_msgs/Resource", "sequence<unilabos_msgs/Resource>"]:
|
||||
if protocol_kwargs[k] is not None:
|
||||
try:
|
||||
r = ResourceUpdate.Request()
|
||||
r.resources = [
|
||||
convert_to_ros_msg(Resource, rs) for rs in nested_dict_to_list(protocol_kwargs[k])
|
||||
]
|
||||
await self._resource_clients["resource_update"].call_async(r)
|
||||
except Exception as e:
|
||||
self.lab_logger().error(f"更新资源失败: {e}")
|
||||
84
unilabos/ros/nodes/presets/serial_node.py
Normal file
84
unilabos/ros/nodes/presets/serial_node.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from threading import Lock
|
||||
|
||||
from unilabos_msgs.srv import SerialCommand
|
||||
from serial import Serial, SerialException
|
||||
|
||||
from unilabos.ros.nodes.base_device_node import BaseROS2DeviceNode, DeviceNodeResourceTracker
|
||||
|
||||
|
||||
class ROS2SerialNode(BaseROS2DeviceNode):
|
||||
def __init__(self, device_id, port: str, baudrate: int = 9600, resource_tracker: DeviceNodeResourceTracker=None):
|
||||
# 保存属性,以便在调用父类初始化前使用
|
||||
self.port = port
|
||||
self.baudrate = baudrate
|
||||
self._hardware_interface = {"name": "hardware_interface", "write": "send_command", "read": "read_data"}
|
||||
self._busy = False
|
||||
self._closing = False
|
||||
self._query_lock = Lock()
|
||||
|
||||
# 初始化硬件接口
|
||||
try:
|
||||
self.hardware_interface = Serial(baudrate=baudrate, port=port)
|
||||
except (OSError, SerialException) as e:
|
||||
# 因为还没调用父类初始化,无法使用日志,直接抛出异常
|
||||
raise RuntimeError(f"Failed to connect to serial port {port} at {baudrate} baudrate.") from e
|
||||
|
||||
# 初始化BaseROS2DeviceNode,使用自身作为driver_instance
|
||||
BaseROS2DeviceNode.__init__(
|
||||
self,
|
||||
driver_instance=self,
|
||||
device_id=device_id,
|
||||
status_types={},
|
||||
action_value_mappings={},
|
||||
hardware_interface=self._hardware_interface,
|
||||
print_publish=False,
|
||||
resource_tracker=resource_tracker,
|
||||
)
|
||||
|
||||
# 现在可以使用日志
|
||||
self.lab_logger().info(
|
||||
f"【ROS2SerialNode.__init__】初始化串口节点: {device_id}, 端口: {port}, 波特率: {baudrate}"
|
||||
)
|
||||
self.lab_logger().info(f"【ROS2SerialNode.__init__】成功连接串口设备")
|
||||
|
||||
# 创建服务
|
||||
self.create_service(SerialCommand, "serialwrite", self.handle_serial_request)
|
||||
self.lab_logger().info(f"【ROS2SerialNode.__init__】创建串口写入服务: serialwrite")
|
||||
|
||||
def send_command(self, command: str):
|
||||
self.lab_logger().info(f"【ROS2SerialNode.send_command】发送命令: {command}")
|
||||
with self._query_lock:
|
||||
if self._closing:
|
||||
self.lab_logger().error(f"【ROS2SerialNode.send_command】设备正在关闭,无法发送命令")
|
||||
raise RuntimeError
|
||||
|
||||
full_command = f"{command}\n"
|
||||
full_command_data = bytearray(full_command, "ascii")
|
||||
|
||||
response = self.hardware_interface.write(full_command_data)
|
||||
# time.sleep(0.05)
|
||||
output = self._receive(self.hardware_interface.read_until(b"\n"))
|
||||
self.lab_logger().info(f"【ROS2SerialNode.send_command】接收响应: {output}")
|
||||
return output
|
||||
|
||||
def read_data(self):
|
||||
self.lab_logger().debug(f"【ROS2SerialNode.read_data】读取数据")
|
||||
with self._query_lock:
|
||||
if self._closing:
|
||||
self.lab_logger().error(f"【ROS2SerialNode.read_data】设备正在关闭,无法读取数据")
|
||||
raise RuntimeError
|
||||
data = self.hardware_interface.read_until(b"\n")
|
||||
result = self._receive(data)
|
||||
self.lab_logger().debug(f"【ROS2SerialNode.read_data】读取到数据: {result}")
|
||||
return result
|
||||
|
||||
def _receive(self, data: bytes):
|
||||
ascii_string = "".join(chr(byte) for byte in data)
|
||||
self.lab_logger().debug(f"【ROS2SerialNode._receive】接收数据: {ascii_string}")
|
||||
return ascii_string
|
||||
|
||||
def handle_serial_request(self, request, response):
|
||||
self.lab_logger().info(f"【ROS2SerialNode.handle_serial_request】收到串口命令请求: {request.command}")
|
||||
response.response = self.send_command(request.command)
|
||||
self.lab_logger().info(f"【ROS2SerialNode.handle_serial_request】命令响应: {response.response}")
|
||||
return response
|
||||
67
unilabos/ros/nodes/resource_tracker.py
Normal file
67
unilabos/ros/nodes/resource_tracker.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from unilabos.utils.log import logger
|
||||
|
||||
|
||||
class DeviceNodeResourceTracker:
|
||||
|
||||
def __init__(self):
|
||||
self.resources = []
|
||||
self.root_resource2resource = {}
|
||||
pass
|
||||
|
||||
def root_resource(self, resource):
|
||||
if id(resource) in self.root_resource2resource:
|
||||
return self.root_resource2resource[id(resource)]
|
||||
else:
|
||||
return resource
|
||||
|
||||
def add_resource(self, resource):
|
||||
# 使用内存地址跟踪是否为同一个resource
|
||||
for r in self.resources:
|
||||
if id(r) == id(resource):
|
||||
return
|
||||
# 添加资源到跟踪器
|
||||
self.resources.append(resource)
|
||||
|
||||
def clear_resource(self):
|
||||
self.resources = []
|
||||
|
||||
def figure_resource(self, resource):
|
||||
# 使用内存地址跟踪是否为同一个resource
|
||||
if isinstance(resource, list):
|
||||
return [self.figure_resource(r) for r in resource]
|
||||
res_id = resource.id if hasattr(resource, "id") else None
|
||||
res_name = resource.name if hasattr(resource, "name") else None
|
||||
res_identifier = res_id if res_id else res_name
|
||||
identifier_key = "id" if res_id else "name"
|
||||
resource_cls_type = type(resource)
|
||||
if res_identifier is None:
|
||||
logger.warning(f"resource {resource} 没有id或name,暂不能对应figure")
|
||||
res_list = []
|
||||
for r in self.resources:
|
||||
res_list.extend(
|
||||
self.loop_find_resource(r, resource_cls_type, identifier_key, getattr(resource, identifier_key))
|
||||
)
|
||||
assert len(res_list) == 1, f"找到多个资源,请检查资源是否唯一: {res_list}"
|
||||
self.root_resource2resource[id(resource)] = res_list[0]
|
||||
# 后续加入其他对比方式
|
||||
return res_list[0]
|
||||
|
||||
def loop_find_resource(self, resource, resource_cls_type, identifier_key, compare_value):
|
||||
res_list = []
|
||||
children = getattr(resource, "children", [])
|
||||
for child in children:
|
||||
res_list.extend(self.loop_find_resource(child, resource_cls_type, identifier_key, compare_value))
|
||||
if resource_cls_type == type(resource):
|
||||
if hasattr(resource, identifier_key):
|
||||
if getattr(resource, identifier_key) == compare_value:
|
||||
res_list.append(resource)
|
||||
return res_list
|
||||
|
||||
def filter_find_list(self, res_list, compare_std_dict):
|
||||
new_list = []
|
||||
for res in res_list:
|
||||
for k, v in compare_std_dict.items():
|
||||
if hasattr(res, k):
|
||||
if getattr(res, k) == v:
|
||||
new_list.append(res)
|
||||
return new_list
|
||||
0
unilabos/ros/scripts/__init__.py
Normal file
0
unilabos/ros/scripts/__init__.py
Normal file
55
unilabos/ros/scripts/pydantic2rosmsg.py
Normal file
55
unilabos/ros/scripts/pydantic2rosmsg.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import os
|
||||
import inspect
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import get_type_hints
|
||||
|
||||
# 定义你要解析的 pydantic 模型所在的 Python 文件
|
||||
MODULES = ['my_pydantic_models'] # 替换为你的 Python 模块名
|
||||
|
||||
ROS_MSG_DIR = 'msg' # 消息文件生成目录
|
||||
|
||||
|
||||
def map_python_type_to_ros(python_type):
|
||||
type_map = {
|
||||
int: 'int32',
|
||||
float: 'float64',
|
||||
str: 'string',
|
||||
bool: 'bool',
|
||||
list: '[]', # List in Pydantic should be handled separately
|
||||
}
|
||||
return type_map.get(python_type, None)
|
||||
|
||||
|
||||
def generate_ros_msg_from_pydantic(model):
|
||||
fields = get_type_hints(model)
|
||||
ros_msg_lines = []
|
||||
|
||||
for field_name, field_type in fields.items():
|
||||
ros_type = map_python_type_to_ros(field_type)
|
||||
if not ros_type:
|
||||
raise TypeError(f"Unsupported type {field_type} for field {field_name}")
|
||||
|
||||
ros_msg_lines.append(f"{ros_type} {field_name}\n")
|
||||
|
||||
return ''.join(ros_msg_lines)
|
||||
|
||||
|
||||
def save_ros_msg_file(model_name, ros_msg_definition):
|
||||
msg_file_path = os.path.join(ROS_MSG_DIR, f'{model_name}.msg')
|
||||
os.makedirs(ROS_MSG_DIR, exist_ok=True)
|
||||
with open(msg_file_path, 'w') as msg_file:
|
||||
msg_file.write(ros_msg_definition)
|
||||
|
||||
|
||||
def main():
|
||||
for module_name in MODULES:
|
||||
module = __import__(module_name)
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseModel) and obj != BaseModel:
|
||||
print(f"Generating ROS message for Pydantic model: {name}")
|
||||
ros_msg_definition = generate_ros_msg_from_pydantic(obj)
|
||||
save_ros_msg_file(name, ros_msg_definition)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
0
unilabos/ros/utils/__init__.py
Normal file
0
unilabos/ros/utils/__init__.py
Normal file
268
unilabos/ros/utils/driver_creator.py
Normal file
268
unilabos/ros/utils/driver_creator.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
设备类实例创建工厂
|
||||
|
||||
这个模块包含用于创建设备类实例的工厂类。
|
||||
基础工厂类提供通用的实例创建方法,而特定工厂类提供针对特定设备类的创建方法。
|
||||
"""
|
||||
import asyncio
|
||||
import inspect
|
||||
import traceback
|
||||
from abc import abstractmethod
|
||||
from typing import Type, Any, Dict, Optional, TypeVar, Generic
|
||||
|
||||
from unilabos.resources.graphio import nested_dict_to_list, resource_ulab_to_plr
|
||||
from unilabos.ros.nodes.resource_tracker import DeviceNodeResourceTracker
|
||||
from unilabos.utils import logger, import_manager
|
||||
from unilabos.utils.cls_creator import create_instance_from_config
|
||||
|
||||
# 定义泛型类型变量
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ClassCreator(Generic[T]):
|
||||
@abstractmethod
|
||||
def create_instance(self, *args, **kwargs) -> T:
|
||||
pass
|
||||
|
||||
|
||||
class DeviceClassCreator(Generic[T]):
|
||||
"""
|
||||
设备类实例创建器基类
|
||||
|
||||
这个类提供了从任意类创建实例的通用方法。
|
||||
"""
|
||||
|
||||
def __init__(self, cls: Type[T]):
|
||||
"""
|
||||
初始化设备类创建器
|
||||
|
||||
Args:
|
||||
cls: 要创建实例的类
|
||||
"""
|
||||
self.device_cls = cls
|
||||
self.device_instance: Optional[T] = None
|
||||
|
||||
def create_instance(self, data: Dict[str, Any]) -> T:
|
||||
"""
|
||||
创建设备类实例
|
||||
|
||||
Args:
|
||||
|
||||
|
||||
Returns:
|
||||
设备类的实例
|
||||
"""
|
||||
self.device_instance = create_instance_from_config(
|
||||
{
|
||||
"_cls": self.device_cls.__module__ + ":" + self.device_cls.__name__,
|
||||
"_params": data,
|
||||
}
|
||||
)
|
||||
self.post_create()
|
||||
return self.device_instance
|
||||
|
||||
def get_instance(self) -> Optional[T]:
|
||||
"""
|
||||
获取当前实例
|
||||
|
||||
Returns:
|
||||
当前设备类实例,如果尚未创建则返回None
|
||||
"""
|
||||
return self.device_instance
|
||||
|
||||
def post_create(self):
|
||||
pass
|
||||
|
||||
|
||||
class PyLabRobotCreator(DeviceClassCreator[T]):
|
||||
"""
|
||||
PyLabRobot设备类创建器
|
||||
|
||||
这个类提供了针对PyLabRobot设备类的实例创建方法,特别处理deserialize方法。
|
||||
"""
|
||||
|
||||
def __init__(self, cls: Type[T], children: Dict[str, Any], resource_tracker: DeviceNodeResourceTracker):
|
||||
"""
|
||||
初始化PyLabRobot设备类创建器
|
||||
|
||||
Args:
|
||||
cls: PyLabRobot设备类
|
||||
children: 子资源字典,用于资源替换
|
||||
"""
|
||||
super().__init__(cls)
|
||||
self.children = children
|
||||
self.resource_tracker = resource_tracker
|
||||
# 检查类是否具有deserialize方法
|
||||
self.has_deserialize = hasattr(cls, "deserialize") and callable(getattr(cls, "deserialize"))
|
||||
if not self.has_deserialize:
|
||||
logger.warning(f"类 {cls.__name__} 没有deserialize方法,将使用标准构造函数")
|
||||
|
||||
def _process_resource_mapping(self, resource, source_type):
|
||||
if source_type == dict:
|
||||
from pylabrobot.resources.resource import Resource
|
||||
|
||||
return nested_dict_to_list(resource), Resource
|
||||
return resource, source_type
|
||||
|
||||
def _process_resource_references(self, data: Any, to_dict=False) -> Any:
|
||||
"""
|
||||
递归处理资源引用,替换_resource_child_name对应的资源
|
||||
|
||||
Args:
|
||||
data: 需要处理的数据,可能是字典、列表或其他类型
|
||||
to_dict: 转换成对应的实例,还是转换成对应的字典
|
||||
|
||||
Returns:
|
||||
处理后的数据
|
||||
"""
|
||||
from pylabrobot.resources import Deck, Resource
|
||||
|
||||
if isinstance(data, dict):
|
||||
# 检查是否包含资源引用
|
||||
if "_resource_child_name" in data:
|
||||
child_name = data["_resource_child_name"]
|
||||
if child_name in self.children:
|
||||
# 找到了对应的资源
|
||||
resource = self.children[child_name]
|
||||
|
||||
# 检查是否需要转换资源类型
|
||||
if "_resource_type" in data:
|
||||
type_path = data["_resource_type"]
|
||||
try:
|
||||
# 尝试导入指定的类型
|
||||
target_type = import_manager.get_class(type_path)
|
||||
contain_model = not issubclass(target_type, Deck)
|
||||
resource, target_type = self._process_resource_mapping(resource, target_type)
|
||||
# 在截图中格式,是deserialize,所以这里要转成plr resource可deserialize的字典
|
||||
# 这样后面执行deserialize的时候能够正确反序列化对应的物料
|
||||
resource_instance: Resource = resource_ulab_to_plr(resource, contain_model)
|
||||
if to_dict:
|
||||
return resource_instance.serialize()
|
||||
else:
|
||||
self.resource_tracker.add_resource(resource_instance)
|
||||
return resource_instance
|
||||
except Exception as e:
|
||||
logger.warning(f"无法导入资源类型 {type_path}: {e}")
|
||||
logger.warning(traceback.format_exc())
|
||||
else:
|
||||
logger.debug(f"找不到资源类型,请补全_resource_type {self.device_cls.__name__} {data.keys()}")
|
||||
return resource
|
||||
else:
|
||||
logger.warning(f"找不到资源引用 '{child_name}',保持原值不变")
|
||||
|
||||
# 递归处理字典的每个值
|
||||
result = {}
|
||||
for key, value in data.items():
|
||||
result[key] = self._process_resource_references(value, to_dict)
|
||||
return result
|
||||
|
||||
# 处理列表类型
|
||||
elif isinstance(data, list):
|
||||
return [self._process_resource_references(item, to_dict) for item in data]
|
||||
|
||||
# 其他类型直接返回
|
||||
return data
|
||||
|
||||
def create_instance(self, data: Dict[str, Any]) -> Optional[T]:
|
||||
"""
|
||||
从数据创建PyLabRobot设备实例
|
||||
|
||||
Args:
|
||||
data: 用于反序列化的数据
|
||||
|
||||
Returns:
|
||||
PyLabRobot设备类实例
|
||||
"""
|
||||
deserialize_error = None
|
||||
stack = None
|
||||
if self.has_deserialize:
|
||||
deserialize_method = getattr(self.device_cls, "deserialize")
|
||||
spect = inspect.signature(deserialize_method)
|
||||
spec_args = spect.parameters
|
||||
for param_name, param_value in data.copy().items():
|
||||
if "_resource_child_name" in param_value and "_resource_type" not in param_value:
|
||||
arg_value = spec_args[param_name].annotation
|
||||
data[param_name]["_resource_type"] = self.device_cls.__module__ + ":" + arg_value
|
||||
logger.debug(f"自动补充 _resource_type: {data[param_name]['_resource_type']}")
|
||||
|
||||
# 首先处理资源引用
|
||||
processed_data = self._process_resource_references(data, to_dict=True)
|
||||
|
||||
try:
|
||||
self.device_instance = deserialize_method(**processed_data)
|
||||
self.resource_tracker.add_resource(self.device_instance)
|
||||
self.post_create()
|
||||
return self.device_instance # type: ignore
|
||||
except Exception as e:
|
||||
# 先静默继续,尝试另外一种创建方法
|
||||
deserialize_error = e
|
||||
stack = traceback.format_exc()
|
||||
|
||||
if self.device_instance is None:
|
||||
try:
|
||||
spect = inspect.signature(self.device_cls.__init__)
|
||||
spec_args = spect.parameters
|
||||
for param_name, param_value in data.copy().items():
|
||||
if "_resource_child_name" in param_value and "_resource_type" not in param_value:
|
||||
arg_value = spec_args[param_name].annotation
|
||||
data[param_name]["_resource_type"] = self.device_cls.__module__ + ":" + arg_value
|
||||
logger.debug(f"自动补充 _resource_type: {data[param_name]['_resource_type']}")
|
||||
processed_data = self._process_resource_references(data, to_dict=False)
|
||||
self.device_instance = super(PyLabRobotCreator, self).create_instance(processed_data)
|
||||
except Exception as e:
|
||||
logger.error(f"PyLabRobot创建实例失败: {e}")
|
||||
logger.error(f"PyLabRobot创建实例堆栈: {traceback.format_exc()}")
|
||||
finally:
|
||||
if self.device_instance is None:
|
||||
if deserialize_error:
|
||||
logger.error(f"PyLabRobot反序列化失败: {deserialize_error}")
|
||||
logger.error(f"PyLabRobot反序列化堆栈: {stack}")
|
||||
|
||||
return self.device_instance
|
||||
|
||||
def post_create(self):
|
||||
if hasattr(self.device_instance, "setup") and asyncio.iscoroutinefunction(getattr(self.device_instance, "setup")):
|
||||
from unilabos.ros.nodes.base_device_node import ROS2DeviceNode
|
||||
ROS2DeviceNode.run_async_func(getattr(self.device_instance, "setup")).add_done_callback(lambda x: logger.debug(f"PyLabRobot设备实例 {self.device_instance} 设置完成"))
|
||||
# 2486229810384
|
||||
#2486232539792
|
||||
|
||||
class ProtocolNodeCreator(DeviceClassCreator[T]):
|
||||
"""
|
||||
ProtocolNode设备类创建器
|
||||
|
||||
这个类提供了针对ProtocolNode设备类的实例创建方法,处理children参数。
|
||||
"""
|
||||
|
||||
def __init__(self, cls: Type[T], children: Dict[str, Any]):
|
||||
"""
|
||||
初始化ProtocolNode设备类创建器
|
||||
|
||||
Args:
|
||||
cls: ProtocolNode设备类
|
||||
children: 子资源字典,用于资源替换
|
||||
"""
|
||||
super().__init__(cls)
|
||||
self.children = children
|
||||
|
||||
def create_instance(self, data: Dict[str, Any]) -> T:
|
||||
"""
|
||||
从数据创建ProtocolNode设备实例
|
||||
|
||||
Args:
|
||||
data: 用于创建实例的数据
|
||||
|
||||
Returns:
|
||||
ProtocolNode设备类实例
|
||||
"""
|
||||
try:
|
||||
|
||||
# 创建实例
|
||||
data["children"] = self.children
|
||||
self.device_instance = super(ProtocolNodeCreator, self).create_instance(data)
|
||||
self.post_create()
|
||||
return self.device_instance
|
||||
except Exception as e:
|
||||
logger.error(f"ProtocolNode创建实例失败: {e}")
|
||||
logger.error(f"ProtocolNode创建实例堆栈: {traceback.format_exc()}")
|
||||
raise
|
||||
0
unilabos/ros/x/__init__.py
Normal file
0
unilabos/ros/x/__init__.py
Normal file
182
unilabos/ros/x/rclpyx.py
Normal file
182
unilabos/ros/x/rclpyx.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import asyncio
|
||||
from asyncio import events
|
||||
import threading
|
||||
|
||||
import rclpy
|
||||
from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy
|
||||
from rclpy.executors import await_or_execute, Executor
|
||||
from rclpy.action import ActionClient, ActionServer
|
||||
from rclpy.action.server import ServerGoalHandle, GoalResponse, GoalInfo, GoalStatus
|
||||
from std_msgs.msg import String
|
||||
from action_tutorials_interfaces.action import Fibonacci
|
||||
|
||||
|
||||
loop = None
|
||||
|
||||
def get_event_loop():
|
||||
global loop
|
||||
return loop
|
||||
|
||||
|
||||
async def default_handle_accepted_callback_async(goal_handle):
|
||||
"""Execute the goal."""
|
||||
await goal_handle.execute()
|
||||
|
||||
|
||||
class ServerGoalHandleX(ServerGoalHandle):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def execute(self, execute_callback=None):
|
||||
# It's possible that there has been a request to cancel the goal prior to executing.
|
||||
# In this case we want to avoid the illegal state transition to EXECUTING
|
||||
# but still call the users execute callback to let them handle canceling the goal.
|
||||
if not self.is_cancel_requested:
|
||||
self._update_state(_rclpy.GoalEvent.EXECUTE)
|
||||
await self._action_server.notify_execute_async(self, execute_callback)
|
||||
|
||||
|
||||
class ActionServerX(ActionServer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.register_handle_accepted_callback(default_handle_accepted_callback_async)
|
||||
|
||||
async def _execute_goal_request(self, request_header_and_message):
|
||||
request_header, goal_request = request_header_and_message
|
||||
goal_uuid = goal_request.goal_id
|
||||
goal_info = GoalInfo()
|
||||
goal_info.goal_id = goal_uuid
|
||||
|
||||
self._node.get_logger().debug('New goal request with ID: {0}'.format(goal_uuid.uuid))
|
||||
|
||||
# Check if goal ID is already being tracked by this action server
|
||||
with self._lock:
|
||||
goal_id_exists = self._handle.goal_exists(goal_info)
|
||||
|
||||
accepted = False
|
||||
if not goal_id_exists:
|
||||
# Call user goal callback
|
||||
response = await await_or_execute(self._goal_callback, goal_request.goal)
|
||||
if not isinstance(response, GoalResponse):
|
||||
self._node.get_logger().warning(
|
||||
'Goal request callback did not return a GoalResponse type. Rejecting goal.')
|
||||
else:
|
||||
accepted = GoalResponse.ACCEPT == response
|
||||
|
||||
if accepted:
|
||||
# Stamp time of acceptance
|
||||
goal_info.stamp = self._node.get_clock().now().to_msg()
|
||||
|
||||
# Create a goal handle
|
||||
try:
|
||||
with self._lock:
|
||||
goal_handle = ServerGoalHandleX(self, goal_info, goal_request.goal)
|
||||
except RuntimeError as e:
|
||||
self._node.get_logger().error(
|
||||
'Failed to accept new goal with ID {0}: {1}'.format(goal_uuid.uuid, e))
|
||||
accepted = False
|
||||
else:
|
||||
self._goal_handles[bytes(goal_uuid.uuid)] = goal_handle
|
||||
|
||||
# Send response
|
||||
response_msg = self._action_type.Impl.SendGoalService.Response()
|
||||
response_msg.accepted = accepted
|
||||
response_msg.stamp = goal_info.stamp
|
||||
self._handle.send_goal_response(request_header, response_msg)
|
||||
|
||||
if not accepted:
|
||||
self._node.get_logger().debug('New goal rejected: {0}'.format(goal_uuid.uuid))
|
||||
return
|
||||
|
||||
self._node.get_logger().debug('New goal accepted: {0}'.format(goal_uuid.uuid))
|
||||
|
||||
# Provide the user a reference to the goal handle
|
||||
# await await_or_execute(self._handle_accepted_callback, goal_handle)
|
||||
asyncio.create_task(self._handle_accepted_callback(goal_handle))
|
||||
|
||||
async def notify_execute_async(self, goal_handle, execute_callback):
|
||||
# Use provided callback, defaulting to a previously registered callback
|
||||
if execute_callback is None:
|
||||
if self._execute_callback is None:
|
||||
return
|
||||
execute_callback = self._execute_callback
|
||||
|
||||
# Schedule user callback for execution
|
||||
self._node.get_logger().info(f"{events.get_running_loop()}")
|
||||
asyncio.create_task(self._execute_goal(execute_callback, goal_handle))
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# task = loop.create_task(self._execute_goal(execute_callback, goal_handle))
|
||||
# await task
|
||||
|
||||
|
||||
class ActionClientX(ActionClient):
|
||||
feedback_queue = asyncio.Queue()
|
||||
|
||||
async def feedback_cb(self, msg):
|
||||
await self.feedback_queue.put(msg)
|
||||
|
||||
async def send_goal_async(self, goal_msg):
|
||||
goal_future = super().send_goal_async(
|
||||
goal_msg,
|
||||
feedback_callback=self.feedback_cb
|
||||
)
|
||||
client_goal_handle = await asyncio.ensure_future(goal_future)
|
||||
if not client_goal_handle.accepted:
|
||||
raise Exception("Goal rejected.")
|
||||
result_future = client_goal_handle.get_result_async()
|
||||
while True:
|
||||
feedback_future = asyncio.ensure_future(self.feedback_queue.get())
|
||||
tasks = [result_future, feedback_future]
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
if result_future.done():
|
||||
result = result_future.result().result
|
||||
yield (None, result)
|
||||
break
|
||||
else:
|
||||
feedback = feedback_future.result().feedback
|
||||
yield (feedback, None)
|
||||
|
||||
|
||||
async def main(node):
|
||||
print('Node started.')
|
||||
action_client = ActionClientX(node, Fibonacci, 'fibonacci')
|
||||
goal_msg = Fibonacci.Goal()
|
||||
goal_msg.order = 10
|
||||
async for (feedback, result) in action_client.send_goal_async(goal_msg):
|
||||
if feedback:
|
||||
print(f'Feedback: {feedback}')
|
||||
else:
|
||||
print(f'Result: {result}')
|
||||
print('Finished.')
|
||||
|
||||
|
||||
async def ros_loop_node(node):
|
||||
while rclpy.ok():
|
||||
rclpy.spin_once(node, timeout_sec=0)
|
||||
await asyncio.sleep(1e-4)
|
||||
|
||||
|
||||
async def ros_loop(executor: Executor):
|
||||
while rclpy.ok():
|
||||
executor.spin_once(timeout_sec=0)
|
||||
await asyncio.sleep(1e-4)
|
||||
|
||||
|
||||
def run_event_loop():
|
||||
global loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
|
||||
def run_event_loop_in_thread():
|
||||
thread = threading.Thread(target=run_event_loop, args=())
|
||||
thread.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
rclpy.init()
|
||||
node = rclpy.create_node('async_subscriber')
|
||||
future = asyncio.wait([ros_loop(node), main()])
|
||||
asyncio.get_event_loop().run_until_complete(future)
|
||||
Reference in New Issue
Block a user