mirror of
https://github.com/dptech-corp/Uni-Lab-OS.git
synced 2025-12-17 21:11:12 +00:00
123 lines
5.3 KiB
Python
123 lines
5.3 KiB
Python
from typing import Callable, Dict
|
||
from std_msgs.msg import Float64
|
||
|
||
from unilabos.ros.nodes.base_device_node import BaseROS2DeviceNode, DeviceNodeResourceTracker
|
||
|
||
|
||
class ControllerNode(BaseROS2DeviceNode):
|
||
namespace_prefix = "/controllers"
|
||
|
||
def __init__(
|
||
self,
|
||
device_id: str,
|
||
controller_func: Callable,
|
||
update_rate: float,
|
||
inputs: Dict[str, Dict[str, type | str]],
|
||
outputs: Dict[str, Dict[str, type]],
|
||
parameters: Dict,
|
||
resource_tracker: DeviceNodeResourceTracker,
|
||
):
|
||
"""
|
||
通用控制器节点
|
||
|
||
:param controller_id: 控制器的唯一标识符(作为命名空间的一部分)
|
||
:param update_rate: 控制器更新频率 (Hz)
|
||
:param controller_func: 控制器函数,接收 Python 格式的 inputs 和 parameters 返回 outputs
|
||
:param input_types: 输入话题及其消息类型的字典
|
||
:param output_types: 输出话题及其消息类型的字典
|
||
:param parameters: 控制器函数的额外参数
|
||
"""
|
||
# 先准备所需的属性,以便在调用父类初始化前就可以使用
|
||
self.device_id = device_id
|
||
self.controller_func = controller_func
|
||
self.update_rate = update_rate
|
||
self.update_time = 1.0 / update_rate
|
||
self.parameters = parameters
|
||
self.inputs = {topic: None for topic in inputs.keys()}
|
||
self.control_input_subscribers = {}
|
||
self.control_output_publishers = {}
|
||
self.topic_mapping = {
|
||
**{input_info["topic"]: input for input, input_info in inputs.items()},
|
||
**{output_info["topic"]: output for output, output_info in outputs.items()},
|
||
}
|
||
|
||
# 调用BaseROS2DeviceNode初始化,使用自身作为driver_instance
|
||
status_types = {}
|
||
action_value_mappings = {}
|
||
hardware_interface = {}
|
||
|
||
# 使用短ID作为节点名,完整ID(带namespace_prefix)作为device_id
|
||
BaseROS2DeviceNode.__init__(
|
||
self,
|
||
driver_instance=self,
|
||
device_id=device_id,
|
||
status_types=status_types,
|
||
action_value_mappings=action_value_mappings,
|
||
hardware_interface=hardware_interface,
|
||
print_publish=False,
|
||
resource_tracker=resource_tracker
|
||
)
|
||
|
||
# 原始初始化逻辑
|
||
# 初始化订阅者
|
||
for input, input_info in inputs.items():
|
||
msg_type = input_info["type"]
|
||
topic = str(input_info["topic"])
|
||
self.control_input_subscribers[input] = self.create_subscription(
|
||
msg_type, topic, lambda msg, t=topic: self.input_callback(t, msg), 10
|
||
)
|
||
|
||
# 初始化发布者
|
||
for output, output_info in outputs.items():
|
||
self.lab_logger().info(f"Creating publisher for {output} {output_info}")
|
||
msg_type = output_info["type"]
|
||
topic = str(output_info["topic"])
|
||
self.control_output_publishers[output] = self.create_publisher(msg_type, topic, 10)
|
||
|
||
# 定时器,用于定期调用控制逻辑
|
||
self.timer = self.create_timer(self.update_time, self.control_loop)
|
||
|
||
def input_callback(self, topic: str, msg):
|
||
"""
|
||
更新指定话题的输入数据,并将 ROS 消息转换为普通 Python 数据。
|
||
支持 `std_msgs` 类型消息。
|
||
"""
|
||
self.inputs[self.topic_mapping[topic]] = msg.data
|
||
self.lab_logger().info(f"Received input on topic {topic}: {msg.data}")
|
||
|
||
def control_loop(self):
|
||
"""主控制逻辑"""
|
||
# 检查所有输入是否已更新
|
||
if all(value is not None for value in self.inputs.values()):
|
||
self.lab_logger().info(
|
||
f"Calling controller function with inputs: {self.inputs}, parameters: {self.parameters}"
|
||
)
|
||
try:
|
||
# 调用控制器函数,传入 Python 格式的数据
|
||
outputs = self.controller_func(**self.inputs, **self.parameters)
|
||
self.lab_logger().info(f"Inputs: {self.inputs}, Outputs: {outputs}")
|
||
self.inputs = {topic: None for topic in self.inputs.keys()}
|
||
except Exception as e:
|
||
self.lab_logger().error(f"Controller function execution failed: {e}")
|
||
return
|
||
|
||
# 发布控制信号,将普通 Python 数据转换为 ROS 消息
|
||
if isinstance(outputs, dict):
|
||
for topic, value in outputs.items():
|
||
if topic in self.control_output_publishers:
|
||
# 支持 Float64 输出
|
||
if isinstance(value, (float, int)):
|
||
self.control_output_publishers[topic].publish(Float64(data=value))
|
||
else:
|
||
self.lab_logger().error(f"Unsupported output type for topic {topic}: {type(value)}")
|
||
else:
|
||
self.lab_logger().warning(f"Output topic {topic} is not defined in output_types.")
|
||
else:
|
||
publisher = list(self.control_output_publishers.values())[0]
|
||
if isinstance(outputs, (float, int)):
|
||
publisher.publish(Float64(data=outputs))
|
||
else:
|
||
self.lab_logger().error(f"Unsupported output type: {type(outputs)}")
|
||
else:
|
||
self.lab_logger().info("Waiting for all inputs to be received.")
|