from abc import abstractmethod
from typing import Any, Dict, Generic, List, Optional, Sequence, TypeVar
import gymnasium as gym
V = TypeVar("V")
[docs]class Vertex(Generic[V]):
"""Abstract class defining a vertex in a graph. To implement a graph using
this class, subclass Vertex and implement the abstract methods below.
Args:
Generic (V): The implementing vertex subclass.
"""
def __init__(self) -> None:
self._children: Optional[List] = None #: memoized list of child vertices
self._observation: Optional[any] = None #: memoized observation of this vertex
@property
@abstractmethod
def observation_space(self) -> gym.spaces.Space:
"""Gets the vertex observation space, used to define the structure of
the data returned when observing a vertex.
Returns:
gym.spaces.Space: Vertex observation space
"""
raise NotImplementedError
@property
@abstractmethod
def root(self) -> V:
"""Gets the root vertex of the graph. Not required to always return the
same vertex.
Returns:
N: The root vertex of the graph.
"""
raise NotImplementedError
@property
@abstractmethod
def reward(self) -> float:
"""Gets the reward for this vertex.
Returns:
float: reward for this vertex
"""
raise NotImplementedError
[docs] def render(self) -> Any:
"""Optional method for rendering the current state of the environment."""
raise NotImplementedError
@abstractmethod
def _get_children(self) -> Sequence[V]:
"""Gets the child vertices of this vertex.
Returns:
Sequence[N]: Sequence of child vertices.
"""
raise NotImplementedError
@abstractmethod
def _make_observation(self) -> any:
"""Gets an observation of this vertex. This observation should have
the same shape as described by the vertex observation space.
Returns:
any: Observation with the same shape as defined by
the observation space.
"""
raise NotImplementedError
@property
def children(self) -> List[V]:
"""
Gets the child vertices of this vertex.
Acts as a wrapper that memorizes calls to _get_children() and
ensures that it is a list. If you would like a different behavior,
such as stochastic child vertices, override this property.
Returns:
List[N] : List of child vertices
"""
if self._children is None:
self._children = list(self._get_children())
return self._children
@property
def observation(self) -> any:
"""
Gets the observation of this vertex.
Acts as a wrapper that memorizes calls to _make_observation().
If you would like a different behavior,
such as stochastic observations, override this property.
Returns:
Observation of this vertex.
"""
if self._observation is None:
self._observation = self._make_observation()
return self._observation
@property
def terminal(self) -> bool:
"""
Returns:
True if this is a terminal vertex in the graph.
"""
return len(self.children) == 0
@property
def info(self) -> Dict:
"""
Returns:
An optional dictionary with additional information about the state
"""
return dict()