Source code for graphenv.examples.hallway.hallway_model_torch

from typing import Tuple

from graphenv import nn
from graphenv.graph_model import TorchGraphModel
from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.typing import TensorStructType, TensorType


[docs]class TorchHallwayModel( TorchGraphModel, TorchModelV2, nn.Module, ): """An example GraphModel implementation for the HallwayEnv and HallwayState Graph. Uses a dense fully connected Torch network. Args: hidden_dim (int, optional): The number of hidden layers to use. Defaults to 1. """ def __init__( self, *args, hidden_dim: int = 1, **kwargs, ): super().__init__(*args, **kwargs) nn.Module.__init__(self) self.hidden_layer = nn.Linear(1, hidden_dim) self.action_value_output = nn.Linear(hidden_dim, 1) self.action_weight_output = nn.Linear(hidden_dim, 1) self.relu = nn.ReLU()
[docs] def forward_vertex( self, input_dict: TensorStructType, ) -> Tuple[TensorType, TensorType]: x = self.hidden_layer(input_dict["cur_pos"].float()) x = self.relu(x) return ( self.action_value_output(x), self.action_weight_output(x), )
[docs]class TorchHallwayQModel(TorchHallwayModel, DQNTorchModel): pass
# class TorchHallwayModel(BaseTorchHallwayModel): # pass