SheepRL

Index

The source code of all examples described in this section is available in our DIAMBRA Agents repository.

Getting Ready

We highly recommend using virtual environments to isolate your python installs, especially to avoid conflicts in dependencies. In what follows we use Conda but any other tool should work too.

Create and activate a new dedicated virtual environment:

conda create -n diambra-arena-sheeprl python=3.9
conda activate diambra-arena-sheeprl

Install DIAMBRA Arena with SheepRL interface:

pip install diambra-arena[sheeprl]

This should be enough to prepare your system to execute the following examples. You can refer to the official SheepRL documentation or reach out on our Discord server for specific needs.

Remember that to train agents, you must have installed the diambra CLI (python3 -m pip install diambra) and set the DIAMBRAROMSPATH environment variable properly.

All the examples presented below are available here: DIAMBRA Agents - SheepRL. They have been created following the high level approach found on SheepRL DIAMBRA page, thus allowing to easily extend them and to understand how they interface with the different components.

These examples only aim at demonstrating the core functionalities and high-level aspects, they will not generate well-performing agents, even if the training time is extended to cover a large number of training steps. The user will need to build upon them, exploring aspects like policy network architecture, algorithm hyperparameter tuning, observation space tweaking, rewards wrapping, and other similar ones.

General Environment Settings

SheepRL provides a lot of different environments that share a set of parameters. Moreover, SheepRL leverages Hydra for defining hierarchical configurations. Below is reported the general structure of the configuration of an environment and a table describing the arguments.

id: ???
num_envs: 4
frame_stack: 1
sync_env: False
screen_size: 64
action_repeat: 1
grayscale: False
clip_rewards: False
capture_video: True
frame_stack_dilation: 1
max_episode_steps: null
reward_as_observation: False
wrapper: ???
ArgumentTypeDefault Value(s)Description
idstr-Game environment identifier
num_envsint4The number of environment to initialize for training
frame_stackint1The number of frames to stack
sync_envboolFalseWhether to use the gymnasium.vector.SyncVectorEnv (True) or gymnasium.vector.AsyncVectorEnv (False) for handling vectorized environments
screen_sizeint | Tuple[int, int]64Screen size of the frames
action_repeatint64How many times repeat the same action
grayscaleboolFalseWhether to use grayscale frames
clip_rewardsboolFalseWhether or not to clip rewards using a tanh
capture_videoboolTrueWhether or not to capture the video of the episodes during training
frame_stack_dilationint1The number of frames to be skipped between frames in the frame_stack
max_episode_stepsint | NonenullThe maximum number of steps in a single episode
reward_as_observationboolFalseWhether or not to add the reward to the observations
wrapperDict[str, Any]-Environment-related arguments (see here)

If you have never used Hydra, before continuing, it is strongly recommended to check the Hydra official documentation and the SheepRL-related section.

Native interface

DIAMBRA Arena native interface with SheepRL covers a wide range of use cases, automating the handling of vectorized environments and monitoring wrappers. In the majority of cases, it will be sufficient for users to directly import and use it, with no need for additional customization. Below is reported its interface and a table describing its arguments.

class DiambraWrapper(gym.Wrapper):
    def __init__(
        self,
        id: str,
        action_space: str = "DISCRETE",
        screen_size: Union[int, Tuple[int, int]] = 64,
        grayscale: bool = False,
        repeat_action: int = 1,
        rank: int = 0,
        diambra_settings: Dict[str, Any] = {},
        diambra_wrappers: Dict[str, Any] = {},
        render_mode: str = "rgb_array",
        log_level: int = 0,
        increase_performance: bool = True,
    ):
ArgumentTypeDefault Value(s)Description
idstr-Game environment identifier
action_spacestr"DISCRETE"Which action space to use: one between "DISCRETE" and "MULTI_DISCRETE"
screen_sizeint | Tuple[int, int]64Screen size of the frames
grayscaleboolFalseWhether to use grayscale frames
rankint0Rank of the environment
diambra_settingsDict[str, Any]{}The settings of the environment. See here to check which settings you can specify.
diambra_wrappersDict[str, Any]{}The wrappers to apply to the environment. See here to check which wrappers you can specify.
render_modestr"rgb_array"Rendering mode
log_levelint0Log level
increase_performanceboolTrueWhether to modify frames on the engine side (True) or use the wrapper (False)

For the interface low-level details, users can review the correspondent source code here.

Agent Settings

SheepRL provides several SOTA algorithms, both model-free and model-based. Here you can find the default configurations for these agent. Of course, one can change algorithm-related hyper-parameters for customizing his/her experiments.

Basic

As anticipated before, SheepRL provides several default configurations for all its components, which are available and can be composed to set up an experiment. Otherwise, you can customize the ones you want: the two main ones to be defined for experiments are the agent and the environment.

Regarding the environment, there are some constraints that must be respected, for example, the dictionary observation spaces cannot be nested. For this reason, the DIAMBRA flattening wrapper is always used. For more information about the constraints of the SheepRL library, check here.

Instead, regarding the agent, the only two constraints that are present concern the observation and action spaces that agents support. You can read the supported observation and action spaces in Table 1 and Table 2 of the README in the SheepRL GitHub repository, respectively.

Customising the Configurations

The default configurations are available here. If you want to define your custom experiments, you just need to follow a few steps:

  1. You need to create a folder (with the same structure as the SheepRL configs folder) where to place your custom configurations.
  2. You need to define the SHEEPRL_SEARCH_PATH environment variable in the .env file as follows: SHEEPRL_SEARCH_PATH=file://relative/path/to/custom/configs/folder;pkg://sheeprl.configs.
  3. You need to define the custom configurations, being careful that the filename is different from the default ones. If this is not respected, your file will overwrite the default configurations.

