graphenv.examples.hallway.hallway_model.HallwayModel

class HallwayModel(*args, hidden_dim=1, **kwargs)[source]

Bases: graphenv.examples.hallway.hallway_model.BaseHallwayModel, ray.rllib.models.tf.tf_modelv2.TFModelV2

Initializes a TFModelV2 instance.

Here is an example implementation for a subclass MyModelClass(TFModelV2):

def __init__(self, *args, **kwargs):
    super(MyModelClass, self).__init__(*args, **kwargs)
    input_layer = tf.keras.layers.Input(...)
    hidden_layer = tf.keras.layers.Dense(...)(input_layer)
    output_layer = tf.keras.layers.Dense(...)(hidden_layer)
    value_layer = tf.keras.layers.Dense(...)(hidden_layer)
    self.base_model = tf.keras.Model(
        input_layer, [output_layer, value_layer])

Methods

context

Returns a contextmanager for the current TF graph.

custom_loss

Override to customize the loss function used to optimize this model.

forward

Tensorflow/Keras style forward method.

forward_vertex

Forward function computing the evaluation of vertex observations.

from_batch

get_initial_state

Get the initial recurrent state values for the model.

import_from_h5

Imports weights from an h5 file.

is_time_major

If True, data for calling this ModelV2 must be in time-major format.

last_output

Returns the last output returned from calling the model.

metrics

Override to return custom metrics from your model.

register_variables

Register the given list of variables with this model.

trainable_variables

Returns the list of trainable variables for this model.

update_ops

Return the list of update ops for this model.

value_function

returns

A tensor of current state values.

variables

Returns the list (or a dict) of variables for this model.

Parameters

hidden_dim (int) –