Running GraphEnv
with ray.tune
Practical reinforcement learning will typically leverage the ray.tune
infrastructure
to scale up environment rollouts and policy model training. For the hallway example, an
example tensorflow
implementation consists of the following:
1import ray
2from graphenv.examples.hallway.hallway_model import HallwayModel
3from graphenv.examples.hallway.hallway_state import HallwayState
4from graphenv.graph_env import GraphEnv
5from ray import tune
6
7config = {
8 "env": GraphEnv,
9 "env_config": {
10 "state": HallwayState(5),
11 "max_num_children": 2,
12 },
13 "model": {
14 "custom_model": HallwayModel,
15 "custom_model_config": {"hidden_dim": 32},
16 },
17 "framework": "tf2",
18 "eager_tracing": True,
19 "num_workers": 1,
20}
21
22stop = {
23 "training_iteration": 5,
24}
25
26if __name__ == "__main__":
27
28 ray.init()
29
30 tune.run(
31 "PPO",
32 config=config,
33 stop=stop,
34 )
In lines 7-20, we specify configuration options for PPO, including matching the
framework with that used in the provided HallwayModel
policy. This script runs 5
iterations of the PPO training algorithm, and the results can be monitored with
tensorboard.
Running the same experiment with pytorch
requires writing a pytorch-compatible
policy model, demonstrated in graphenv.examples.hallway.hallway_model_torch
. Beyond
this, the only required modifications to the training script to use pytorch instead of
tensorflow are shown below:
1import ray
2from graphenv.examples.hallway.hallway_model_torch import TorchHallwayModel
3from graphenv.examples.hallway.hallway_state import HallwayState
4from graphenv.graph_env import GraphEnv
5from ray import tune
6
7config = {
8 "env": GraphEnv,
9 "env_config": {
10 "state": HallwayState(5),
11 "max_num_children": 2,
12 },
13 "model": {
14 "custom_model": TorchHallwayModel,
15 "custom_model_config": {"hidden_dim": 32},
16 },
17 "framework": "torch",
18 "num_workers": 1,
19}
20
21stop = {
22 "training_iteration": 5,
23}
24
25if __name__ == "__main__":
26
27 ray.init()
28
29 tune.run(
30 "PPO",
31 config=config,
32 stop=stop,
33 )