Basic Example

This example demonstrates how to:

  • Leverage SheepRL to define the environment for training.
  • Define a PPO Agent to be trained.
  • Define custom configurations for your experiment.
  • Train the agent.
  • Run the trained agent in the environment for one episode.

SheepRL natively supports dictionary observation spaces, the only thing you need to define is the keys of the observations you want to process. For more information about observations selection, check here.

Configs Folder

First, it is necessary to create a folder for the configuration files. We create the configs folder under the ./sheeprl/ folder in the DIAMBRA Arena GitHub repository. Then we added the .env file in ./sheeprl/ folder, in which we need to define the SHEEPRL_SEARCH_PATH environment variable as follows:

SHEEPRL_SEARCH_PATH=file://configs;pkg://sheeprl.configs
Define the Environment

Now, in the ./sheeprl/configs folder we create the env folder in which the custom_env.yaml will be placed. Below is reported a possible configuration of the environment.

defaults:
  - default
  - _self_

# Override from `default` config
# `default` config contains the arguments shared
# among all the environments in SheepRL
id: doapp
frame_stack: 1
sync_env: True
action_repeat: 1
num_envs: 1
screen_size: 64
grayscale: False
clip_rewards: False
capture_video: True
frame_stack_dilation: 1
max_episode_steps: null
reward_as_observation: False

# DOAPP-related arguments
wrapper:
  # class to be instantiated
  _target_: sheeprl.envs.diambra.DiambraWrapper
  id: ${env.id}
  action_space: DISCRETE # or "MULTI_DISCRETE"
  screen_size: ${env.screen_size}
  grayscale: ${env.grayscale}
  repeat_action: ${env.action_repeat}
  rank: null
  log_level: 0
  increase_performance: True
  diambra_settings:
    role: P1 # or "P2" or null
    step_ratio: 6
    difficulty: 4
    continue_game: 0.0
    show_final: False
    outfits: 2
    splash_screen: False
  diambra_wrappers:
    stack_actions: 1
    no_op_max: 0
    no_attack_buttons_combinations: False
    add_last_action: True
    scale: False
    exclude_image_scaling: False
    process_discrete_binary: False
    role_relative: True
Define the Agent

As for the environment, we need to create a dedicated folder to place the custom configurations of the agents: we create the algo folder in the ./sheeprl/configs folder and we place the custom_ppo_agent.yaml file. Under the default keyword, it is possible to retrieve the configurations specified in another file, in our case, since we are defining the agent, we can take the configuration from the algorithm config folder in SheepRL, in which several SOTA agents are defined.

When defining an agent it is mandatory to define the name of the algorithm (it must be equal to the filename of the file in which the algorithm is defined). The value of these parameters defines which algorithm will be used for training. If you inherit the default configurations of a specific algorithm, then you do not need to define it, since it is already defined in the default configs of that algorithm.

Below is reported a configuration file for a PPO agent.

defaults:
  # Take default configurations of PPO
  - ppo
  # define Adam optimizer under the `optimizer` key 
  # from the sheeprl/configs/optim folder
  - override /optim@optimizer: adam
  - _self_

# Override default ppo arguments
# `name` is a mandatory attribute, it must be equal to the filename 
# of the file in which the algorithm is defined.
# If you inherit the default configurations of a specific algoritm,
# then you do not need to define it, since it is already defined in the default configs
name: ppo
update_epochs: 1
normalize_advantages: True
rollout_steps: 32
dense_units: 16
mlp_layers: 1
dense_act: torch.nn.Tanh
max_grad_norm: 1.0

# Encoder
encoder:
  cnn_features_dim: 128
  mlp_features_dim: 32
  dense_units: ${algo.dense_units}
  mlp_layers: ${algo.mlp_layers}
  dense_act: ${algo.dense_act}
  layer_norm: ${algo.layer_norm}

# Actor
actor:
  dense_units: ${algo.dense_units}
  mlp_layers: ${algo.mlp_layers}
  dense_act: ${algo.dense_act}
  layer_norm: ${algo.layer_norm}

# Critic
critic:
  dense_units: ${algo.dense_units}
  mlp_layers: ${algo.mlp_layers}
  dense_act: ${algo.dense_act}
  layer_norm: ${algo.layer_norm}

# Single optimizer for both actor and critic
optimizer:
  lr: 5e-3
  eps: 1e-6

Define the Experiment

The last thing to do is to define the experiment. You just need to define a custom_exp.yaml file in the ./sheeprl/configs/exp folder and assemble the environment, the agent, and the other components of the SheepRL framework. In particular, there are four parameters that must be defined:

  1. algo.total_steps: the total number of policy steps to compute during training (for more information, check here).
  2. buffer.size: the dimension of the replay buffer.
  3. algo.cnn_keys: the keys of frames in observations that must be encoded (and eventually reconstructed by the decoder).
  4. algo.mlp_keys: the keys of vectors in observations that must be encoded (and eventually reconstructed by the decoder).

Both algo.cnn_keys and algo.mlp_keys must be non-empty lists. Moreover, the user specified keys must be a subset of the environment observation keys.

Below is an example of an experiment config file.

# @package _global_

defaults:
  # Selects the algorithm and the environment
  - override /algo: custom_ppo_agent
  - override /env: custom_env
  - _self_

# Buffer
buffer:
  share_data: False
  size: ${algo.rollout_steps}

checkpoint:
  save_last: True

# Experiment
algo:
  total_steps: 1024
  per_rank_batch_size: 16
  cnn_keys:
    encoder: [frame]
  mlp_keys:
    encoder:
      - own_character
      - own_health
      - own_side
      - own_wins
      - opp_character
      - opp_health
      - opp_side
      - opp_wins
      - stage
      - timer
      - action

