graphenv.examples.hallway.hallway_model_torch.TorchHallwayModel
- class TorchHallwayModel(*args, hidden_dim=1, **kwargs)[source]
Bases:
graphenv.graph_model.TorchGraphModel,ray.rllib.models.torch.torch_modelv2.TorchModelV2,torch.nn.modules.module.ModuleAn example GraphModel implementation for the HallwayEnv and HallwayState Graph. Uses a dense fully connected Torch network.
- Parameters
hidden_dim (int, optional) – The number of hidden layers to use. Defaults to 1.
Initialize a TorchModelV2.
Here is an example implementation for a subclass
MyModelClass(TorchModelV2, nn.Module):def __init__(self, *args, **kwargs): TorchModelV2.__init__(self, *args, **kwargs) nn.Module.__init__(self) self._hidden_layers = nn.Sequential(...) self._logits = ... self._value_branch = ...
Methods
add_moduleAdds a child module to the current module.
applyApplies
fnrecursively to every submodule (as returned by.children()) as well as self.bfloat16Casts all floating point parameters and buffers to
bfloat16datatype.buffersReturns an iterator over module buffers.
childrenReturns an iterator over immediate children modules.
contextReturns a contextmanager for the current forward pass.
cpuMoves all model parameters and buffers to the CPU.
cudaMoves all model parameters and buffers to the GPU.
custom_lossOverride to customize the loss function used to optimize this model.
doubleCasts all floating point parameters and buffers to
doubledatatype.evalSets the module in evaluation mode.
extra_reprSet the extra representation of the module
floatCasts all floating point parameters and buffers to
floatdatatype.forwardTensorflow/Keras style forward method.
Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)
from_batchget_bufferReturns the buffer given by
targetif it exists, otherwise throws an error.get_extra_stateReturns any extra state to include in the module's state_dict.
get_initial_stateGet the initial recurrent state values for the model.
get_parameterReturns the parameter given by
targetif it exists, otherwise throws an error.get_submoduleReturns the submodule given by
targetif it exists, otherwise throws an error.halfCasts all floating point parameters and buffers to
halfdatatype.import_from_h5Imports weights from an h5 file.
ipuMoves all model parameters and buffers to the IPU.
is_time_majorIf True, data for calling this ModelV2 must be in time-major format.
last_outputReturns the last output returned from calling the model.
load_state_dictCopies parameters and buffers from
state_dictinto this module and its descendants.metricsOverride to return custom metrics from your model.
modulesReturns an iterator over all modules in the network.
named_buffersReturns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_childrenReturns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modulesReturns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parametersReturns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parametersReturns an iterator over module parameters.
register_backward_hookRegisters a backward hook on the module.
register_bufferAdds a buffer to the module.
register_forward_hookRegisters a forward hook on the module.
register_forward_pre_hookRegisters a forward pre-hook on the module.
register_full_backward_hookRegisters a backward hook on the module.
register_full_backward_pre_hookRegisters a backward pre-hook on the module.
register_load_state_dict_post_hookRegisters a post hook to be run after module's
load_state_dictis called.register_moduleAlias for
add_module().register_parameterAdds a parameter to the module.
register_state_dict_pre_hookThese hooks will be called with arguments:
self,prefix, andkeep_varsbefore callingstate_dictonself.requires_grad_Change if autograd should record operations on parameters in this module.
set_extra_stateThis function is called from
load_state_dict()to handle any extra state found within the state_dict.share_memorySee
torch.Tensor.share_memory_()state_dictReturns a dictionary containing references to the whole state of the module.
toMoves and/or casts the parameters and buffers.
to_emptyMoves the parameters and buffers to the specified device without copying storage.
trainSets the module in training mode.
trainable_variablesReturns the list of trainable variables for this model.
typeCasts all parameters and buffers to
dst_type.value_function- returns
A tensor of current state values.
variablesReturns the list (or a dict) of variables for this model.
xpuMoves all model parameters and buffers to the XPU.
zero_gradSets gradients of all model parameters to zero.
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Dict[str,Any])call_super_initdump_patches- forward_vertex(input_dict)[source]
Forward function returning a value and weight tensor for the vertices observed via input_dict (a dict of tensors for each vertex property)
- Parameters
input_dict (Union[numpy.array, tf.Tensor, torch.Tensor, dict, tuple]) – per-vertex observations
- Returns
(value tensor, weight tensor) for the given observations
- Return type
Tuple[Union[numpy.array, tf.Tensor, torch.Tensor], Union[numpy.array, tf.Tensor, torch.Tensor]]