Reinforcement Learning Workflow
Reinforcement Learning (RL) is one of the primary focus areas of the Space Robotics Bench. While there are several RL frameworks with their unique peculiarities, SRB offers a unified interface for training and evaluating policies across a diverse set of space robotics tasks.
1. Train your 1st RL Agent
Reference:
srb agent train
— Train Agent
The fastest way to get started with training an RL agent is by using the srb agent train
command, which provides a streamlined interface for all integrated RL frameworks. In general, you want to specify the RL algorithm to use, the environment to train on, and the number of parallel environment instances used for rollout collection.
Let's start with a simple TODO_ENV
environment using the TODO_ALGO
algorithm. For now, omit the --headless
flag so that you can observe the convergence in real time:
srb agent train --algo TODO_ALGO --env TODO_ENV env.num_envs=512
As you begin to observe the training process, you can also monitor the progress in your terminal. After about TODO_CONVERGENCE timesteps, you will see that the agent found a stable policy that successfully solves the task. Checkpoints are saved regularly, so you are free to stop the training process at any point by sending an interrupt signal (Ctrl+C in most terminals).
2. Evaluate your Agent
Reference:
srb agent eval
— Evaluate Agent
Once training is complete, you can evaluate your agent with the srb agent eval
command:
srb agent eval --algo TODO_ALGO --env TODO_ENV env.num_envs=16
By default, the latest checkpoint from the training run is loaded for evaluation. However, you might want to run the evaluation for a checkpoint specified via --model
:
srb agent eval --algo TODO_ALGO --env TODO_ENV env.num_envs=16 --model space_robotics_bench/logs/TODO_ENV/TODO_ALGO/TODO_CHECKPOINT
3. Try a Different Algorithm
SRB directly supports several popular RL algorithms from different frameworks:
Algorithm Type | DreamerV3 | Stable-Baselines3 | SBX | skrl |
---|---|---|---|---|
Model-based | dreamer | |||
Value-based | sb3_dqn | sbx_dqn | skrl_dqn | |
sb3_qrdqn | skrl_ddqn | |||
sb3_crossq | sbx_crossq | |||
Policy Gradient | sb3_ppo | sbx_ppo | skrl_ppo | |
sb3_ppo_lstm | skrl_ppo_rnn | |||
sb3_trpo | skrl_trpo | |||
sb3_a2c | skrl_a2c | |||
skrl_rpo | ||||
Actor-Critic | sb3_ddpg | sbx_ddpg | skrl_ddpg | |
sb3_td3 | sbx_td3 | skrl_td3 | ||
sb3_sac | sbx_sac | skrl_sac | ||
sb3_tqc | sbx_tqc | |||
Evolutionary | sb3_ars | |||
skrl_cem | ||||
Imitation-based | skrl_amp | |||
Multi-agent | skrl_ippo | |||
skrl_mappo |
This time, you can train another agent using an algorithm of your choice:
srb agent train --headless --algo <ALGO> --env TODO_ENV env.num_envs=1024
Hint: Use
--headless
mode with more parallel environments for faster convergence.
4. Monitor Training Progress
While training, you might be interested in monitoring the progress and comparing different runs through a visual interface. By default, TensorBoard logs are saved for all algorithms and environments in the space_robotics_bench/logs
directory. You can start TensorBoard to visualize the training progress:
tensorboard --logdir ./space_robotics_bench/logs --bind_all
Furthermore, you can enable Weights & Biases (wandb
) logging by passing framework-specific flags [subject to future standardization]:
- DreamerV3:
srb agent train ... +agent.logger.outputs=wandb
- SB3 & SBX:
srb agent train ... agent.track=true
- skrl:
srb agent train ... +agent.experiment.wandb=true
Note: Logging to Weights & Biases requires an account and API key.
5. Configure Hyperparameters
Reference: Agent Configuration
The default hyperparameters for all algorithms and environments are available under the space_robotics_bench/hyperparams directory. Similar to the environment configuration, you can adjust the hyperparameters of the selected RL algorithm through Hydra. However, the available hyperparameters and their structure is specific to each framework and algorithm.
Here are some examples (consult hyperparameter configs for more details):
srb agent train --algo dreamer agent.run.train_raio=128 ...
sed agent train --algo sb3_ppo agent.gamma=0.99 ...
srb agent train --algo sbx_sac agent.learning_rate=0.0002 ...
srb agent train --algo skrl_ppo_rnn agent.models.separate=True ...