graphenv.examples.tsp.tsp_model.BaseTSPModel

class BaseTSPModel(*args, num_nodes, hidden_dim=32, embed_dim=32, **kwargs)[source]

Bases: graphenv.graph_model.GraphModel

Methods

forward

Tensorflow/Keras style forward method.

forward_vertex

Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)

value_function

returns

A tensor of current state values.

Parameters
  • num_nodes (int) –

  • hidden_dim (int) –

  • embed_dim (int) –

forward_vertex(input_dict)[source]

Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)

Parameters

input_dict (Union[numpy.array, tf.Tensor, torch.Tensor, dict, tuple]) – per-vertex observations

Returns

(value tensor, weight tensor) for the given observations

Return type

Tuple[Union[numpy.array, tf.Tensor, torch.Tensor], Union[numpy.array, tf.Tensor, torch.Tensor]]