Source code for graphenv.examples.tsp.tsp_state

from typing import Any, Callable, Dict, List, Optional, Sequence

import gymnasium as gym
import networkx as nx
import numpy as np
from graphenv import tf
from graphenv.examples.tsp.graph_utils import plot_network
from graphenv.vertex import Vertex

layers = tf.keras.layers


[docs]class TSPState(Vertex): """Create a TSP vertex that defines the graph search problem. Args: generator: a function that creates a networkx graph G: A fully connected networkx graph. tour: A list of nodes in visitation order that led to this state. Defaults to [0] which begins the tour at node 0. """ def __init__( self, graph_generator: Callable[[], nx.Graph], G: Optional[nx.Graph] = None, tour: List[int] = [0], ) -> None: super().__init__() self.G = G if G is not None else graph_generator() self.num_nodes = self.G.number_of_nodes() self.tour = tour self.graph_generator = graph_generator @property def observation_space(self) -> gym.spaces.Dict: """Returns the graph env's observation space. Returns: Dict observation space. """ return gym.spaces.Dict( { "node_obs": gym.spaces.Box( low=np.zeros(2), high=np.ones(2), dtype=float ), "node_idx": gym.spaces.Box( low=0, high=self.num_nodes, shape=(1,), dtype=int ), "parent_dist": gym.spaces.Box( low=0.0, high=np.sqrt(2), shape=(1,), dtype=float ), "nbr_dist": gym.spaces.Box( low=0.0, high=np.sqrt(2), shape=(1,), dtype=float ), } ) @property def root(self) -> "TSPState": """Returns the root node of the graph env. Returns: Node with node 0 as the starting point of the tour, and generates a new graph using the given constructor """ return self.new([0], new_graph=True) @property def reward(self) -> float: """Returns the graph env reward. Returns: Negative distance between last two nodes in the tour. """ if len(self.tour) == 1: # This should never be called rew = 0 elif len(self.tour) >= 2: # Otherwise, reward is negative distance between last two nodes. src, dst = self.tour[-2:] rew = -self.G[src][dst]["weight"] else: raise RuntimeError(f"Invalid tour: {self.tour}") return rew
[docs] def new(self, tour: List[int] = [0], new_graph=False, **kwargs): """Convenience function for duplicating the existing node. Args: G: Networkx graph. tour: List of visited nodes. Returns: New TSP state. """ G = self.G if not new_graph else self.graph_generator() return self.__class__(self.graph_generator, G=G, tour=tour, **kwargs)
[docs] def render(self) -> Any: return plot_network(self.G, self.tour, draw_all_edges=False)
@property def info(self) -> Dict: return {} def _get_children(self) -> Sequence["TSPState"]: """Yields a sequence of TSPState instances associated with the next accessible nodes. Yields: New instance of the TSPState with the next node added to the tour. """ G = self.G cur_node = self.tour[-1] # Look at neighbors not already on the path. nbrs = [n for n in G.neighbors(cur_node) if n not in self.tour] # Go back to the first node if we've visited every other already. if len(nbrs) == 0 and len(self.tour) == self.num_nodes: nbrs = [self.tour[0]] # Conditions for completing the circuit. if len(nbrs) == 0 and len(self.tour) == self.num_nodes + 1: nbrs = [] # Loop over the neighbors and update paths. for nbr in nbrs: # Update the node path with next node. tour = self.tour.copy() tour.append(nbr) yield self.new(tour) def _make_observation(self) -> Dict[str, np.ndarray]: """Return an observation. The dict returned here needs to match both the self.observation_space in this class, as well as the input layer in tsp_model.TSPModel Returns: Observation dict. We define the node_obs to be the degree of the current node. This is a placeholder for a more meaningful feature! """ cur_node = self.tour[-1] cur_pos = np.array(self.G.nodes[cur_node]["pos"], dtype=float).squeeze() # Compute distance to parent node, or 0 if this is the root. if len(self.tour) == 1: parent_dist = 0.0 else: parent_dist = self.G[cur_node][self.tour[-2]]["weight"] # Get list of all neighbors that are unvisited. If none, then the only # remaining neighbor is the root so dist is 0. nbrs = [n for n in self.G.neighbors(cur_node) if n not in self.tour] nbr_dist = 0.0 if len(nbrs) > 0: nbr_dist = np.min([self.G[cur_node][n]["weight"] for n in nbrs]) return { "node_obs": cur_pos, "node_idx": np.array([cur_node]), "parent_dist": np.array([parent_dist]), "nbr_dist": np.array([nbr_dist]), }