Running GraphEnv with ray.tune

Practical reinforcement learning will typically leverage the ray.tune infrastructure to scale up environment rollouts and policy model training. For the hallway example, an example tensorflow implementation consists of the following:

 1import ray
 2from graphenv.examples.hallway.hallway_model import HallwayModel
 3from graphenv.examples.hallway.hallway_state import HallwayState
 4from graphenv.graph_env import GraphEnv
 5from ray import tune
 6
 7config = {
 8    "env": GraphEnv,
 9    "env_config": {
10        "state": HallwayState(5),
11        "max_num_children": 2,
12    },
13    "model": {
14        "custom_model": HallwayModel,
15        "custom_model_config": {"hidden_dim": 32},
16    },
17    "framework": "tf2",
18    "eager_tracing": True,
19    "num_workers": 1,
20}
21
22stop = {
23    "training_iteration": 5,
24}
25
26if __name__ == "__main__":
27
28    ray.init()
29
30    tune.run(
31        "PPO",
32        config=config,
33        stop=stop,
34    )

In lines 7-20, we specify configuration options for PPO, including matching the framework with that used in the provided HallwayModel policy. This script runs 5 iterations of the PPO training algorithm, and the results can be monitored with tensorboard.

Running the same experiment with pytorch requires writing a pytorch-compatible policy model, demonstrated in graphenv.examples.hallway.hallway_model_torch. Beyond this, the only required modifications to the training script to use pytorch instead of tensorflow are shown below:

 1import ray
 2from graphenv.examples.hallway.hallway_model_torch import TorchHallwayModel
 3from graphenv.examples.hallway.hallway_state import HallwayState
 4from graphenv.graph_env import GraphEnv
 5from ray import tune
 6
 7config = {
 8    "env": GraphEnv,
 9    "env_config": {
10        "state": HallwayState(5),
11        "max_num_children": 2,
12    },
13    "model": {
14        "custom_model": TorchHallwayModel,
15        "custom_model_config": {"hidden_dim": 32},
16    },
17    "framework": "torch",
18    "num_workers": 1,
19}
20
21stop = {
22    "training_iteration": 5,
23}
24
25if __name__ == "__main__":
26
27    ray.init()
28
29    tune.run(
30        "PPO",
31        config=config,
32        stop=stop,
33    )