graphenv.examples.tsp.tsp_model.TSPModel
- class TSPModel(*args, num_nodes, hidden_dim=32, embed_dim=32, **kwargs)[source]
Bases:
graphenv.examples.tsp.tsp_model.BaseTSPModel,ray.rllib.models.tf.tf_modelv2.TFModelV2Initializes a TFModelV2 instance.
Here is an example implementation for a subclass
MyModelClass(TFModelV2):def __init__(self, *args, **kwargs): super(MyModelClass, self).__init__(*args, **kwargs) input_layer = tf.keras.layers.Input(...) hidden_layer = tf.keras.layers.Dense(...)(input_layer) output_layer = tf.keras.layers.Dense(...)(hidden_layer) value_layer = tf.keras.layers.Dense(...)(hidden_layer) self.base_model = tf.keras.Model( input_layer, [output_layer, value_layer])
Methods
contextReturns a contextmanager for the current TF graph.
custom_lossOverride to customize the loss function used to optimize this model.
forwardTensorflow/Keras style forward method.
forward_vertexForward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)
from_batchget_initial_stateGet the initial recurrent state values for the model.
import_from_h5Imports weights from an h5 file.
is_time_majorIf True, data for calling this ModelV2 must be in time-major format.
last_outputReturns the last output returned from calling the model.
metricsOverride to return custom metrics from your model.
register_variablesRegister the given list of variables with this model.
trainable_variablesReturns the list of trainable variables for this model.
update_opsReturn the list of update ops for this model.
value_function- returns
A tensor of current state values.
variablesReturns the list (or a dict) of variables for this model.
- Parameters
num_nodes (int) –
hidden_dim (int) –
embed_dim (int) –