When defining the configurations of the experiment you can specify how frequently save checkpoints of the model, and if you want to save the final agent. For more information, check here.

Train and Evaluate the Agent

To run the experiment you just need to go into the ./sheeprl folder and run the following command:

diambra run -s=2 python train.py exp=custom_exp

You have to instantiate 2 docker containers because sheeprl automatically performs a test of the agent after training.

After training, you can decide to evaluate the agent as many times as you want. You can specify only a few parameters for evaluating your agent:

  1. The checkpoint of the agent that you want to evaluate (checkpoint_path, mandatory).
  2. The type of device on which you want to run the evaluation (fabric.device, default to cpu).
  3. Whether or not to capture the video of the evaluation (env.capture_video, default to True).

The reason why only these three parameters need to be specified is to avoid inconsistencies, e.g. the checkpoint of one agent and the configurations of the evaluation refer to another one, or the model in the checkpoint has different dimensions from the model specified in the configurations. This implies, however, that the evaluation script expects a certain directory structure. For this reason, the structure of the log directory should not be changed: all of it can be moved, but not the checkpoint individually, otherwise the script cannot automatically retrieve the environment and agent configurations.

# @package _global_

# specify here default training configuration
defaults:
  - _self_
  - override hydra/hydra_logging: disabled
  - override hydra/job_logging: disabled

hydra:
  output_subdir: null
  run:
    dir: .

fabric:
  accelerator: cpu

env:
  capture_video: True

seed: null
num_threads: 1
disable_grads: True
checkpoint_path: ???
float32_matmul_precision: "high"

To evaluate the agent you just need to run the following command:

diambra run python evaluate.py checkpoint_path=/path/to/checkpoint.ckpt

If you want to specify the device to use, for instance cuda, you have to run the following command:

diambra run python evaluate.py checkpoint_path=/path/to/checkpoint.ckpt fabric.device=cuda

If you want to specify whether or not to capture the video, you have to run the following command:

diambra run python evaluate.py checkpoint_path=/path/to/checkpoint.ckpt env.capture_video=True
Train and Evaluate Scripts

In this section, we show the two scripts for training and evaluating agents. With regard to training, first the environment selected by the user is checked, if it is not one of diambra, then an exception is raised. Next, the run() function of SheepRL is called, which will initialize all components and start the training.

As far as evaluation is concerned, simply the configurations are passed directly to the evaluate() function of sheeprl. There is no need to check the environment as it has already been checked before training.

The train.py script:

# Diambra Agents

import hydra
from diambra.arena.sheeprl import CONFIGS_PATH
from omegaconf import DictConfig

from sheeprl.cli import run


def check_configs(cfg: DictConfig):
    if "diambra" not in cfg.env.wrapper._target_:
        raise ValueError(
            f"You must choose a DIAMBRA environment. "
            f"Got '{cfg.env.id}' provided by '{cfg.env.wrapper._target_.split('.')[-2]}'."
        )


@hydra.main(version_base="1.3", config_path=CONFIGS_PATH, config_name="config")
def train(cfg: DictConfig):
    check_configs(cfg)
    run(cfg)


if __name__ == "__main__":
    train()

The evaluate.py script:

# Diambra Agents

import hydra
from diambra.arena.sheeprl import CONFIGS_PATH
from omegaconf import DictConfig

from sheeprl.cli import evaluation


@hydra.main(version_base="1.3", config_path=CONFIGS_PATH, config_name="eval_config")
def run(cfg: DictConfig):
    evaluation(cfg)


if __name__ == "__main__":
    run()

PPO Implementation

In this paragraph, we quote the code of our ppo implementation (the ppo.py file in the SheepRL PPO folder), just to give more context on how SheepRL works. In the main() function, all the components needed for training are instantiated (i.e., the agent, the environments, the buffer, the logger, and so on). Then, the environment interaction is performed, and after collecting the rollout steps, the train function is called.

The train() function is responsible for sharing the data between processes, if more processes are launched and the buffer.share_data is set to True. Then, for each batch, the losses are computed and the agent is updated.

from __future__ import annotations

import copy
import os
import warnings
from typing import Any, Dict, Union

import gymnasium as gym
import hydra
import numpy as np
import torch
from lightning.fabric import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import nn
from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler
from torchmetrics import SumMetric

from sheeprl.algos.ppo.agent import build_agent
from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss
from sheeprl.algos.ppo.utils import normalize_obs, prepare_obs, test
from sheeprl.data.buffers import ReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, save_configs


