Source code for graphenv.examples.hallway.hallway_state

import random
from typing import Dict, Sequence

import gymnasium as gym
import numpy as np
from graphenv import tf
from graphenv.vertex import Vertex

layers = tf.keras.layers


[docs]class HallwayState(Vertex): """Example Vertex implementation of a simple hallway process graph. The hallway graph is a simple bidirectional chain of vertices. The root vertex is on one end of the chain and the terminal goal vertex is on the opposite end. The length is configurable. Args: corridor_length (int): length of the vertex chain cur_pos (int, optional): initial vertex index. Defaults to 0. """ def __init__( self, corridor_length: int, cur_pos: int = 0, ) -> None: super().__init__() self.end_pos = corridor_length - 1 self.cur_pos = cur_pos @property def observation_space(self) -> gym.spaces.Dict: """HallwayStates are observed with a dictionary containing a single key, 'cur_pos', with an integer value between 0 and self.end_pos, indicating the index of the vertex. Returns: gym.spaces.Dict: The observation space for HallwayStates. """ return gym.spaces.Dict( { "cur_pos": gym.spaces.Box( low=np.array([0]), high=np.array([self.end_pos]), dtype=int ), } ) @property def root(self) -> "HallwayState": """ Returns: HallwayState: initial state (vertex at index 0) """ return self.new(0) @property def reward(self) -> float: """The reward function for the HallwayState graph. Returns: float: random reward between 0 and 2 on the goal vertex, -0.1 otherwise. """ return random.random() * 2 if self.cur_pos >= self.end_pos else -0.1
[docs] def new(self, cur_pos: int): """Convenience function for duplicating the existing node. Returns: HallwayState : a copy of this HallwayState. """ return HallwayState(self.end_pos + 1, cur_pos)
@property def info(self) -> Dict: """ Debugging information compiled and returned by the environment step() method about vertices passed through or considered. Returns: Dict: Debugging information including the index of this vertex. """ info = super().info info["cur_pos"] = self.cur_pos return info def _get_children(self) -> Sequence["HallwayState"]: """Gets child vertices of this vertex. Each vertex has both larger and smaller adjacent index vertices as children, except for the initial and goal vertices. Yields: HallwayState: Child vertices of this vertex. """ if self.cur_pos < self.end_pos: if self.cur_pos > 0: # Stop the hallway from going negative yield self.new(self.cur_pos - 1) yield self.new(self.cur_pos + 1) def _make_observation(self) -> Dict[str, np.ndarray]: """Makes an observation of this HallwayState vertex. Returns: Dict[str, np.ndarray]: dictionary containing the current position index under the key 'cur_pos'. """ return { "cur_pos": np.array([self.cur_pos], dtype=int), }