Source code for graphenv.examples.tsp.tsp_model

from typing import Tuple

from graphenv import tf
from graphenv.graph_model import GraphModel
from graphenv.graph_model_bellman_mixin import GraphModelBellmanMixin
from ray.rllib.algorithms.dqn.distributional_q_tf_model import DistributionalQTFModel
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.typing import TensorStructType, TensorType

layers = tf.keras.layers


[docs]class BaseTSPModel(GraphModel): def __init__( self, *args, num_nodes: int, hidden_dim: int = 32, embed_dim: int = 32, **kwargs, ): super().__init__(*args, **kwargs) self.base_model = self._create_base_model(num_nodes, hidden_dim, embed_dim) @staticmethod def _create_base_model( num_nodes: int, hidden_dim: int = 32, embed_dim: int = 32 ) -> tf.keras.Model: node_obs = layers.Input(shape=(2,), name="node_obs", dtype=tf.float32) node_idx = layers.Input(shape=(1,), name="node_idx", dtype=tf.int32) parent_dist = layers.Input(shape=(1,), name="parent_dist", dtype=tf.float32) nbr_dist = layers.Input(shape=(1,), name="nbr_dist", dtype=tf.float32) embed_layer = layers.Embedding( num_nodes, embed_dim, name="embed_layer", input_length=1 ) hidden_layer_1 = layers.Dense( hidden_dim, name="hidden_layer_1", activation="relu" ) hidden_layer_2 = layers.Dense( hidden_dim, name="hidden_layer_2", activation="linear" ) action_value_output = layers.Dense( 1, name="action_value_output", bias_initializer="ones" ) action_weight_output = layers.Dense( 1, name="action_weight_output", bias_initializer="ones" ) # Process the positional node data. Here we need to expand the # middle axis to match the embedding output dimension. x = layers.Concatenate(axis=-1)([node_obs, parent_dist, nbr_dist]) hidden = layers.Reshape((1, hidden_dim))(hidden_layer_1(x)) # Process the embedding. embed = embed_layer(node_idx) # Concatenate and flatten for dense output layers. out = layers.Concatenate(axis=-1)([hidden, embed]) out = layers.Flatten()(out) out = hidden_layer_2(out) # Action values and weights for RLLib algorithms action_values = action_value_output(out) action_weights = action_weight_output(out) return tf.keras.Model( [node_obs, node_idx, parent_dist, nbr_dist], [action_values, action_weights] )
[docs] def forward_vertex( self, input_dict: TensorStructType, ) -> Tuple[TensorType, TensorType]: return tuple(self.base_model(input_dict))
[docs]class TSPModel(BaseTSPModel, TFModelV2): pass
[docs]class TSPQModel(BaseTSPModel, DistributionalQTFModel): pass
[docs]class TSPQModelBellman(GraphModelBellmanMixin, BaseTSPModel, DistributionalQTFModel): pass