def train(
    fabric: Fabric,
    agent: Union[nn.Module, _FabricModule],
    optimizer: torch.optim.Optimizer,
    data: Dict[str, torch.Tensor],
    aggregator: MetricAggregator | None,
    cfg: Dict[str, Any],
):
    """Train the agent on the data collected from the environment."""
    indexes = list(range(next(iter(data.values())).shape[0]))
    if cfg.buffer.share_data:
        sampler = DistributedSampler(
            indexes,
            num_replicas=fabric.world_size,
            rank=fabric.global_rank,
            shuffle=True,
            seed=cfg.seed,
        )
    else:
        sampler = RandomSampler(indexes)
    sampler = BatchSampler(sampler, batch_size=cfg.algo.per_rank_batch_size, drop_last=False)

    for epoch in range(cfg.algo.update_epochs):
        if cfg.buffer.share_data:
            sampler.sampler.set_epoch(epoch)
        for batch_idxes in sampler:
            batch = {k: v[batch_idxes] for k, v in data.items()}
            normalized_obs = normalize_obs(
                batch, cfg.algo.cnn_keys.encoder, cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder
            )
            _, logprobs, entropy, new_values = agent(
                normalized_obs, torch.split(batch["actions"], agent.actions_dim, dim=-1)
            )

            if cfg.algo.normalize_advantages:
                batch["advantages"] = normalize_tensor(batch["advantages"])

            # Policy loss
            pg_loss = policy_loss(
                logprobs,
                batch["logprobs"],
                batch["advantages"],
                cfg.algo.clip_coef,
                cfg.algo.loss_reduction,
            )

            # Value loss
            v_loss = value_loss(
                new_values,
                batch["values"],
                batch["returns"],
                cfg.algo.clip_coef,
                cfg.algo.clip_vloss,
                cfg.algo.loss_reduction,
            )

            # Entropy loss
            ent_loss = entropy_loss(entropy, cfg.algo.loss_reduction)

            # Equation (9) in the paper
            loss = pg_loss + cfg.algo.vf_coef * v_loss + cfg.algo.ent_coef * ent_loss

            optimizer.zero_grad(set_to_none=True)
            fabric.backward(loss)
            if cfg.algo.max_grad_norm > 0.0:
                fabric.clip_gradients(agent, optimizer, max_norm=cfg.algo.max_grad_norm)
            optimizer.step()

            # Update metrics
            if aggregator and not aggregator.disabled:
                aggregator.update("Loss/policy_loss", pg_loss.detach())
                aggregator.update("Loss/value_loss", v_loss.detach())
                aggregator.update("Loss/entropy_loss", ent_loss.detach())


