Source code for graphenv.graph_model

import logging
from abc import abstractmethod
from typing import Dict, List, Tuple

import gymnasium as gym
from ray.rllib.models.repeated_values import RepeatedValues
from ray.rllib.utils.typing import TensorStructType, TensorType

from graphenv import tf, torch

logger = logging.getLogger(__file__)


[docs]class GraphModel: """Defines a RLLib compatible model for using RL algorithms on a GraphEnv. Args: obs_space: The observation space to use. action_space: The action space to use. num_outputs: The number of scalar outputs per state to produce. model_config: Config forwarded to TFModelV2.__init()__. name: Config forwarded to TFModelV2.__init()__. """ _tensorlib = "tf" def __init__( self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: int, model_config: Dict, name: str, *args, **kwargs, ): super().__init__( obs_space, action_space, num_outputs, model_config, name, *args, **kwargs, ) self.current_vertex_value = None self.action_values = None self.current_vertex_weight = None self.action_weights = None self.num_outputs = num_outputs logger.debug(f"num_outputs: {num_outputs}") assert self._tensorlib is not None
[docs] def forward( self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType, ) -> Tuple[TensorType, List[TensorType]]: """ Tensorflow/Keras style forward method. Sets up the computation graph used by this model. Args: input_dict: Observation input to the model. Consists of a dictionary including key 'obs' which stores the raw observation from the process. state: Tensor of current states. Passes through this function untouched. seq_lens: Unused. Required by API. Returns: (action weights tensor, state) """ mask = _create_action_mask(input_dict["obs"], self._tensorlib) flattened_observations = _stack_batch_dim( input_dict["obs"], mask, self._tensorlib ) flat_values, flat_weights = self.forward_vertex(flattened_observations) # mask out invalid children and get current vertex value self.current_vertex_value, _ = _mask_and_split_values( flat_values, input_dict["obs"], self._tensorlib ) _, action_weights = _mask_and_split_values( flat_weights, input_dict["obs"], self._tensorlib ) self.total_value = self._forward_total_value() return action_weights, state
[docs] def value_function(self): """ Returns: A tensor of current state values. """ return self.total_value
[docs] @abstractmethod def forward_vertex( self, input_dict, ) -> Tuple[TensorType, TensorType]: """Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property) Args: input_dict: per-vertex observations Returns: (value tensor, weight tensor) for the given observations """ pass
def _forward_total_value(self): """Forward method computing the value assessment of the current state, as returned by the value_function() method. The default implementation return the action value of the current state. Breaking this into a separate method allows subclasses to override the state value assessment, for example with a Bellman backup returning the max over all successor states's values. Returns: current value tensor """ return self.current_vertex_value
[docs]class TorchGraphModel(GraphModel): _tensorlib = "torch"
def _create_action_mask(obs: RepeatedValues, tensorlib: str = "tf") -> TensorType: """Create an action mask array of valid actions from a given RepeatedValues tensor. Args: obs (RepeatedValues): The input observations tensorlib (Any, optional): A reference to the current framework. Defaults to tf. Raises: NotImplementedError: if the given framework is not supported Returns: TensorType: The boolean mask for valid actions (includes the current state as the first index). """ if tensorlib == "tf": # the "dummy batch" rllib provides to initialize the policy model is a matrix of # all zeros, which ends with a batch size of zero provided to the policy model. # We can assume that at least the input state is valid, and clip the row_lengths # vector to a minimum of 1 per (state, *actions) entry. row_lengths = tf.clip_by_value(tf.cast(obs.lengths, tf.int32), 1, tf.int32.max) num_elements = tf.reduce_sum(row_lengths) action_mask = tf.RaggedTensor.from_row_lengths( tf.ones(num_elements, dtype=tf.bool), row_lengths, ).to_tensor(shape=(None, obs.max_len)) elif tensorlib == "torch": # Integer torch index tensors must be long type row_lengths = torch.clip(obs.lengths.long(), 1, torch.iinfo(torch.long).max) num_elements = row_lengths.sum().item() action_mask = torch.zeros(len(row_lengths), obs.max_len, dtype=bool) mask_index = torch.LongTensor( [(i, j) for i in range(len(row_lengths)) for j in range(row_lengths[i])] ) action_mask.index_put_( tuple(mask_index.t()), torch.ones(num_elements, dtype=bool) ) else: raise NotImplementedError(f"tensorlib {tensorlib} not implemented") return action_mask def _apply_mask( values: TensorType, action_mask: TensorType, tensorlib: str = "tf" ) -> TensorType: if tensorlib == "tf": return tf.boolean_mask(values, action_mask) elif tensorlib == "torch": # masked_select returns a 1D tensor so needs reshaping. Pretty sure the last # dimension will always be the feature dim -- will action_mask always be 2d? # The .view(-1, feature_dim) call will fail if more than 2d. feature_dim = values.shape[-1] values = torch.masked_select(values, action_mask.view(*action_mask.shape, 1)) return values.view(-1, feature_dim) else: raise NotImplementedError(f"tensorlib {tensorlib} not implemented") def _stack_batch_dim( obs: TensorStructType, mask: TensorType, tensorlib: str = "tf" ) -> TensorType: if isinstance(obs, dict): return {k: _stack_batch_dim(v, mask, tensorlib) for k, v in obs.items()} elif isinstance(obs, tuple): return tuple(_stack_batch_dim(u, mask, tensorlib) for u in obs) elif isinstance(obs, RepeatedValues): return _stack_batch_dim(obs.values, mask, tensorlib) else: return _apply_mask(obs, mask, tensorlib) def _mask_and_split_values( flat_values: TensorType, obs: RepeatedValues, tensorlib: str = "tf" ) -> Tuple[TensorType, TensorType]: """Returns the value for the current vertex (index 0 of values), and the masked values of the action vertices. Args: values: Tensor to apply the action mask to. Returns: (a current state value tensor, a masked action values tensor) """ if tensorlib == "tf": row_lengths = tf.clip_by_value(tf.cast(obs.lengths, tf.int32), 1, tf.int32.max) flat_values = tf.squeeze(flat_values, axis=[-1]) values = tf.RaggedTensor.from_row_lengths(flat_values, row_lengths) values = values.to_tensor( default_value=values.dtype.min, shape=(None, obs.max_len), ) current_value = values[:, 0] masked_action_values = values[:, 1:] elif tensorlib == "torch": row_lengths = torch.clip(obs.lengths.long(), 1, torch.iinfo(torch.long).max) flat_values = flat_values.squeeze(dim=-1) value_index = torch.LongTensor( [(i, j) for i in range(len(row_lengths)) for j in range(row_lengths[i])] ) _fmin = torch.finfo(flat_values.dtype).min values = _fmin * torch.ones( len(row_lengths), obs.max_len, dtype=flat_values.dtype ) values.index_put_(tuple(value_index.t()), flat_values) current_value = values[:, 0] masked_action_values = values[:, 1:] else: raise NotImplementedError(f"tensorlib {tensorlib} not implemented") return current_value, masked_action_values