graphenv.graph_model.TorchGraphModel
- class TorchGraphModel(obs_space, action_space, num_outputs, model_config, name, *args, **kwargs)[source]
Bases:
graphenv.graph_model.GraphModel
Methods
forward
Tensorflow/Keras style forward method.
forward_vertex
Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)
value_function
- returns
A tensor of current state values.
- Parameters
obs_space (gymnasium.spaces.space.Space) –
action_space (gymnasium.spaces.space.Space) –
num_outputs (int) –
model_config (Dict) –
name (str) –