Source code for yawning_titan.experiment_helpers.sb3

import logging
from statistics import mean

import gym
from scipy.stats import describe, iqr
from stable_baselines3 import A2C, DQN, PPO
from stable_baselines3.a2c import MlpPolicy as A2CMlp
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.dqn import MlpPolicy as DQNMlp
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from tabulate import tabulate

from yawning_titan.agents.random import RandomAgent

logger = logging.getLogger(__name__)


[docs]def init_env(env: str, experiment_id: str): """ Use the Stable Baselines 3 Monitor wrappper to wrap an environment in order to enable monitoring. Args: env: the registered name of an OpenAI gym environment (str) experiment_id: a UID for the experiment (str) Returns: A Stable Baselines 3 Monitor Wrapped Gym Environment """ wrapped_env = Monitor( gym.make(env), filename=f"./logs/{experiment_id}", allow_early_resets=True ) return wrapped_env
[docs]def train_and_eval( agent_name: str, environment, training_timesteps: int, n_eval_episodes: int ): """ Train and Evaluate an agent. Args: agent_name: the algorithm name (str) environment: An initlaised Open AI Gym environment training_timesteps: total no. of training timesteps (int) Returns: chosen_agent: a trained Stable Baselines 3 agent eval_pol: the output from the Stable Baselines 3 'evaluate_policy' function """ agent_dic = {"ppo": 0, "a2c": 1, "random": 2, "dqn": 3} agent_list = [ PPO(PPOMlp, environment, verbose=1, tensorboard_log="./logs/ppo-tensorboard"), A2C(A2CMlp, environment, verbose=1, tensorboard_log="./logs/a2c-tensorboard"), RandomAgent(environment.action_space), DQN(DQNMlp, environment, verbose=1, tensorboard_log="./logs/dqn-tensorboard"), ] chosen_agent = agent_list[agent_dic[agent_name]] logger.debug(f"{agent_name} Agent Initialised") chosen_agent.learn(total_timesteps=training_timesteps) logger.debug(f"Completed Training {agent_name}") print("Training Complete - Entering Evaluation") eval_pol = evaluate_policy( chosen_agent, environment, return_episode_rewards=True, n_eval_episodes=n_eval_episodes, ) return chosen_agent, eval_pol