Source code for yawning_titan.envs.generic.wrappers.graph_embedding_observations
import gym
import networkx as nx
import numpy as np
from gym import ObservationWrapper
from gym.spaces import Box
from karateclub.graph_embedding.feathergraph import FeatherGraph
from yawning_titan.envs.generic.generic_env import GenericNetworkEnv
[docs]class FeatherGraphEmbedObservation(ObservationWrapper):
"""
Gym Observation Space Wrapper that embeds the underlying environment graph using the Feather-G algorithm.
This wrapper uses the Feather-G Whole Graph embedding algorithm to embed the underlying environment
graph and then re-creates the observation space to include the embedding and all other
observation space settings from the configuration file.
"""
[docs] def __init__(self, env: GenericNetworkEnv, max_num_nodes: int = 100):
"""
Initialise a Feather-G observation space wrapper.
Args:
env: the OpenAI Gym environment to be wrapped
max_num_nodes: the maximum number of nodes required to be supported in the
observation space
Note:
The max_num_nodes is for defining the maximum number of nodes you want
the agent to support within its observation space. This is in
order to support the training of agents which can work across a number of
YAWNING TITAN environments with variable node counts.
For example, if set to 100 (like the default), the agent could be trained in
an environment with 10 nodes, 50 nodes or 100 nodes.
"""
super(FeatherGraphEmbedObservation, self).__init__(env)
self.env: GenericNetworkEnv = env
self.network_interface = env.network_interface
self.new_ob_space_dim = env.calculate_observation_space_size(with_feather=True)
self.original_observation_space: gym.spaces.Box = env.observation_space
self.observation_space: gym.spaces.Box = Box(
-np.inf, np.inf, shape=(self.new_ob_space_dim,)
)
self.latest_adj_matrix = None
self.latest_graph_embedding = None
[docs] def observation(self, observation: np.ndarray) -> np.ndarray:
"""
Observation Transformation Function.
1. Generates a networkx graph object from the current adjacency matrix
2. Collects the current vulnerability scores and node status's
3. Pads the returned arrays to ensure length is 100 (currently arbitrarily set)
4. Embeds the networkx graph using the Feather Graph algorithm from Karateclub
5. Concatenates the graph embedding, padded vulnerability scores and padded node status's together
6. Returns new observation
Args:
observation: The base, unwrapped observation generated by the environment
Returns:
A newly formatted environment observation
"""
if self.latest_adj_matrix is None:
self.latest_adj_matrix = self.env.network_interface.adj_matrix
self.latest_graph_embedding = self.make_embedding()
elif (
self.env.network_interface.adj_matrix.all() != self.latest_adj_matrix.all()
):
self.latest_adj_matrix = self.env.network_interface.adj_matrix
self.latest_graph_embedding = self.make_embedding()
standard_obs = self.env.network_interface.get_current_observation()
if self.network_interface.game_mode.observation_space.node_connections.value:
size_standard_adj = self.network_interface.get_total_num_nodes() ** 2
extra_obs = standard_obs[size_standard_adj:]
observation = np.concatenate(
(self.latest_graph_embedding, extra_obs), axis=None, dtype=np.float32
)
else:
observation = standard_obs
return observation
[docs] def make_embedding(self) -> np.ndarray:
"""
Create a FeaterGraph embedding from the inputted NetworkX graph.
Returns:
A numpy array containing the Feather embedding
"""
current_graph = nx.from_numpy_array(self.latest_adj_matrix)
embedder = FeatherGraph()
embedder.fit([current_graph])
embedding = embedder.get_embedding()
return embedding