@register_algorithm()
def main(fabric: Fabric, cfg: Dict[str, Any]):
    if "minedojo" in cfg.env.wrapper._target_.lower():
        raise ValueError(
            "MineDojo is not currently supported by PPO agent, since it does not take "
            "into consideration the action masks provided by the environment, but needed "
            "in order to play correctly the game. "
            "As an alternative you can use one of the Dreamers' agents."
        )

    initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef)
    initial_clip_coef = copy.deepcopy(cfg.algo.clip_coef)

    # Initialize Fabric
    rank = fabric.global_rank
    world_size = fabric.world_size
    device = fabric.device

    # Resume from checkpoint
    if cfg.checkpoint.resume_from:
        state = fabric.load(cfg.checkpoint.resume_from)

    # Create Logger. This will create the logger only on the
    # rank-0 process
    logger = get_logger(fabric, cfg)
    if logger and fabric.is_global_zero:
        fabric._loggers = [logger]
        fabric.logger.log_hyperparams(cfg)
    log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name)
    fabric.print(f"Log dir: {log_dir}")

    # Environment setup
    vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv
    envs = vectorized_env(
        [
            make_env(
                cfg,
                cfg.seed + rank * cfg.env.num_envs + i,
                rank * cfg.env.num_envs,
                log_dir if rank == 0 else None,
                "train",
                vector_env_idx=i,
            )
            for i in range(cfg.env.num_envs)
        ]
    )
    observation_space = envs.single_observation_space

    if not isinstance(observation_space, gym.spaces.Dict):
        raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
    if cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder == []:
        raise RuntimeError(
            "You should specify at least one CNN keys or MLP keys from the cli: "
            "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`"
        )
    if cfg.metric.log_level > 0:
        fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder)
        fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder)
    obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder

    is_continuous = isinstance(envs.single_action_space, gym.spaces.Box)
    is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete)
    actions_dim = tuple(
        envs.single_action_space.shape
        if is_continuous
        else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n])
    )
    # Create the actor and critic models
    agent, player = build_agent(
        fabric,
        actions_dim,
        is_continuous,
        cfg,
        observation_space,
        state["agent"] if cfg.checkpoint.resume_from else None,
    )

    # Define the optimizer
    optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters(), _convert_="all")

    if fabric.is_global_zero:
        save_configs(cfg, log_dir)

    # Load the state from the checkpoint
    if cfg.checkpoint.resume_from:
        optimizer.load_state_dict(state["optimizer"])

    # Setup agent and optimizer with Fabric
    optimizer = fabric.setup_optimizers(optimizer)

    # Create a metric aggregator to log the metrics
    aggregator = None
    if not MetricAggregator.disabled:
        aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device)

    # Local data
    if cfg.buffer.size < cfg.algo.rollout_steps:
        raise ValueError(
            f"The size of the buffer ({cfg.buffer.size}) cannot be lower "
            f"than the rollout steps ({cfg.algo.rollout_steps})"
        )
    rb = ReplayBuffer(
        cfg.buffer.size,
        cfg.env.num_envs,
        memmap=cfg.buffer.memmap,
        memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"),
        obs_keys=obs_keys,
    )

    # Global variables
    last_train = 0
    train_step = 0
    start_step = (
        # + 1 because the checkpoint is at the end of the update step
        # (when resuming from a checkpoint, the update at the checkpoint
        # is ended and you have to start with the next one)
        (state["update"] // fabric.world_size) + 1
        if cfg.checkpoint.resume_from
        else 1
    )
    policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0
    last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
    last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
    policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
    num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1
    if cfg.checkpoint.resume_from:
        cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size

    # Warning for log and checkpoint every
    if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0:
        warnings.warn(
            f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the "
            f"policy_steps_per_update value ({policy_steps_per_update}), so "
            "the metrics will be logged at the nearest greater multiple of the "
            "policy_steps_per_update value."
        )
    if cfg.checkpoint.every % policy_steps_per_update != 0:
        warnings.warn(
            f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the "
            f"policy_steps_per_update value ({policy_steps_per_update}), so "
            "the checkpoint will be saved at the nearest greater multiple of the "
            "policy_steps_per_update value."
        )

    # Linear learning rate scheduler
    if cfg.algo.anneal_lr:
        from torch.optim.lr_scheduler import PolynomialLR

        scheduler = PolynomialLR(optimizer=optimizer, total_iters=num_updates, power=1.0)
        if cfg.checkpoint.resume_from:
            scheduler.load_state_dict(state["scheduler"])

    # Get the first environment observation and start the optimization
    step_data = {}
    next_obs = envs.reset(seed=cfg.seed)[0]  # [N_envs, N_obs]
    for k in obs_keys:
        if k in cfg.algo.cnn_keys.encoder:
            next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:])
        step_data[k] = next_obs[k][np.newaxis]

    for update in range(start_step, num_updates + 1):
        with torch.inference_mode():
            for _ in range(0, cfg.algo.rollout_steps):
                policy_step += cfg.env.num_envs * world_size

                # Measure environment interaction time: this considers both the model forward
                # to get the action given the observation and the time taken into the environment
                with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
                    # Sample an action given the observation received by the environment
                    torch_obs = prepare_obs(
                        fabric, next_obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs
                    )
                    actions, logprobs, values = player(torch_obs)
                    if is_continuous:
                        real_actions = torch.cat(actions, -1).cpu().numpy()
                    else:
                        real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
                    actions = torch.cat(actions, -1).cpu().numpy()

                    # Single environment step
                    obs, rewards, terminated, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape))
                    truncated_envs = np.nonzero(truncated)[0]
                    if len(truncated_envs) > 0:
                        real_next_obs = {
                            k: torch.empty(
                                len(truncated_envs),
                                *observation_space[k].shape,
                                dtype=torch.float32,
                                device=device,
                            )
                            for k in obs_keys
                        }
                        for i, truncated_env in enumerate(truncated_envs):
                            for k, v in info["final_observation"][truncated_env].items():
                                torch_v = torch.as_tensor(v, dtype=torch.float32, device=device)
                                if k in cfg.algo.cnn_keys.encoder:
                                    torch_v = torch_v.view(-1, *v.shape[-2:])
                                    torch_v = torch_v / 255.0 - 0.5
                                real_next_obs[k][i] = torch_v
                        vals = player.get_values(real_next_obs).cpu().numpy()
                        rewards[truncated_envs] += cfg.algo.gamma * vals.reshape(rewards[truncated_envs].shape)
                    dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8)
                    rewards = rewards.reshape(cfg.env.num_envs, -1)

                # Update the step data
                step_data["dones"] = dones[np.newaxis]
                step_data["values"] = values.cpu().numpy()[np.newaxis]
                step_data["actions"] = actions[np.newaxis]
                step_data["logprobs"] = logprobs.cpu().numpy()[np.newaxis]
                step_data["rewards"] = rewards[np.newaxis]
                if cfg.buffer.memmap:
                    step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape))
                    step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape))

                # Append data to buffer
                rb.add(step_data, validate_args=cfg.buffer.validate_args)

                # Update the observation and dones
                next_obs = {}
                for k in obs_keys:
                    _obs = obs[k]
                    if k in cfg.algo.cnn_keys.encoder:
                        _obs = _obs.reshape(cfg.env.num_envs, -1, *_obs.shape[-2:])
                    step_data[k] = _obs[np.newaxis]
                    next_obs[k] = _obs

                if cfg.metric.log_level > 0 and "final_info" in info:
                    for i, agent_ep_info in enumerate(info["final_info"]):
                        if agent_ep_info is not None:
                            ep_rew = agent_ep_info["episode"]["r"]
                            ep_len = agent_ep_info["episode"]["l"]
                            if aggregator and "Rewards/rew_avg" in aggregator:
                                aggregator.update("Rewards/rew_avg", ep_rew)
                            if aggregator and "Game/ep_len_avg" in aggregator:
                                aggregator.update("Game/ep_len_avg", ep_len)
                            fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}")

        # Transform the data into PyTorch Tensors
        local_data = rb.to_tensor(dtype=None, device=device, from_numpy=cfg.buffer.from_numpy)

        # Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
        with torch.inference_mode():
            torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
            next_values = player.get_values(torch_obs)
            returns, advantages = gae(
                local_data["rewards"].to(torch.float64),
                local_data["values"],
                local_data["dones"],
                next_values,
                cfg.algo.rollout_steps,
                cfg.algo.gamma,
                cfg.algo.gae_lambda,
            )
            # Add returns and advantages to the buffer
            local_data["returns"] = returns.float()
            local_data["advantages"] = advantages.float()

        if cfg.buffer.share_data and fabric.world_size > 1:
            # Gather all the tensors from all the world and reshape them
            gathered_data: Dict[str, torch.Tensor] = fabric.all_gather(local_data)
            # Flatten the first three dimensions: [World_Size, Buffer_Size, Num_Envs]
            gathered_data = {k: v.flatten(start_dim=0, end_dim=2).float() for k, v in gathered_data.items()}
        else:
            # Flatten the first two dimensions: [Buffer_Size, Num_Envs]
            gathered_data = {k: v.flatten(start_dim=0, end_dim=1).float() for k, v in local_data.items()}

        with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
            train(fabric, agent, optimizer, gathered_data, aggregator, cfg)
        train_step += world_size

        if cfg.metric.log_level > 0:
            # Log lr and coefficients
            if cfg.algo.anneal_lr:
                fabric.log("Info/learning_rate", scheduler.get_last_lr()[0], policy_step)
            else:
                fabric.log("Info/learning_rate", cfg.algo.optimizer.lr, policy_step)
            fabric.log("Info/clip_coef", cfg.algo.clip_coef, policy_step)
            fabric.log("Info/ent_coef", cfg.algo.ent_coef, policy_step)

            # Log metrics
            if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates):
                # Sync distributed metrics
                if aggregator and not aggregator.disabled:
                    metrics_dict = aggregator.compute()
                    fabric.log_dict(metrics_dict, policy_step)
                    aggregator.reset()

                # Sync distributed timers
                if not timer.disabled:
                    timer_metrics = timer.compute()
                    if "Time/train_time" in timer_metrics:
                        fabric.log(
                            "Time/sps_train",
                            (train_step - last_train) / timer_metrics["Time/train_time"],
                            policy_step,
                        )
                    if "Time/env_interaction_time" in timer_metrics:
                        fabric.log(
                            "Time/sps_env_interaction",
                            ((policy_step - last_log) / world_size * cfg.env.action_repeat)
                            / timer_metrics["Time/env_interaction_time"],
                            policy_step,
                        )
                    timer.reset()

                # Reset counters
                last_log = policy_step
                last_train = train_step

        # Update lr and coefficients
        if cfg.algo.anneal_lr:
            scheduler.step()
        if cfg.algo.anneal_clip_coef:
            cfg.algo.clip_coef = polynomial_decay(
                update, initial=initial_clip_coef, final=0.0, max_decay_steps=num_updates, power=1.0
            )
        if cfg.algo.anneal_ent_coef:
            cfg.algo.ent_coef = polynomial_decay(
                update, initial=initial_ent_coef, final=0.0, max_decay_steps=num_updates, power=1.0
            )

        # Checkpoint model
        if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or (
            update == num_updates and cfg.checkpoint.save_last
        ):
            last_checkpoint = policy_step
            state = {
                "agent": agent.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None,
                "update": update * world_size,
                "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size,
                "last_log": last_log,
                "last_checkpoint": last_checkpoint,
            }
            ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt")
            fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)

    envs.close()
    if fabric.is_global_zero and cfg.algo.run_test:
        test(player, fabric, cfg, log_dir)

    if not cfg.model_manager.disabled and fabric.is_global_zero:
        from sheeprl.algos.ppo.utils import log_models
        from sheeprl.utils.mlflow import register_model

        models_to_log = {"agent": agent}
        register_model(fabric, log_models, cfg, models_to_log)

