import logging
import warnings
from typing import Any, Dict, List, Optional, Tuple
import inspect
import gymnasium as gym
import numpy as np
from ray.rllib.env.env_context import EnvContext
from ray.rllib.utils.spaces.repeated import Repeated
from graphenv.vertex import V
logger = logging.getLogger(__name__)
[docs]class GraphEnv(gym.Env):
"""
Defines an OpenAI Gym Env for traversing a graph using the current vertex
as the state, and the successor vertices as actions.
GraphEnv uses composition to supply the per-vertex model of type Vertex, which
defines the graph via it's `_get_children()` method.
The `env_config` dictionary should contain the following keys::
state (N): Current vertex
max_num_children (int): maximum number of children considered at a time.
Args:
env_config (dict): A dictionary of parameters, required to conform with
rllib's environment initialization.
"""
#: graphenv.vertex.Vertex: current vertex
state: V
#: int: maximum number of actions considered at a time
max_num_children: int
#: the observation space of the graph environment
observation_space: gym.Space
#: the action space, a Discrete space over `max_num_children`
action_space: gym.Space
# For environment rendering
metadata: Dict[str, Any] = {"render_modes": ["human", None]}
render_mode: Optional[str] = None
def __init__(self, env_config: EnvContext) -> None:
super().__init__()
logger.debug("entering graphenv construction")
self.state = env_config["state"]
self.max_num_children = env_config["max_num_children"]
num_vertex_observations = 1 + self.max_num_children
self.observation_space = Repeated(
self.state.observation_space, num_vertex_observations
)
self.action_space = gym.spaces.Discrete(self.max_num_children)
logger.debug("leaving graphenv construction")
# RLlib 2.3.1 does not yet support setting the 'seed' here. Using kwargs quiets the warning.
# "Seeding will take place using 'env.seed()' and the info dict will not be returned from reset."
#def reset(self, *, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict]:
[docs] def reset(self, **kwargs) -> Tuple[Dict[str, np.ndarray], Dict]:
"""Reset this state to the root vertex. It is possible for state.root to
return different root vertices on each call.
Returns:
Dict[str, np.ndarray]: Observation of the root vertex.
"""
self.state = self.state.root
return self.make_observation(), self.state.info
[docs] def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, bool, dict]:
"""Steps the environment to a new state by taking an action. In the
case of GraphEnv, the action specifies which next vertex to move to and
this method advances the environment to that vertex.
Args:
action (int): The index of the child vertex of self.state to move to.
Raises:
RuntimeError: When action is an invalid index.
Returns:
Tuple[Dict[str, np.ndarray], float, bool, dict]: Tuple of:
a dictionary of the new state's observation,
the reward received by moving to the new state's vertex,
a bool which is true iff the new state is a terminal vertex,
a bool which is true if the search is truncated
a dictionary of debugging information related to this call
"""
if len(self.state.children) > self.max_num_children:
raise RuntimeError(
f"State {self.state} has {len(self.state.children)} children "
f"(> {self.max_num_children})"
)
if action not in self.action_space:
raise RuntimeError(
f"Action {action} outside the action space of state {self.state}: "
f"{len(self.state.children)} max children"
)
try:
# Move the state to the next action
self.state = self.state.children[action]
except IndexError:
# Skip this warning message if the call
# came from rllib's precheck function
# https://github.com/ray-project/ray/blob/e6dad0b961b5e962f6dc4986947ccac2d2e032cd/rllib/utils/pre_checks/env.py#L220
skip_warning = False
for stack_func_info in inspect.stack():
caller_name = stack_func_info[3]
if caller_name == "check_gym_environments":
skip_warning = True
if not skip_warning:
warnings.warn(
"Attempting to choose a masked child state. This is either due to "
"rllib's env pre_check module, or due to a failure of the policy model "
"to mask invalid actions. Returning the current state to satisfy the "
"pre_check module.",
RuntimeWarning,
)
# In RLlib 2.3, the config options "no_done_at_end", "horizon", and "soft_horizon" are no longer supported
# according to the migration guide https://docs.google.com/document/d/1lxYK1dI5s0Wo_jmB6V6XiP-_aEBsXDykXkD1AXRase4/edit#
# Instead, wrap your gymnasium environment with a TimeLimit wrapper,
# which will set truncated according to the number of timesteps
# see https://gymnasium.farama.org/api/wrappers/misc_wrappers/#gymnasium.wrappers.TimeLimit
truncated = False
result = (
self.make_observation(),
self.state.reward,
self.state.terminal,
truncated,
self.state.info,
)
logger.debug(
f"{type(self)}: {result[1]} {result[2]}, {result[3]},"
f" {len(self.state.children)}"
)
return result
[docs] def make_observation(self) -> List[any]:
"""
Makes an observation for this state which includes observations of
each possible action, and the current state.
Expects the action observations to all be Dicts with the same keys.
Returns a column-oriented representation, a Dict with keys matching
the action observation keys, and values that are the current state
and every action's values for that key concatenated into numpy arrays.
The current state is the 0th entry in these arrays, and the children
are offset by one index to accommodate that.
Returns:
List[any]: A list of next state observations.
"""
assert (
len(self.state.children) <= self.max_num_children
), f"{self.state} exceeds the maximum number of children"
return [state.observation for state in (self.state, *self.state.children)]
[docs] def render(self, mode: str = "human") -> Any:
"""Delegates to Vertex.render()"""
if mode == "human":
return self.state.render()