Source code for yawning_titan.envs.generic.core.action_loops

"""
The ``ActionLoop`` class helps reduce boilerplate code when evaluating an agent within a target environment.

Serves a similar function to library helpers such as Stable Baselines 3 ``evaluate_policy()".
"""

import os
import re
import string
from datetime import datetime
from pathlib import Path
from threading import Thread
from typing import Any
from uuid import uuid4

import imageio
import matplotlib.pyplot as plt
import moviepy.editor as mp
import pandas as pd

from yawning_titan import APP_IMAGES_DIR, IMAGES_DIR, VIDEOS_DIR
from yawning_titan.envs.generic.generic_env import GenericNetworkEnv


[docs]class ActionLoop: """A class that represents different post-training action loops for agents."""
[docs] def __init__(self, env, agent, filename=None, episode_count=None): """ Initialise Class. Args: env: The environment to run through agent: The agent to run in the environment filename: The save name for the action lop episode_count: The number of episodes to go through """ if filename is None: filename = uuid4() self.env: GenericNetworkEnv = env self.agent = agent self.filename = filename self.episode_count = episode_count
[docs] def gif_action_loop( self, render_network=True, prompt_to_close=False, save_gif=False, save_webm=False, deterministic=False, gif_output_directory: Path = None, webm_output_directory: Path = None, *args, **kwargs, ): """ Run the agent in evaluation and create a gif from episodes. Args: render_network: Bool to toggle rendering on or off. Has a default value of True. prompt_to_close: Bool to toggle if the output window should close immediately on loop ending save_gif: Bool to toggle if gif file should be saved to AppData save_webm: Bool to toggle if webm file should be saved to AppData deterministic: Bool to toggle if the agents actions should be deterministic gif_output_directory: Directory where the GIF will be output webm_output_directory: Directory where the WEBM file will be output """ gif_uuid = str(uuid4()) complete_results = [] for i in range(self.episode_count): results = pd.DataFrame( columns=["action", "rewards", "info"] ) # temporary log to satisfy repeatability tests until logging can be full implemented obs = self.env.reset() done = False frame_names = [] current_image = 0 while not done: # gets the agents prediction for the best next action to take action, _states = self.agent.predict(obs, deterministic=deterministic) # TODO: setup logging properly here # logging.info(f'Blue Agent Action: {action}') # step the env obs, rewards, done, info = self.env.step(action) results.loc[len(results.index)] = [action, rewards, info] # TODO: setup logging properly here # logging.info(f'Observations: {obs.flatten()} Rewards:{rewards} Done:{done}') # self.env.render(episode=i+1) if save_gif or save_webm: current_name = os.path.join( APP_IMAGES_DIR, f"{gif_uuid}_{current_image}.png" ) current_image += 1 # set the size of the gif image self._get_render_figure(current_name) frame_names.append(current_name) if render_network: self.env.render(*args, **kwargs) # get current time string_time = datetime.now().strftime("%d-%m-%Y_%H-%M") generate_render_thread = [] def natural_sort_key(s, _nsre=re.compile("([0-9]+)")): return [ int(text) if text.isdigit() else text.lower() for text in _nsre.split(s) ] frame_names = sorted(frame_names, key=natural_sort_key) if save_gif: if gif_output_directory is None: gif_output_directory = IMAGES_DIR gif_path = os.path.join( gif_output_directory, f"{self.filename}_{string_time}_{self.episode_count}.gif", ) # gif generator thread gif_thread = Thread( target=self.generate_gif, args=( gif_path, frame_names, ), ) generate_render_thread.append(gif_thread) if save_webm: if webm_output_directory is None: webm_output_directory = VIDEOS_DIR webm_path = os.path.join( webm_output_directory, f"{self.filename}_{string_time}_{self.episode_count}.webm", ) # video generator thread video_thread = Thread( target=self.generate_webm, args=( webm_path, frame_names, ), ) generate_render_thread.append(video_thread) # if any threads were added to generate threads list, run them if len(generate_render_thread): for thread in generate_render_thread: thread.start() thread.join() # clean up once done self.render_cleanup(frame_names) complete_results.append(results) if not prompt_to_close: self.env.close() return complete_results
[docs] def standard_action_loop(self, deterministic=False): """Indefinitely act within the environment using a trained agent.""" complete_results = [] for i in range(self.episode_count): results = pd.DataFrame( columns=["action", "rewards", "info"] ) # temporary log to satisfy repeatability tests until logging can be full implemented obs = self.env.reset() done = False while not done: action, _states = self.agent.predict(obs, deterministic=deterministic) # TODO: setup logging properly here # logging.info(f'Blue Agent Action: {action}') obs, rewards, done, info = self.env.step(action) results.loc[len(results.index)] = [action, rewards, info] complete_results.append(results) return complete_results
[docs] def random_action_loop(self, deterministic=False): """Indefinitely act within the environment taking random actions.""" for i in range(self.episode_count): obs = self.env.reset() done = False reward = 0 while not done: action = self.agent.predict( obs, reward, done, deterministic=deterministic ) ob, reward, done, ep_history = self.env.step(action) if done: break
@classmethod def _get_render_figure(cls, gif_name: string) -> Any: fig = plt.gcf() # save the current image plt.savefig(gif_name, bbox_inches="tight", dpi=100) return fig
[docs] def generate_gif(self, gif_path, frame_names): """Generate GIF from images.""" # TODO: Full docstring. with imageio.get_writer(gif_path, mode="I") as writer: # create a gif from the images for frame_num, filename in enumerate(frame_names): # skip first frame because it is empty if filename == frame_names[0]: continue # read image image = imageio.imread(filename) # add image to GIF writer.append_data(image) # if the last frame, add more of it so the result can be seen longer if frame_num == len(frame_names) - 1: for _ in range(10): writer.append_data(image)
[docs] def generate_webm(self, webm_path, frame_names): """Create webm from image files.""" # TODO: Full docstring. clip = mp.ImageSequenceClip(frame_names[1:], fps=5) clip.write_gif(webm_path, program="ffmpeg")
[docs] def render_cleanup(self, frame_names): """Delete the frames image files.""" # TODO: Full docstring. # delete images for filename in set(frame_names): os.remove(filename)