graphenv.graph_model.GraphModel

class GraphModel(obs_space, action_space, num_outputs, model_config, name, *args, **kwargs)[source]

Bases: object

Defines a RLLib compatible model for using RL algorithms on a GraphEnv.

Parameters
  • obs_space (gymnasium.spaces.space.Space) – The observation space to use.

  • action_space (gymnasium.spaces.space.Space) – The action space to use.

  • num_outputs (int) – The number of scalar outputs per state to produce.

  • model_config (Dict) – Config forwarded to TFModelV2.__init()__.

  • name (str) – Config forwarded to TFModelV2.__init()__.

Methods

forward

Tensorflow/Keras style forward method.

forward_vertex

Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)

value_function

returns

A tensor of current state values.

forward(input_dict, state, seq_lens)[source]

Tensorflow/Keras style forward method. Sets up the computation graph used by this model.

Parameters
  • input_dict (Dict[str, Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]]) – Observation input to the model. Consists of a dictionary including key ‘obs’ which stores the raw observation from the process.

  • state (List[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]]) – Tensor of current states. Passes through this function untouched.

  • seq_lens (Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]) – Unused. Required by API.

Returns

(action weights tensor, state)

Return type

Tuple[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor], List[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]]]

abstract forward_vertex(input_dict)[source]

Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)

Parameters

input_dict – per-vertex observations

Returns

(value tensor, weight tensor) for the given observations

Return type

Tuple[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor], Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]]

value_function()[source]
Returns

A tensor of current state values.