Source code for yawning_titan.envs.generic.helpers.graph2plot

import math
import statistics
from typing import Dict, List

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

from yawning_titan.networks.network import Network
from yawning_titan.networks.node import Node


[docs]def repeat_check(node: Dict, legend_list: List[Line2D]): """ Checks if a node already exists by comparing the nodes colour and description with nodes already in the legend. Args: node: A node dict. legend_list: The legend list. Returns: ``True`` if is already exists, otherwise ``False``. """ for legend in legend_list: if ( legend.get_markerfacecolor() == node["colour"] and legend.get_label() == node["description"] ): return True return False
[docs]class CustomEnvGraph: """A network graph rendering environment for Open AI Gym use."""
[docs] def __init__(self, title: str = None): """ Initialise the CustomEnvGraph that is used to visualise and render node environments. Args: title: The name of the plot """ self.fig = plt.figure(figsize=(12, 6)) self.fig.suptitle(title) # Create subplot for network graph drawing self.vis_ax = plt.subplot2grid(shape=(1, 1), loc=(0, 0), rowspan=1, colspan=1) plt.subplots_adjust( # left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0 left=0.12, bottom=0.25, right=0.90, top=1.00, wspace=0.25, hspace=0, ) plt.tight_layout() # Show the graph without blocking the rest of the program plt.show(block=False)
[docs] def render( self, current_step: int, g: Network, # pos: dict, # compromised_nodes: dict, # uncompromised_nodes: list, attacked_nodes: List[List[Node]], current_time_step_reward: float, # red_previous_node, # vulnerability_dict: dict, made_safe_nodes: list, # title: str, # special_nodes: dict = None, # entrance_nodes: list = None, show_only_blue_view: bool = False, target_node: Node = None, show_node_names: bool = False, ): """ Render the current network into an axis. Args: current_step: the current step in the environment (int) g: a networkx object that stores the current connectivity (networkx graph) pos: a dictionary that contains the points and their positions compromised_nodes: a dictionary of all the compromised nodes and a boolean value if blue can see the intrusion or not uncompromised_nodes: a list of all the uncompromised nodes attacked_nodes: a list of the nodes where an attack is happening (infected node, target node) current_time_step_reward: the current total reward red_previous_node: CURRENTLY NOT USED vulnerability_dict: A dictionary that stores the vulnerability of the nodes made_safe_nodes: a list of nodes that the blue agent has made safe this turn title: The title for the render special_nodes: A dictionary containing dictionaries of: nodes, node descriptions and colours for the nodes entrance_nodes: Nodes that serve as a gateway for the red agent to be able to access the network show_only_blue_view: If true only shows what the blue agent can see show_node_names: Show the names of nodes """ self.vis_ax.clear() special_node_info = { "high_value_node": { "description": "high value node", "colour": "#da2fed", } } # Creates a list that contains the details for the legend legend_objects = [ # Each item in the legend is an marker with a colour and a description # Compromised Nodes Line2D( [0], [0], color="white", marker="o", markerfacecolor="orange", label="Compromised Node", markersize=15, ), # Vulnerable safe nodes Line2D( [0], [0], color="white", marker="o", markerfacecolor="#00FF13", label="Safe Node: Weak", markersize=15, ), # Safe nodes with low vulnerability Line2D( [0], [0], color="white", marker="o", markerfacecolor="#006007", label="Safe Node: Strong", markersize=15, ), # Nodes that have just been taken over by red Line2D( [0], [0], color="white", marker="o", markerfacecolor="red", label="Attacked Node", markersize=15, ), # The nodes that blue has "patched" or "fixed" this turn Line2D( [0], [0], color="white", marker="o", markerfacecolor="#4ef2e7", label="Blue Patch", markersize=15, ), ] # If a target node is specified add to the legend if target_node is not None: legend_objects.append( Line2D( [0], [0], color="white", marker="o", markerfacecolor="#2c195e", label="Target Node", markersize=15, ) ) # plots the target node plt.scatter( [target_node.x_pos], [target_node.y_pos], color="#2c195e", s=324, zorder=8, ) legend_objects.extend( [ # An edge that red has attacked along this turn Line2D( [0], [0], color="red", marker="_", markerfacecolor="red", label="Attack Path", markersize=15, ), # An edge Line2D( [0], [0], color="gray", marker="_", markerfacecolor="gray", label="Connection", markersize=15, ), ] ) # If only showing the blue view then only render red nodes that blue can see if not show_only_blue_view: legend_objects.append( Line2D( [0], [0], color="white", marker="$\\bf{O}$", markerfacecolor="red", label="Unknown Compromise", markersize=12, ) ) legend_objects.append( Line2D( [0], [0], color="white", marker="$\\bf{O}$", markerfacecolor="blue", label="Known Compromise", markersize=12, ) ) # Some environments may have special custom nodes that they want to add if len(special_node_info) > 0: for node_info in special_node_info.values(): # only insert if the legend is not in the list yet if not repeat_check(node_info, legend_objects): # Inserts the object into the legends at position 3. This is because it looks better if there are any # special nodes added that they are added at the some point as the other nodes in the legend legend_objects.insert( 3, Line2D( [0], [0], color="white", marker="o", markerfacecolor=node_info["colour"], label=node_info["description"], markersize=15, ), ) # If entrance nodes are used then they are added to the legend if g.entry_nodes: legend_objects.append( Line2D( [0], [0], color="white", marker="$\\bf{E}$", markerfacecolor="black", label="Entry Node", markersize=12, ) ) # plots all of the edges in the graph for edge in g.edges: plt.plot( [edge[0].x_pos, edge[1].x_pos], [edge[0].y_pos, edge[1].y_pos], color="grey", zorder=1, ) # plots all of the current turns attacks red_nodes_x = [] red_nodes_y = [] for node_set in attacked_nodes: red_nodes_x.append(node_set[1].x_pos) red_nodes_y.append(node_set[1].y_pos) if node_set[0] is not None: plt.plot( [node_set[0].x_pos, node_set[1].x_pos], [node_set[0].y_pos, node_set[1].y_pos], color="red", zorder=2, ) # All the shades of green for the different levels of vulnerability green_shades = [ "#00FF13", "#00DF11", "#00BF0E", "#009F0C", "#00800A", "#006007", ] max_x = 0 max_y = 0 min_x = 100000 min_y = 100000 comp_x = [] comp_y = [] known_comp_x = [] known_comp_y = [] unknown_comp_x = [] unknown_comp_y = [] safe_x = [] safe_y = [] safe_colours = [] void_x = [] void_y = [] special_x = [] special_y = [] special_colour = [] made_safe_x = [] made_safe_y = [] for n in g.get_nodes(): max_x = max(max_x, n.x_pos) max_y = max(max_y, n.y_pos) min_x = min(min_x, n.x_pos) min_y = min(min_y, n.y_pos) if n in made_safe_nodes: # get the locations of nodes that have been made safe made_safe_x.append(n.x_pos) made_safe_y.append(n.y_pos) elif n.high_value_node: # get the locations of special nodes special_x.append(n.x_pos) special_y.append(n.y_pos) special_colour.append(special_node_info["high_value_node"]["colour"]) elif n.true_compromised_status == 1: if n.blue_knows_intrusion: # get the locations of compromised nodes (unknown) comp_x.append(n.x_pos) comp_y.append(n.y_pos) if not show_only_blue_view: # get the locations of compromised nodes (known) known_comp_x.append(n.x_pos) known_comp_y.append(n.y_pos) else: if not show_only_blue_view: comp_x.append(n.x_pos) comp_y.append(n.y_pos) unknown_comp_x.append(n.x_pos) unknown_comp_y.append(n.y_pos) else: # get the location of the safe nodes vuln = n.vulnerability_score index = 5 - math.floor(vuln * (len(green_shades) - 1)) safe_colours.append(green_shades[index]) safe_x.append(n.x_pos) safe_y.append(n.y_pos) elif n.true_compromised_status == 0: # get the locations of safe nodes vuln = n.vulnerability_score index = 5 - math.floor(vuln * (len(green_shades) - 1)) safe_colours.append(green_shades[index]) safe_x.append(n.x_pos) safe_y.append(n.y_pos) else: void_x.append(n.x_pos) void_y.append(n.y_pos) # plots any nodes that have no features plt.scatter(void_x, void_y, color="grey", s=300, zorder=1) # plots all of the compromised nodes plt.scatter(comp_x, comp_y, color="orange", s=324, zorder=8) # plot the circles around unknown compromised nodes plt.scatter(unknown_comp_x, unknown_comp_y, color="red", s=484, zorder=7) # plot the circles around known compromised nodes plt.scatter(known_comp_x, known_comp_y, color="blue", s=484, zorder=7) # plots all of the safe nodes plt.scatter(safe_x, safe_y, color=safe_colours, s=324, zorder=5) # plots all of the recently taken red nodes plt.scatter(red_nodes_x, red_nodes_y, color="red", s=324, zorder=9) # plots all of the nodes that have just been patched plt.scatter(made_safe_x, made_safe_y, color="#4ef2e7", s=324, zorder=10) # plots any special nodes for the env plt.scatter(special_x, special_y, color=special_colour, s=324, zorder=6) # plot the entrance nodes for node in g.entry_nodes: plt.scatter( [node.x_pos], [node.y_pos], color="black", zorder=11, s=121, marker="$E$", ) if show_node_names: for node in g.nodes: plt.text( node.x_pos + 0.1, node.y_pos + 0.1, node, color="red", fontsize=12, zorder=11, ) # Creates a string containing information about the current state of the network info = ( "Current Step: " + str(current_step) + "\nReward for current time step: " + str(current_time_step_reward) + "\nCurrent Avg vulnerability: " + str(round(statistics.mean([n.vulnerability_score for n in g.nodes]), 2)) ) ax = plt.gca() ax.legend( handles=legend_objects, loc="center left", bbox_to_anchor=(1, 0.5), borderpad=1, labelspacing=1, fontsize=10, edgecolor="black", ) ax.axes.xaxis.set_ticks([]) ax.axes.yaxis.set_ticks([]) ax.axes.set_xlim(min_x - 0.1 * max_x, max_x * 1.1) ax.axes.set_ylim(min_y - 0.1 * max_y, max_y * 1.1) ax.set_xlabel(info) for pos in ["left", "right", "top", "bottom"]: plt.gca().spines[pos].set_visible(False) # plt.show() # invert y axis - computer coords to cartesian conversion plt.gca().invert_yaxis()
[docs] def close(self): """Close all handles to external renderers.""" plt.close()