Files
Uni-Lab-OS/unilabos/ros/nodes/presets/controller_node.py
Junhan Chang c78ac482d8 Initial commit
2025-04-17 15:19:47 +08:00

123 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Callable, Dict
from std_msgs.msg import Float64
from unilabos.ros.nodes.base_device_node import BaseROS2DeviceNode, DeviceNodeResourceTracker
class ControllerNode(BaseROS2DeviceNode):
namespace_prefix = "/controllers"
def __init__(
self,
device_id: str,
controller_func: Callable,
update_rate: float,
inputs: Dict[str, Dict[str, type | str]],
outputs: Dict[str, Dict[str, type]],
parameters: Dict,
resource_tracker: DeviceNodeResourceTracker,
):
"""
通用控制器节点
:param controller_id: 控制器的唯一标识符(作为命名空间的一部分)
:param update_rate: 控制器更新频率 (Hz)
:param controller_func: 控制器函数,接收 Python 格式的 inputs 和 parameters 返回 outputs
:param input_types: 输入话题及其消息类型的字典
:param output_types: 输出话题及其消息类型的字典
:param parameters: 控制器函数的额外参数
"""
# 先准备所需的属性,以便在调用父类初始化前就可以使用
self.device_id = device_id
self.controller_func = controller_func
self.update_rate = update_rate
self.update_time = 1.0 / update_rate
self.parameters = parameters
self.inputs = {topic: None for topic in inputs.keys()}
self.control_input_subscribers = {}
self.control_output_publishers = {}
self.topic_mapping = {
**{input_info["topic"]: input for input, input_info in inputs.items()},
**{output_info["topic"]: output for output, output_info in outputs.items()},
}
# 调用BaseROS2DeviceNode初始化使用自身作为driver_instance
status_types = {}
action_value_mappings = {}
hardware_interface = {}
# 使用短ID作为节点名完整ID带namespace_prefix作为device_id
BaseROS2DeviceNode.__init__(
self,
driver_instance=self,
device_id=device_id,
status_types=status_types,
action_value_mappings=action_value_mappings,
hardware_interface=hardware_interface,
print_publish=False,
resource_tracker=resource_tracker
)
# 原始初始化逻辑
# 初始化订阅者
for input, input_info in inputs.items():
msg_type = input_info["type"]
topic = str(input_info["topic"])
self.control_input_subscribers[input] = self.create_subscription(
msg_type, topic, lambda msg, t=topic: self.input_callback(t, msg), 10
)
# 初始化发布者
for output, output_info in outputs.items():
self.lab_logger().info(f"Creating publisher for {output} {output_info}")
msg_type = output_info["type"]
topic = str(output_info["topic"])
self.control_output_publishers[output] = self.create_publisher(msg_type, topic, 10)
# 定时器,用于定期调用控制逻辑
self.timer = self.create_timer(self.update_time, self.control_loop)
def input_callback(self, topic: str, msg):
"""
更新指定话题的输入数据,并将 ROS 消息转换为普通 Python 数据。
支持 `std_msgs` 类型消息。
"""
self.inputs[self.topic_mapping[topic]] = msg.data
self.lab_logger().info(f"Received input on topic {topic}: {msg.data}")
def control_loop(self):
"""主控制逻辑"""
# 检查所有输入是否已更新
if all(value is not None for value in self.inputs.values()):
self.lab_logger().info(
f"Calling controller function with inputs: {self.inputs}, parameters: {self.parameters}"
)
try:
# 调用控制器函数,传入 Python 格式的数据
outputs = self.controller_func(**self.inputs, **self.parameters)
self.lab_logger().info(f"Inputs: {self.inputs}, Outputs: {outputs}")
self.inputs = {topic: None for topic in self.inputs.keys()}
except Exception as e:
self.lab_logger().error(f"Controller function execution failed: {e}")
return
# 发布控制信号,将普通 Python 数据转换为 ROS 消息
if isinstance(outputs, dict):
for topic, value in outputs.items():
if topic in self.control_output_publishers:
# 支持 Float64 输出
if isinstance(value, (float, int)):
self.control_output_publishers[topic].publish(Float64(data=value))
else:
self.lab_logger().error(f"Unsupported output type for topic {topic}: {type(value)}")
else:
self.lab_logger().warning(f"Output topic {topic} is not defined in output_types.")
else:
publisher = list(self.control_output_publishers.values())[0]
if isinstance(outputs, (float, int)):
publisher.publish(Float64(data=outputs))
else:
self.lab_logger().error(f"Unsupported output type: {type(outputs)}")
else:
self.lab_logger().info("Waiting for all inputs to be received.")