Parallel Environments

In addition to what is seen in previous examples, this one demonstrates how to run training using parallel environments. In this example, the same PPO algorithm is used as before. To train the agent with multiple parallel environments, you need to define properly a few environment parameters and then run the script instantiating the correct number of docker containers.

You can create a custom_parallel_env.yaml config file that inherits the configurations from the custom_env.yaml file:

defaults:
  # Inherit evironment configurations from custom_env.yaml
  - custom_env
  - _self_

# Override parameters
sync_env: False # True if you want to use the gymnasium.vector.SyncVectorEnv
num_envs: 4

If you set the env.sync_env to False, then you must instantiate one more docker container because the gymnasium.vector.AsyncVectorEnv instantiates a dummy env when defined.

Then, you have to create a new file for the experiment (custom_parallel_env_exp.yaml), this file inherits the configurations of the custom_exp file and overrides the environment with the newly defined configurations (custom_parallel_env):

# @package _global_

defaults:
  # Inherit configs from custom_exp
  - custom_exp
  # Override the environment configurations
  - override /env: custom_parallel_env
  - _self_

How to run it:

# s=6 comes from: 4 for the envs, 1 for testing, 1 for `gymnasium.vector.AsyncVectorEnv`
diambra run -s=6 python train.py exp=custom_parallel_env_exp

Advanced

Fabric

SheepRL allows training to be distributed thanks to Lightning Fabric.

The default Fabric configuration is the following:

_target_: lightning.fabric.Fabric
devices: 1
num_nodes: 1
strategy: "auto"
accelerator: "cpu"
precision: "32-true"
callbacks:
  - _target_: sheeprl.utils.callback.CheckpointCallback
    keep_last: "${checkpoint.keep_last}"

The sheeprl.utils.callback.CheckpointCallback is used for saving the checkpoint during training and for saving the trained agent.

To modify the Fabric configs, you can add a fabric field in the experiment file, as shown below. In this case, we selected 2 devices, the accelerator is "cuda" and the training is performed in 16 bits. As before, it inherits the configurations from the custom_exp and then sets the Fabric parameters.

# @package _global_

defaults:
  # Inherit configs from custom_exp
  - custom_exp
  - _self_

# Set Fabric parameters
fabric:
  devices: 2
  accelerator: cuda
  precision: bf16-mixed

How to run it:

# Remember to set properly the number of containers to create
#   - Each process has 1 environment
#   - There are 2 processes
#   - Only the zero-rank process will perform the evaluation after the training
diambra run -s=3 python train.py exp=custom_fabric_exp

To run the fabric experiment, make sure you have a cuda GPU in your device, otherwise, change the device from cuda to cpu (or to another device).

Metric and Logging

Finally, SheepRL allows you to visualize and monitor training using Tensorboard.

We strongly recommend to read the SheepRL logging documentation to know about how to enable/disable logging.

Below is reported the default logging configuration and a table describing the arguments.

defaults:
  - _self_
  - /logger@logger: tensorboard

log_every: 5000
disable_timer: False

# Level of Logging:
#   0: No log
#   1: Log everything
log_level: 1

# Metric related parameters. Please have a look at
# https://torchmetrics.readthedocs.io/en/stable/references/metric.html#torchmetrics.Metric
# for more information
sync_on_compute: False

aggregator:
  _target_: sheeprl.utils.metric.MetricAggregator
  raise_on_missing: False
  metrics:
    Rewards/rew_avg: 
      _target_: torchmetrics.MeanMetric
      sync_on_compute: ${metric.sync_on_compute}
    Game/ep_len_avg: 
      _target_: torchmetrics.MeanMetric
      sync_on_compute: ${metric.sync_on_compute}
