Source code for yawning_titan.envs.generic.generic_env

"""
A generic class that creates Open AI environments within YAWNING TITAN.

This class has several key inputs which determine aspects of the environment such
as how the red agent behaves, what the red team and blue team objectives are, the size
and topology of the network being defended and what data should be collected during the simulation.
"""

import copy
import json
from collections import Counter
from typing import Dict, Tuple

import gym
import numpy as np
from gym import spaces
from stable_baselines3.common.utils import set_random_seed

import yawning_titan.envs.generic.core.reward_functions as reward_functions
from yawning_titan.envs.generic.core.blue_interface import BlueInterface
from yawning_titan.envs.generic.core.network_interface import NetworkInterface
from yawning_titan.envs.generic.core.red_interface import RedInterface
from yawning_titan.envs.generic.helpers.eval_printout import EvalPrintout
from yawning_titan.envs.generic.helpers.graph2plot import CustomEnvGraph


[docs]class GenericNetworkEnv(gym.Env): """Class to create a generic YAWNING TITAN gym environment."""
[docs] def __init__( self, red_agent: RedInterface, blue_agent: BlueInterface, network_interface: NetworkInterface, print_metrics: bool = False, show_metrics_every: int = 1, collect_additional_per_ts_data: bool = True, print_per_ts_data: bool = False, ): """ Initialise the generic network environment. Args: red_agent: Object from the RedInterface class blue_agent: Object from the BlueInterface class network_interface: Object from the NetworkInterface class print_metrics: Whether or not to print metrics (boolean) show_metrics_every: Number of timesteps to show summary metrics (int) collect_additional_per_ts_data: Whether or not to collect additional per timestep data (boolean) print_per_ts_data: Whether or not to print collected per timestep data (boolean) Note: The ``notes`` variable returned at the end of each timestep contains the per timestep data. By default it contains a base level of info required for some of the reward functions. When ``collect_additional_per_ts_data`` is toggled on, a lot more data is collected. """ super(GenericNetworkEnv, self).__init__() self.RED = red_agent self.BLUE = blue_agent self.blue_actions = blue_agent.get_number_of_actions() self.network_interface = network_interface self.current_duration = 0 self.game_stats_list = [] self.num_games_since_avg = 0 self.avg_every = show_metrics_every self.current_game_blue = {} self.current_game_stats = {} self.total_games = 0 self.made_safe_nodes = [] self.current_reward = 0 self.print_metrics = print_metrics self.print_notes = print_per_ts_data self.random_seed = self.network_interface.random_seed self.graph_plotter = None self.eval_printout = EvalPrintout(self.avg_every) self.action_space = spaces.Discrete(self.blue_actions) self.network_interface.get_observation_size() # sets up the observation space. This is a (n+2 by n) matrix. The first two columns show the state of all the # nodes. The remaining n columns show the connections between the nodes (effectively the adjacency matrix) self.observation_space = spaces.Box( low=0, high=1, shape=(self.network_interface.get_observation_size(),), dtype=np.float32, ) # The gym environment can only properly deal with a 1d array so the observation is flattened self.collect_data = collect_additional_per_ts_data self.env_observation = self.network_interface.get_current_observation()
[docs] def reset(self) -> np.array: """ Reset the environment to the default state. :todo: May need to add customization of cuda setting. :return: A new starting observation (numpy array). """ if self.random_seed is not None: # conditionally set random_seed set_random_seed(self.random_seed, True) self.network_interface.reset() self.RED.reset() self.current_duration = 0 self.env_observation = self.network_interface.get_current_observation() self.current_game_blue = {} return self.env_observation
[docs] def step(self, action: int) -> Tuple[np.array, float, bool, Dict[str, dict]]: """ Take a time step and executes the actions for both Blue RL agent and non-learning Red agent. Args: action: The action value generated from the Blue RL agent (int) Returns: A four tuple containing the next observation as a numpy array, the reward for that timesteps, a boolean for whether complete and additional notes containing timestep information from the environment. """ # sets the nodes that have been made safe this turn to an empty list self.made_safe_nodes = [] # Gets the initial states of various states for logging and testing purposes if self.collect_data: # notes collects information about the state of the env notes = { "initial_state": self.network_interface.get_all_node_compromised_states(), "initial_blue_view": self.network_interface.get_all_node_blue_view_compromised_states(), "initial_vulnerabilities": self.network_interface.get_all_vulnerabilities(), "initial_red_location": copy.deepcopy( self.network_interface.red_current_location ), "initial_graph": self.network_interface.get_current_graph_as_dict(), "current_step": self.current_duration, } else: # If not logging everything, the program still needs to collect some information (required by other parts # of the program) notes = {} # resets the attack list for the red agent (so that only the current turns attacks are held) self.network_interface.reset_stored_attacks() # The red agent performs their turn if ( self.network_interface.game_mode.game_rules.grace_period_length.value <= self.current_duration ): red_info = self.RED.perform_action() else: red_info = { 0: { "Action": "do_nothing", "Attacking_Nodes": [], "Target_Nodes": [], "Successes": [True], } } # Gets the number of nodes that are safe number_uncompromised = len( self.network_interface.current_graph.get_nodes(filter_true_safe=True) ) # Collects data on the natural spreading if self.collect_data: notes["red_info"] = red_info # The states of the nodes after red has had their turn (Used by the reward functions) notes[ "post_red_state" ] = self.network_interface.get_all_node_compromised_states() # Blues view of the environment after red has had their turn notes[ "post_red_blue_view" ] = self.network_interface.get_all_node_blue_view_compromised_states() # A dictionary of vulnerabilities after red has had their turn notes[ "post_red_vulnerabilities" ] = self.network_interface.get_all_vulnerabilities() # The isolation status of all the nodes notes["post_red_isolation"] = self.network_interface.get_all_isolation() # collects extra data if turned on if self.collect_data: # The location of the red agent after red has had their turn notes["post_red_red_location"] = copy.deepcopy( self.network_interface.red_current_location ) # set up initial variables that are reassigned based on the action that blue takes done = False reward = 0 blue_action = "" blue_node = None # Check if the game is over and red has won if ( self.network_interface.game_mode.game_rules.blue_loss_condition.all_nodes_lost.value ): if number_uncompromised == 0: done = True reward = self.network_interface.game_mode.rewards.for_loss.value blue_action = "failed" if ( self.network_interface.game_mode.game_rules.blue_loss_condition.n_percent_nodes_lost.use.value ): # calculate the number of safe nodes percent_comp = ( len( self.network_interface.current_graph.get_nodes( filter_true_compromised=True ) ) / self.network_interface.current_graph.number_of_nodes() ) if ( percent_comp >= self.network_interface.game_mode.game_rules.blue_loss_condition.n_percent_nodes_lost.value.value ): done = True reward = self.network_interface.game_mode.rewards.for_loss.value # If the game ends before blue has had their turn the the blue action is set to failed blue_action = "failed" if ( self.network_interface.game_mode.game_rules.blue_loss_condition.high_value_node_lost.value ): # check if a high value node was compromised compromised_hvn = False for hvn in self.network_interface.current_graph.high_value_nodes: if hvn.true_compromised_status == 1: compromised_hvn = True break if compromised_hvn: # If this mode is selected then the game ends if the high value node has been compromised done = True reward = self.network_interface.game_mode.rewards.for_loss.value blue_action = "failed" # if self.network_interface.gr_loss_tn: tn = self.network_interface.get_target_node() if ( tn is not None and self.network_interface.game_mode.game_rules.blue_loss_condition.target_node_lost.value ): if tn.true_compromised_status == 1: # If this mode is selected then the game ends if the target node has been compromised done = True reward = self.network_interface.game_mode.rewards.for_loss.value blue_action = "failed" if done: if ( self.network_interface.game_mode.rewards.reduce_negative_rewards_for_closer_fails.value ): reward = reward * ( 1 - ( self.current_duration / self.network_interface.game_mode.game_rules.max_steps.value ) ) if not done: blue_action, blue_node = self.BLUE.perform_action(action) if blue_action == "make_node_safe" or blue_action == "restore_node": self.made_safe_nodes.append(blue_node) if blue_action in self.current_game_blue: self.current_game_blue[blue_action] += 1 else: self.current_game_blue[blue_action] = 1 # calculates the reward from the current state of the network reward_args = { "network_interface": self.network_interface, "blue_action": blue_action, "blue_node": blue_node, "start_state": notes["post_red_state"], "end_state": self.network_interface.get_all_node_compromised_states(), "start_vulnerabilities": notes["post_red_vulnerabilities"], "end_vulnerabilities": self.network_interface.get_all_vulnerabilities(), "start_isolation": notes["post_red_isolation"], "end_isolation": self.network_interface.get_all_isolation(), "start_blue": notes["post_red_blue_view"], "end_blue": self.network_interface.get_all_node_blue_view_compromised_states(), } reward = getattr( reward_functions, self.network_interface.game_mode.rewards.function.value, )(reward_args) # gets the current observation from the environment self.env_observation = ( self.network_interface.get_current_observation().flatten() ) self.current_duration += 1 # if the total number of steps reaches the set end then the blue agent wins and is rewarded accordingly if ( self.current_duration == self.network_interface.game_mode.game_rules.max_steps.value ): if ( self.network_interface.game_mode.rewards.end_rewards_are_multiplied_by_end_state.value ): reward = ( self.network_interface.game_mode.rewards.for_reaching_max_steps.value * ( len( self.network_interface.current_graph.get_nodes( filter_true_safe=True ) ) / self.network_interface.current_graph.number_of_nodes() ) ) else: reward = ( self.network_interface.game_mode.rewards.for_reaching_max_steps.value ) done = True # Gets the state of the environment at the end of the current time step if self.collect_data: # The blues view of the network notes[ "end_blue_view" ] = self.network_interface.get_all_node_blue_view_compromised_states() # The state of the nodes (safe/compromised) notes[ "end_state" ] = self.network_interface.get_all_node_compromised_states() # A dictionary of vulnerabilities notes[ "final_vulnerabilities" ] = self.network_interface.get_all_vulnerabilities() # The location of the red agent notes["final_red_location"] = copy.deepcopy( self.network_interface.red_current_location ) if ( self.network_interface.game_mode.miscellaneous.output_timestep_data_to_json.value ): current_state = self.network_interface.create_json_time_step() self.network_interface.save_json(current_state, self.current_duration) if self.print_metrics and done: # prints end of game metrics such as who won and how long the game lasted self.num_games_since_avg += 1 self.total_games += 1 # Populate the current game's dictionary of stats with the episode winner and the number of timesteps if ( self.current_duration == self.network_interface.game_mode.game_rules.max_steps.value ): self.current_game_stats = { "Winner": "blue", "Duration": self.current_duration, } else: self.current_game_stats = { "Winner": "red", "Duration": self.current_duration, } # Add the actions taken by blue during the episode to the stats dictionary self.current_game_stats.update(self.current_game_blue) # Add the current game dictionary to the list of dictionaries to average over self.game_stats_list.append(Counter(dict(self.current_game_stats.items()))) # Every self.avg_every episodes, print the stats to console if self.num_games_since_avg == self.avg_every: self.eval_printout.print_stats(self.game_stats_list, self.total_games) self.num_games_since_avg = 0 self.game_stats_list = [] self.current_reward = reward if self.collect_data: notes["safe_nodes"] = len( self.network_interface.current_graph.get_nodes(filter_true_safe=True) ) notes["blue_action"] = blue_action notes["blue_node"] = blue_node notes["attacks"] = self.network_interface.true_attacks notes["end_isolation"] = self.network_interface.get_all_isolation() if self.print_notes: json_data = json.dumps(notes) print(json_data) # Returns the environment information that AI gym uses and all of the information collected in a dictionary return self.env_observation, reward, done, notes
[docs] def render( self, mode: str = "human", show_only_blue_view: bool = False, show_node_names: bool = False, ): """ Render the environment using Matplotlib to create an animation. Args: mode: the mode of the rendering show_only_blue_view: If true shows only what the blue agent can see show_node_names: Show the names of the nodes """ if self.graph_plotter is None: self.graph_plotter = CustomEnvGraph() # gets the networkx object # compromised nodes is a dictionary of all the compromised nodes with a 1 if the compromise is known or a 0 if # not # gets information about the current state from the network interface main_graph = self.network_interface.current_graph if show_only_blue_view: attacks = self.network_interface.detected_attacks else: attacks = self.network_interface.true_attacks reward = round(self.current_reward, 2) # sends the current information to a graph plotter to display the information visually self.graph_plotter.render( current_step=self.current_duration, g=main_graph, attacked_nodes=attacks, current_time_step_reward=reward, # self.network_interface.red_current_location, made_safe_nodes=self.made_safe_nodes, target_node=self.network_interface.get_target_node(), # "RL blue agent vs probabilistic red in a generic network environment", show_only_blue_view=show_only_blue_view, show_node_names=show_node_names, )
[docs] def calculate_observation_space_size(self, with_feather: bool) -> int: """ Calculate the observation space size. This is done using the current active observation space configuration and the number of nodes within the environment. Args: with_feather: Whether to include the size of the Feather Wrapper output Returns: The observation space size """ return self.network_interface.get_observation_size_base(with_feather)