Source code for yawning_titan.envs.specific.core.nsa_node_collection

import random
from typing import List, Tuple

import networkx as nx
import numpy as np

from yawning_titan.envs.specific.core.nsa_node import Node


[docs]class NodeCollection: """Class representing a collection of nodes for the 18-node Ridley Environment."""
[docs] def __init__(self, network: Tuple[np.array, dict], chance_to_spread_during_patch): self.adj_matrix = network[0] self.pos_dic = network[1] self.nodes = [] for _ in range(0, len(self.adj_matrix)): self.nodes.append(Node()) self.chance_to_spread_during_patch = chance_to_spread_during_patch
[docs] def get_number_of_nodes(self) -> int: """ Return the number of nodes in the network. Returns: The number of nodes in the network (int) """ return len(self.nodes)
[docs] def get_observation(self) -> np.array: """ Get the states of all the nodes in the network. Returns: observation: The current state of the environment (numpy array) """ observation = np.zeros( (len(self.nodes), (len(self.nodes) + 2)), dtype=np.float32 ) for i in range(0, len(self.nodes)): data = self.nodes[i].get_condition() observation[i][0] = data[0] observation[i][1] = data[1] for j in range(0, len(self.nodes)): if self.nodes[i].get_condition()[0]: observation[i][j + 2] = 0 else: observation[i][j + 2] = self.adj_matrix[i][j] return observation
[docs] def modify_node(self, number: int, changes: Tuple[bool, int]): """ Change the state of a single node. Args: number: the number of the node to change changes: a list with two variables in [isolate, compromise] isolate: A boolean that will if true change the isolation status of the node (true -> false, false -> true) (boolean) compromise: a mode signal that will change the state of a node. 0 does nothing, 1 makes it safe and 2 compromises the node (int) """ [isolate, compromise] = changes if isolate: self.nodes[number].change_isolated() self.nodes[number].change_compromised(compromise)
[docs] def get_compromised_nodes(self) -> List[int]: """ Create a list of all the nodes in the network that are compromised. Returns: compromised_nodes: A list of nodes that are compromised (list of ints) """ compromised_nodes = [] for i in range(0, len(self.nodes)): if self.nodes[i].get_condition()[1]: # check if compromised compromised_nodes.append(i) return compromised_nodes
[docs] def get_un_compromised_nodes(self) -> List[int]: """ Create a list of all the safe nodes in the network. Returns: un_compromised_nodes: A list of nodes that are safe (list of ints) """ un_compromised_nodes = [] for i in range(0, len(self.nodes)): if not self.nodes[i].get_condition()[1]: un_compromised_nodes.append(i) return un_compromised_nodes
[docs] def get_isolated_nodes(self) -> List[int]: """ Create a list of all the isolated nodes in the network. Returns: isolated_nodes: A list of nodes that are isolated (list of ints) """ isolated_nodes = [] for i in range(0, len(self.nodes)): if self.nodes[i].get_condition()[0]: isolated_nodes.append(i) return isolated_nodes
[docs] def get_number_of_isolated(self) -> int: """ Get the number of isolated nodes in the network. Returns: the number of isolated nodes in the network (int) """ return len(self.get_isolated_nodes())
[docs] def get_number_of_un_compromised(self) -> int: """ Get the number of safe nodes in the network. Returns: the number of safe nodes in the network (int) """ return len(self.get_un_compromised_nodes())
[docs] def get_connected_nodes(self, number: int) -> List[int]: """ When given a node returns a list of all of the nodes connected to that node. Args: number: the number of the node to run on Returns: a list of all the nodes connected to a specified node (list of ints) """ if self.nodes[number].get_condition()[0]: return [] else: # checks the connected nodes though the adj matrix and checks if the nodes are not isolated return [ i for i in range(0, len(self.nodes)) if self.adj_matrix[number][i] == 1 and self.nodes[i].get_condition()[0] is False ]
[docs] def spread(self, number: int): """ Spread the red agent through all connected nodes. Args: number: the number of the node to spread from """ if ( self.nodes[number].get_condition()[0] or self.nodes[number].get_condition()[1] is False ): # If the nodes is isolated or not infected pass else: # get all the connected nodes connected_nodes = self.get_connected_nodes(number) for i in connected_nodes: n = random.randint(1, 100) # attempt to spread if n < self.chance_to_spread_during_patch * 100: self.nodes[i].change_compromised(2)
[docs] def calculate_reward(self) -> float: """ Calculate a reward for the current network state. Returns: reward: the reward for being in the current state """ reward = 0 for i in self.nodes: # gets the conditions of all the nodes node_state = i.get_condition() if node_state[1] is False: # reward for safe reward = reward + 0.2 elif node_state[1] is True and node_state[0]: # reward for unsafe but isolated reward = reward + 0.01 return reward
[docs] def get_netx_graph(self) -> nx.Graph: """ Get the underlying networkx graph. Returns: A networkx graph object """ nodes = [str(i) for i in range(self.get_number_of_nodes())] graph = nx.Graph() graph.add_nodes_from(nodes) for i in range(len(self.adj_matrix)): for j in range(len(self.adj_matrix[i])): if self.adj_matrix[i][j] == 1: graph.add_edge(str(i), str(j)) return graph
[docs] def get_netx_pos(self) -> dict: """Get graph positions.""" return self.pos_dic