ArgumentTypeDefault Value(s)Description
log_everyint5000Number of steps between one log and the next
disable_timerboolFalseWhether or not to disable timer information (training and environment interaction)
log_levelint1The level of logging (0: disabled, 1: log everything)
sync_on_computeboolFalseWhether to synchronize the metrics between processes
aggregatorDict[str, Any]-Configurations of the aggregator to be instantiated, containing the metrics to log

You can modify the default metric configurations by adding in the custom_exp file the custom configuration you want under the metric key, as shown below. In this example, we do not log the timer information and we want to synchronize the metrics between the 2 processes. Moreover, we add 3 metrics to log to the aggregator (in addition to reward and episode length): the value loss, the policy loss, and the entropy loss.

# @package _global_

defaults:
  # Inherit configs from custom_fabric_exp
  - custom_fabric_exp
  - _self_

# Set Metric parameters
metric:
  disable_timer: True
  sync_on_compute: True
  aggregator:
    metrics:
      Loss/value_loss:
        _target_: torchmetrics.MeanMetric
        sync_on_compute: ${metric.sync_on_compute}
      Loss/policy_loss:
        _target_: torchmetrics.MeanMetric
        sync_on_compute: ${metric.sync_on_compute}
      Loss/entropy_loss:
        _target_: torchmetrics.MeanMetric
        sync_on_compute: ${metric.sync_on_compute}

How to run it:

# s=3 since `custom_metric_exp` extends from the fabric experiments
diambra run -s=3 python train.py exp=custom_metric_exp

The logs are stored in the ./logs/runs/<algo_name>/<env_id>/<datetime_experiment>/ folder, and to visualize the plots, you just need to run the following command:

tensorboard --logdir /path/to/logging/directory

open your browser and go to http://localhost:6006/. You can eventually modify the port of the process, for instance, you can use port 6010 by running the following command:

tensorboard --logdir /path/to/logging/directory --port 6010

Agent Script for Competition

Finally, after the agent training is completed, besides running it locally on your own machine, you may want to submit it to our Competition Platform! To do so, you can use the following script that provides a ready-to-use, flexible example that can accommodate different games and settings.

To submit your trained agent to our platform, compete for the first leaderboard positions, and unlock our achievements, follow the simple steps described in the “How to Submit an Agent” section.

The script for submitting PPOs is shown below.

import argparse
import json

import gymnasium as gym
import torch
from lightning import Fabric
from omegaconf import OmegaConf
from sheeprl.algos.ppo.agent import build_agent
from sheeprl.algos.ppo.utils import prepare_obs
from sheeprl.utils.env import make_env
from sheeprl.utils.utils import dotdict

"""This is an example agent based on SheepRL.

Usage:
cd sheeprl
diambra run python agent-ppo.py --cfg_path "/absolute/path/to/example-logs/runs/ppo/doapp/experiment/version_0/config.yaml" --checkpoint_path "/absolute/path/to/example-logs/runs/ppo/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt"
"""


def main(cfg_path: str, checkpoint_path: str, test=False):
    # Read the cfg file
    cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True))
    print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4))

    # Override configs for evaluation
    # You do not need to capture the video since you are submitting the agent and the video is recorded by DIAMBRA
    cfg.env.capture_video = False
    # Only one environment is used for evaluation
    cfg.env.num_envs = 1

    # Instantiate Fabric
    # You must use the same precision and plugins used for training.
    precision = getattr(cfg.fabric, "precision", None)
    plugins = getattr(cfg.fabric, "plugins", None)
    fabric = Fabric(
        accelerator="auto",
        devices=1,
        num_nodes=1,
        precision=precision,
        plugins=plugins,
        strategy="auto",
    )

    # Create Environment
    env = make_env(cfg, 0, 0)()
    observation_space = env.observation_space
    is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
    actions_dim = tuple(
        env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]
    )
    cnn_keys = cfg.algo.cnn_keys.encoder

    # Load the trained agent
    state = fabric.load(checkpoint_path)
    # You need to retrieve only the player
    # Check for each algorithm what models the `build_agent()` function returns
    # (placed in the `agent.py` file of the algorithm), and which arguments it needs.
    # Check also which are the keys of the checkpoint: if the `build_agent()` parameter
    # is called `model_state`, then you retrieve the model state with `state["model"]`.
    agent = build_agent(
        fabric=fabric,
        actions_dim=actions_dim,
        is_continuous=False,
        cfg=cfg,
        obs_space=observation_space,
        agent_state=state["agent"],
    )[-1]
    agent.eval()

    # Print policy network architecture
    print("Policy architecture:")
    print(agent)

    obs, info = env.reset()

    while True:
        # Convert numpy observations into torch observations and normalize image observations
        # Every algorithm has its own way to do it, you must import the correct method
        torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys)

        # Select actions, the agent returns a one-hot categorical or
        # more one-hot categorical distributions for muli-discrete actions space
        actions = agent.get_actions(torch_obs, greedy=True)
        # Convert actions from one-hot categorical to categorial
        actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)

        obs, _, terminated, truncated, info = env.step(
            actions.cpu().numpy().reshape(env.action_space.shape)
        )

        if terminated or truncated:
            obs, info = env.reset()
            if info["env_done"] or test is True:
                break

    # Close the environment
    env.close()

    # Return success
    return 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cfg_path", type=str, required=True, help="Configuration file"
    )
    parser.add_argument(
        "--checkpoint_path", type=str, default="model", help="Model checkpoint"
    )
    parser.add_argument("--test", action="store_true", help="Test mode")
    opt = parser.parse_args()
    print(opt)

    main(opt.cfg_path, opt.checkpoint_path, opt.test)

