graphenv.graph_model.GraphModel
- class GraphModel(obs_space, action_space, num_outputs, model_config, name, *args, **kwargs)[source]
Bases:
object
Defines a RLLib compatible model for using RL algorithms on a GraphEnv.
- Parameters
obs_space (gymnasium.spaces.space.Space) – The observation space to use.
action_space (gymnasium.spaces.space.Space) – The action space to use.
num_outputs (int) – The number of scalar outputs per state to produce.
model_config (Dict) – Config forwarded to TFModelV2.__init()__.
name (str) – Config forwarded to TFModelV2.__init()__.
Methods
Tensorflow/Keras style forward method.
Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)
- returns
A tensor of current state values.
- forward(input_dict, state, seq_lens)[source]
Tensorflow/Keras style forward method. Sets up the computation graph used by this model.
- Parameters
input_dict (Dict[str, Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]]) – Observation input to the model. Consists of a dictionary including key ‘obs’ which stores the raw observation from the process.
state (List[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]]) – Tensor of current states. Passes through this function untouched.
seq_lens (Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]) – Unused. Required by API.
- Returns
(action weights tensor, state)
- Return type
Tuple[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor], List[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]]]
- abstract forward_vertex(input_dict)[source]
Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)
- Parameters
input_dict – per-vertex observations
- Returns
(value tensor, weight tensor) for the given observations
- Return type
Tuple[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor], Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]]