graphenv.examples.hallway.hallway_model_torch.TorchHallwayModel
- class TorchHallwayModel(*args, hidden_dim=1, **kwargs)[source]
Bases:
graphenv.graph_model.TorchGraphModel
,ray.rllib.models.torch.torch_modelv2.TorchModelV2
,torch.nn.modules.module.Module
An example GraphModel implementation for the HallwayEnv and HallwayState Graph. Uses a dense fully connected Torch network.
- Parameters
hidden_dim (int, optional) – The number of hidden layers to use. Defaults to 1.
Initialize a TorchModelV2.
Here is an example implementation for a subclass
MyModelClass(TorchModelV2, nn.Module)
:def __init__(self, *args, **kwargs): TorchModelV2.__init__(self, *args, **kwargs) nn.Module.__init__(self) self._hidden_layers = nn.Sequential(...) self._logits = ... self._value_branch = ...
Methods
add_module
Adds a child module to the current module.
apply
Applies
fn
recursively to every submodule (as returned by.children()
) as well as self.bfloat16
Casts all floating point parameters and buffers to
bfloat16
datatype.buffers
Returns an iterator over module buffers.
children
Returns an iterator over immediate children modules.
context
Returns a contextmanager for the current forward pass.
cpu
Moves all model parameters and buffers to the CPU.
cuda
Moves all model parameters and buffers to the GPU.
custom_loss
Override to customize the loss function used to optimize this model.
double
Casts all floating point parameters and buffers to
double
datatype.eval
Sets the module in evaluation mode.
extra_repr
Set the extra representation of the module
float
Casts all floating point parameters and buffers to
float
datatype.forward
Tensorflow/Keras style forward method.
Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)
from_batch
get_buffer
Returns the buffer given by
target
if it exists, otherwise throws an error.get_extra_state
Returns any extra state to include in the module's state_dict.
get_initial_state
Get the initial recurrent state values for the model.
get_parameter
Returns the parameter given by
target
if it exists, otherwise throws an error.get_submodule
Returns the submodule given by
target
if it exists, otherwise throws an error.half
Casts all floating point parameters and buffers to
half
datatype.import_from_h5
Imports weights from an h5 file.
ipu
Moves all model parameters and buffers to the IPU.
is_time_major
If True, data for calling this ModelV2 must be in time-major format.
last_output
Returns the last output returned from calling the model.
load_state_dict
Copies parameters and buffers from
state_dict
into this module and its descendants.metrics
Override to return custom metrics from your model.
modules
Returns an iterator over all modules in the network.
named_buffers
Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children
Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules
Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters
Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters
Returns an iterator over module parameters.
register_backward_hook
Registers a backward hook on the module.
register_buffer
Adds a buffer to the module.
register_forward_hook
Registers a forward hook on the module.
register_forward_pre_hook
Registers a forward pre-hook on the module.
register_full_backward_hook
Registers a backward hook on the module.
register_full_backward_pre_hook
Registers a backward pre-hook on the module.
register_load_state_dict_post_hook
Registers a post hook to be run after module's
load_state_dict
is called.register_module
Alias for
add_module()
.register_parameter
Adds a parameter to the module.
register_state_dict_pre_hook
These hooks will be called with arguments:
self
,prefix
, andkeep_vars
before callingstate_dict
onself
.requires_grad_
Change if autograd should record operations on parameters in this module.
set_extra_state
This function is called from
load_state_dict()
to handle any extra state found within the state_dict.share_memory
See
torch.Tensor.share_memory_()
state_dict
Returns a dictionary containing references to the whole state of the module.
to
Moves and/or casts the parameters and buffers.
to_empty
Moves the parameters and buffers to the specified device without copying storage.
train
Sets the module in training mode.
trainable_variables
Returns the list of trainable variables for this model.
type
Casts all parameters and buffers to
dst_type
.value_function
- returns
A tensor of current state values.
variables
Returns the list (or a dict) of variables for this model.
xpu
Moves all model parameters and buffers to the XPU.
zero_grad
Sets gradients of all model parameters to zero.
Attributes
T_destination
alias of TypeVar('T_destination', bound=
Dict
[str
,Any
])call_super_init
dump_patches
- forward_vertex(input_dict)[source]
Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)
- Parameters
input_dict (Union[numpy.array, tf.Tensor, torch.Tensor, dict, tuple]) – per-vertex observations
- Returns
(value tensor, weight tensor) for the given observations
- Return type
Tuple[Union[numpy.array, tf.Tensor, torch.Tensor], Union[numpy.array, tf.Tensor, torch.Tensor]]