graphenv.graph_model.TorchGraphModel

class TorchGraphModel(obs_space, action_space, num_outputs, model_config, name, *args, **kwargs)[source]

Bases: graphenv.graph_model.GraphModel

Methods

forward

Tensorflow/Keras style forward method.

forward_vertex

Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)

value_function

returns

A tensor of current state values.

Parameters
  • obs_space (gymnasium.spaces.space.Space) –

  • action_space (gymnasium.spaces.space.Space) –

  • num_outputs (int) –

  • model_config (Dict) –

  • name (str) –