graphenv.examples.tsp.tsp_model.BaseTSPModel
- class BaseTSPModel(*args, num_nodes, hidden_dim=32, embed_dim=32, **kwargs)[source]
Bases:
graphenv.graph_model.GraphModel
Methods
forward
Tensorflow/Keras style forward method.
Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)
value_function
- returns
A tensor of current state values.
- Parameters
num_nodes (int) –
hidden_dim (int) –
embed_dim (int) –
- forward_vertex(input_dict)[source]
Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)
- Parameters
input_dict (Union[numpy.array, tf.Tensor, torch.Tensor, dict, tuple]) – per-vertex observations
- Returns
(value tensor, weight tensor) for the given observations
- Return type
Tuple[Union[numpy.array, tf.Tensor, torch.Tensor], Union[numpy.array, tf.Tensor, torch.Tensor]]