Source code for yawning_titan.envs.generic.wrappers.graph_embedding_observations

import gym
import networkx as nx
import numpy as np
from gym import ObservationWrapper
from gym.spaces import Box
from karateclub.graph_embedding.feathergraph import FeatherGraph

from yawning_titan.envs.generic.generic_env import GenericNetworkEnv


[docs]class FeatherGraphEmbedObservation(ObservationWrapper): """ Gym Observation Space Wrapper that embeds the underlying environment graph using the Feather-G algorithm. This wrapper uses the Feather-G Whole Graph embedding algorithm to embed the underlying environment graph and then re-creates the observation space to include the embedding and all other observation space settings from the configuration file. """
[docs] def __init__(self, env: GenericNetworkEnv, max_num_nodes: int = 100): """ Initialise a Feather-G observation space wrapper. Args: env: the OpenAI Gym environment to be wrapped max_num_nodes: the maximum number of nodes required to be supported in the observation space Note: The max_num_nodes is for defining the maximum number of nodes you want the agent to support within its observation space. This is in order to support the training of agents which can work across a number of YAWNING TITAN environments with variable node counts. For example, if set to 100 (like the default), the agent could be trained in an environment with 10 nodes, 50 nodes or 100 nodes. """ super(FeatherGraphEmbedObservation, self).__init__(env) self.env: GenericNetworkEnv = env self.network_interface = env.network_interface self.new_ob_space_dim = env.calculate_observation_space_size(with_feather=True) self.original_observation_space: gym.spaces.Box = env.observation_space self.observation_space: gym.spaces.Box = Box( -np.inf, np.inf, shape=(self.new_ob_space_dim,) ) self.latest_adj_matrix = None self.latest_graph_embedding = None
[docs] def observation(self, observation: np.ndarray) -> np.ndarray: """ Observation Transformation Function. 1. Generates a networkx graph object from the current adjacency matrix 2. Collects the current vulnerability scores and node status's 3. Pads the returned arrays to ensure length is 100 (currently arbitrarily set) 4. Embeds the networkx graph using the Feather Graph algorithm from Karateclub 5. Concatenates the graph embedding, padded vulnerability scores and padded node status's together 6. Returns new observation Args: observation: The base, unwrapped observation generated by the environment Returns: A newly formatted environment observation """ if self.latest_adj_matrix is None: self.latest_adj_matrix = self.env.network_interface.adj_matrix self.latest_graph_embedding = self.make_embedding() elif ( self.env.network_interface.adj_matrix.all() != self.latest_adj_matrix.all() ): self.latest_adj_matrix = self.env.network_interface.adj_matrix self.latest_graph_embedding = self.make_embedding() standard_obs = self.env.network_interface.get_current_observation() if self.network_interface.game_mode.observation_space.node_connections.value: size_standard_adj = self.network_interface.get_total_num_nodes() ** 2 extra_obs = standard_obs[size_standard_adj:] observation = np.concatenate( (self.latest_graph_embedding, extra_obs), axis=None, dtype=np.float32 ) else: observation = standard_obs return observation
[docs] def make_embedding(self) -> np.ndarray: """ Create a FeaterGraph embedding from the inputted NetworkX graph. Returns: A numpy array containing the Feather embedding """ current_graph = nx.from_numpy_array(self.latest_adj_matrix) embedder = FeatherGraph() embedder.fit([current_graph]) embedding = embedder.get_embedding() return embedding