mirror of
https://github.com/ZGCA-Forge/Elevator.git
synced 2025-12-17 13:01:03 +00:00
Full support for gui & algorithm
This commit is contained in:
@@ -4,6 +4,7 @@ Unified API Client for Elevator Saga
|
||||
使用统一数据模型的客户端API封装
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from typing import Any, Dict, Optional
|
||||
@@ -18,19 +19,25 @@ from elevator_saga.core.models import (
|
||||
SimulationState,
|
||||
StepResponse,
|
||||
)
|
||||
from elevator_saga.utils.debug import debug_log
|
||||
from elevator_saga.utils.logger import debug, error, info, warning
|
||||
|
||||
|
||||
class ElevatorAPIClient:
|
||||
"""统一的电梯API客户端"""
|
||||
|
||||
def __init__(self, base_url: str):
|
||||
def __init__(self, base_url: str, client_type: str = "algorithm"):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
# 客户端身份相关
|
||||
self.client_type = client_type
|
||||
self.client_id: Optional[str] = None
|
||||
# 缓存相关字段
|
||||
self._cached_state: Optional[SimulationState] = None
|
||||
self._cached_tick: int = -1
|
||||
self._tick_processed: bool = False # 标记当前tick是否已处理完成
|
||||
debug_log(f"API Client initialized for {self.base_url}")
|
||||
debug(f"API Client initialized for {self.base_url} with type {self.client_type}", prefix="CLIENT")
|
||||
|
||||
# 尝试自动注册
|
||||
self._auto_register()
|
||||
|
||||
def get_state(self, force_reload: bool = False) -> SimulationState:
|
||||
"""获取模拟状态
|
||||
@@ -92,7 +99,16 @@ class ElevatorAPIClient:
|
||||
|
||||
def step(self, ticks: int = 1) -> StepResponse:
|
||||
"""执行步进"""
|
||||
response_data = self._send_post_request("/api/step", {"ticks": ticks})
|
||||
# 携带当前tick信息,用于优先级队列控制
|
||||
# 如果没有缓存的state,先获取一次
|
||||
if self._cached_state is None:
|
||||
self.get_state(force_reload=True)
|
||||
|
||||
request_data = {"ticks": ticks}
|
||||
if self._cached_state is not None:
|
||||
request_data["current_tick"] = self._cached_state.tick
|
||||
|
||||
response_data = self._send_post_request("/api/step", request_data)
|
||||
|
||||
if "error" not in response_data:
|
||||
# 使用服务端返回的真实数据
|
||||
@@ -108,7 +124,7 @@ class ElevatorAPIClient:
|
||||
|
||||
event_dict["type"] = EventType(event_dict["type"])
|
||||
except ValueError:
|
||||
debug_log(f"Unknown event type: {event_dict['type']}")
|
||||
warning(f"Unknown event type: {event_dict['type']}", prefix="CLIENT")
|
||||
continue
|
||||
events.append(SimulationEvent.from_dict(event_dict))
|
||||
|
||||
@@ -118,6 +134,10 @@ class ElevatorAPIClient:
|
||||
events=events,
|
||||
)
|
||||
|
||||
# 更新缓存的tick(保持其他状态不变,只更新tick)
|
||||
if self._cached_state is not None:
|
||||
self._cached_state.tick = step_response.tick
|
||||
|
||||
# debug_log(f"Step response: tick={step_response.tick}, events={len(events)}")
|
||||
return step_response
|
||||
else:
|
||||
@@ -125,9 +145,20 @@ class ElevatorAPIClient:
|
||||
|
||||
def send_elevator_command(self, command: GoToFloorCommand) -> bool:
|
||||
"""发送电梯命令"""
|
||||
# 客户端拦截:检查是否有权限发送控制命令
|
||||
if not self._can_send_command():
|
||||
warning(
|
||||
f"Client type '{self.client_type}' cannot send control commands. "
|
||||
f"Command ignored: {command.command_type} elevator {command.elevator_id} to floor {command.floor}",
|
||||
prefix="CLIENT",
|
||||
)
|
||||
# 不抛出错误,直接返回True(但实际未执行)
|
||||
return True
|
||||
|
||||
endpoint = self._get_elevator_endpoint(command)
|
||||
debug_log(
|
||||
f"Sending elevator command: {command.command_type} to elevator {command.elevator_id} To:F{command.floor}"
|
||||
debug(
|
||||
f"Sending elevator command: {command.command_type} to elevator {command.elevator_id} To:F{command.floor}",
|
||||
prefix="CLIENT",
|
||||
)
|
||||
|
||||
response_data = self._send_post_request(endpoint, command.parameters)
|
||||
@@ -145,7 +176,7 @@ class ElevatorAPIClient:
|
||||
response = self.send_elevator_command(command)
|
||||
return response
|
||||
except Exception as e:
|
||||
debug_log(f"Go to floor failed: {e}")
|
||||
error(f"Go to floor failed: {e}", prefix="CLIENT")
|
||||
return False
|
||||
|
||||
def _get_elevator_endpoint(self, command: GoToFloorCommand) -> str:
|
||||
@@ -155,13 +186,69 @@ class ElevatorAPIClient:
|
||||
if isinstance(command, GoToFloorCommand):
|
||||
return f"{base}/go_to_floor"
|
||||
|
||||
def _auto_register(self) -> None:
|
||||
"""自动注册客户端"""
|
||||
try:
|
||||
# 从环境变量读取客户端类型(如果有的话)
|
||||
env_client_type = os.environ.get("ELEVATOR_CLIENT_TYPE")
|
||||
if env_client_type:
|
||||
self.client_type = env_client_type
|
||||
debug(f"Client type from environment: {self.client_type}", prefix="CLIENT")
|
||||
|
||||
# 直接发送注册请求(不使用_send_post_request以避免循环依赖)
|
||||
url = f"{self.base_url}/api/client/register"
|
||||
request_body = json.dumps({}).encode("utf-8")
|
||||
headers = {"Content-Type": "application/json", "X-Client-Type": self.client_type}
|
||||
req = urllib.request.Request(url, data=request_body, headers=headers)
|
||||
|
||||
with urllib.request.urlopen(req, timeout=60) as response:
|
||||
response_data = json.loads(response.read().decode("utf-8"))
|
||||
if response_data.get("success"):
|
||||
self.client_id = response_data.get("client_id")
|
||||
info(f"Client registered successfully with ID: {self.client_id}", prefix="CLIENT")
|
||||
else:
|
||||
warning(f"Client registration failed: {response_data.get('error')}", prefix="CLIENT")
|
||||
except Exception as e:
|
||||
error(f"Auto registration failed: {e}", prefix="CLIENT")
|
||||
|
||||
def _can_send_command(self) -> bool:
|
||||
"""检查客户端是否可以发送控制命令
|
||||
|
||||
Returns:
|
||||
True: 如果是算法客户端或未注册客户端
|
||||
False: 如果是GUI客户端
|
||||
"""
|
||||
# 算法客户端可以发送命令
|
||||
if self.client_type.lower() == "algorithm":
|
||||
return True
|
||||
# 未注册的客户端也可以发送命令(向后兼容)
|
||||
if self.client_id is None:
|
||||
return True
|
||||
# GUI客户端不能发送命令
|
||||
if self.client_type.lower() == "gui":
|
||||
return False
|
||||
# 其他未知类型,默认允许(向后兼容)
|
||||
return True
|
||||
|
||||
def _get_request_headers(self) -> Dict[str, str]:
|
||||
"""获取请求头,包含客户端身份信息"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.client_id:
|
||||
headers["X-Client-ID"] = self.client_id
|
||||
headers["X-Client-Type"] = self.client_type
|
||||
return headers
|
||||
|
||||
def _send_get_request(self, endpoint: str) -> Dict[str, Any]:
|
||||
"""发送GET请求"""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
# todo: 全部更改为post
|
||||
# debug_log(f"GET {url}")
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=60) as response:
|
||||
headers = self._get_request_headers()
|
||||
# 对于GET请求,只添加客户端标识头
|
||||
req = urllib.request.Request(url, headers={k: v for k, v in headers.items() if k != "Content-Type"})
|
||||
with urllib.request.urlopen(req, timeout=60) as response:
|
||||
data: Dict[str, Any] = json.loads(response.read().decode("utf-8"))
|
||||
# debug_log(f"GET {url} -> {response.status}")
|
||||
return data
|
||||
@@ -169,7 +256,7 @@ class ElevatorAPIClient:
|
||||
raise RuntimeError(f"GET {url} failed: {e}")
|
||||
|
||||
def reset(self) -> bool:
|
||||
"""重置模拟"""
|
||||
"""重置模拟并重新注册客户端"""
|
||||
try:
|
||||
response_data = self._send_post_request("/api/reset", {})
|
||||
success = bool(response_data.get("success", False))
|
||||
@@ -178,10 +265,14 @@ class ElevatorAPIClient:
|
||||
self._cached_state = None
|
||||
self._cached_tick = -1
|
||||
self._tick_processed = False
|
||||
debug_log("Cache cleared after reset")
|
||||
debug("Cache cleared after reset", prefix="CLIENT")
|
||||
|
||||
# 重新注册客户端(因为服务器已清除客户端记录)
|
||||
self._auto_register()
|
||||
debug("Client re-registered after reset", prefix="CLIENT")
|
||||
return success
|
||||
except Exception as e:
|
||||
debug_log(f"Reset failed: {e}")
|
||||
error(f"Reset failed: {e}", prefix="CLIENT")
|
||||
return False
|
||||
|
||||
def next_traffic_round(self, full_reset: bool = False) -> bool:
|
||||
@@ -194,10 +285,10 @@ class ElevatorAPIClient:
|
||||
self._cached_state = None
|
||||
self._cached_tick = -1
|
||||
self._tick_processed = False
|
||||
debug_log("Cache cleared after traffic round switch")
|
||||
debug("Cache cleared after traffic round switch", prefix="CLIENT")
|
||||
return success
|
||||
except Exception as e:
|
||||
debug_log(f"Next traffic round failed: {e}")
|
||||
error(f"Next traffic round failed: {e}", prefix="CLIENT")
|
||||
return False
|
||||
|
||||
def get_traffic_info(self) -> Optional[Dict[str, Any]]:
|
||||
@@ -207,10 +298,10 @@ class ElevatorAPIClient:
|
||||
if "error" not in response_data:
|
||||
return response_data
|
||||
else:
|
||||
debug_log(f"Get traffic info failed: {response_data.get('error')}")
|
||||
warning(f"Get traffic info failed: {response_data.get('error')}", prefix="CLIENT")
|
||||
return None
|
||||
except Exception as e:
|
||||
debug_log(f"Get traffic info failed: {e}")
|
||||
error(f"Get traffic info failed: {e}", prefix="CLIENT")
|
||||
return None
|
||||
|
||||
def _send_post_request(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -220,7 +311,8 @@ class ElevatorAPIClient:
|
||||
|
||||
# debug_log(f"POST {url} with data: {data}")
|
||||
|
||||
req = urllib.request.Request(url, data=request_body, headers={"Content-Type": "application/json"})
|
||||
headers = self._get_request_headers()
|
||||
req = urllib.request.Request(url, data=request_body, headers=headers)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=600) as response:
|
||||
|
||||
Reference in New Issue
Block a user