If you have trained an agent different from PPO, you can simply reuse this script, modifying a couple of things:

  1. The build_agent() method: every agent has its own method, you can find it in the ./sheeprl/algos/<algo_name>/agent.py file. You should check the parameters of the method, every agent has its own models, so you must pass the state of the agent’s models as a parameter. The parameter is always called: model_name_state, you can retrieve the state of the model from the state dictionary (in the script shown above) with the "model_name" key.
  2. Whether or not to initialize the recurrent states: the models with recurrent neural networks (e.g., all the Dreamer algorithms) need to initialize the recurrent states at every reset of the environment. It is recommended to look at the test function of the algorithm you want to submit.

The only algorithm you need to pay a little more attention to is PPO Recurrent. It is recommended to look at its test function available here.

For example, the changes to be made to the script to submit DreamerV3 are explained below.

import argparse
import json

import gymnasium as gym
import torch
from lightning import Fabric
from omegaconf import OmegaConf
-from sheeprl.algos.ppo.agent import build_agent
+from sheeprl.algos.dreamer_v3.agent import build_agent
-from sheeprl.algos.ppo.utils import prepare_obs
+from sheeprl.algos.dreamer_v3.utils import prepare_obs
from sheeprl.utils.env import make_env
from sheeprl.utils.utils import dotdict

def main(cfg_path: str, checkpoint_path: str, test=False):
    ...

    agent = build_agent(
        fabric=fabric,
        actions_dim=actions_dim,
        is_continuous=False,
        cfg=cfg,
        obs_space=observation_space,
-       agent_state=state["agent"],
+       world_model_state=state["world_model"],
+       actor_state=state["actor"],
+       critic_state=state["critic"],
+       target_critic_state=state["target_critic"],
    )[-1]
    agent.eval()

    # Print policy network architecture
    print("Policy architecture:")
    print(agent)

    obs, info = env.reset()
+   agent.init_states()

    while True:
        ...

        if terminated or truncated:
            obs, info = env.reset()
+           agent.init_states()
            if info["env_done"] or test is True:
                break

    ...

The final script for the submission of dreamer_v3 is shown below.

import argparse
import json

import gymnasium as gym
import torch
from lightning import Fabric
from omegaconf import OmegaConf
from sheeprl.algos.dreamer_v3.agent import build_agent
from sheeprl.algos.dreamer_v3.utils import prepare_obs
from sheeprl.utils.env import make_env
from sheeprl.utils.utils import dotdict

"""This is an example agent based on SheepRL.

Usage:
cd sheeprl
diambra run python agent-dreamer_v3.py --cfg_path "/absolute/path/to/example-logs/runs/dreamer_v3/doapp/experiment/version_0/config.yaml" --checkpoint_path "/absolute/path/to/example-logs/runs/dreamer_v3/doapp/experiment/version_0/checkpoint/ckpt_1024_0.ckpt"
"""


def main(cfg_path: str, checkpoint_path: str, test=False):
    # Read the cfg file
    cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True))
    print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4))

    # Override configs for evaluation
    # You do not need to capture the video since you are submitting the agent and the video is recorded by DIAMBRA
    cfg.env.capture_video = False
    # Only one environment is used for evaluation
    cfg.env.num_envs = 1

    # Instantiate Fabric
    # You must use the same precision and plugins used for training.
    precision = getattr(cfg.fabric, "precision", None)
    plugins = getattr(cfg.fabric, "plugins", None)
    fabric = Fabric(
        accelerator="auto",
        devices=1,
        num_nodes=1,
        precision=precision,
        plugins=plugins,
        strategy="auto",
    )

    # Create Environment
    env = make_env(cfg, 0, 0)()
    observation_space = env.observation_space
    is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
    actions_dim = tuple(
        env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]
    )
    cnn_keys = cfg.algo.cnn_keys.encoder

    # Load the trained agent
    state = fabric.load(checkpoint_path)
    # You need to retrieve only the player
    # Check for each algorithm what models the `build_agent()` function returns
    # (placed in the `agent.py` file of the algorithm), and which arguments it needs.
    # Check also which are the keys of the checkpoint: if the `build_agent()` parameter
    # is called `model_state`, then you retrieve the model state with `state["model"]`.
    agent = build_agent(
        fabric=fabric,
        actions_dim=actions_dim,
        is_continuous=False,
        cfg=cfg,
        obs_space=observation_space,
        world_model_state=state["world_model"],
        actor_state=state["actor"],
        critic_state=state["critic"],
        target_critic_state=state["target_critic"],
    )[-1]
    agent.eval()

    # Print policy network architecture
    print("Policy architecture:")
    print(agent)

    obs, info = env.reset()
    # Every time you reset the environment, you must reset the initial states of the model
    agent.init_states()

    while True:
        # Convert numpy observations into torch observations and normalize image observations
        # Every algorithm has its own way to do it, you must import the correct method
        torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys)

        # Select actions, the agent returns a one-hot categorical or
        # more one-hot categorical distributions for muli-discrete actions space
        actions = agent.get_actions(torch_obs, greedy=False)
        # Convert actions from one-hot categorical to categorial
        actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)

        obs, _, terminated, truncated, info = env.step(
            actions.cpu().numpy().reshape(env.action_space.shape)
        )

        if terminated or truncated:
            obs, info = env.reset()
            # Every time you reset the environment, you must reset the initial states of the model
            agent.init_states()
            if info["env_done"] or test is True:
                break

    # Close the environment
    env.close()

    # Return success
    return 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cfg_path", type=str, required=True, help="Configuration file"
    )
    parser.add_argument(
        "--checkpoint_path", type=str, default="model", help="Model checkpoint"
    )
    parser.add_argument("--test", action="store_true", help="Test mode")
    opt = parser.parse_args()
    print(opt)

    main(opt.cfg_path, opt.checkpoint_path, opt.test)