import random
from typing import List, Tuple
import networkx as nx
import numpy as np
from yawning_titan.envs.specific.core.nsa_node import Node
[docs]class NodeCollection:
"""Class representing a collection of nodes for the 18-node Ridley Environment."""
[docs] def __init__(self, network: Tuple[np.array, dict], chance_to_spread_during_patch):
self.adj_matrix = network[0]
self.pos_dic = network[1]
self.nodes = []
for _ in range(0, len(self.adj_matrix)):
self.nodes.append(Node())
self.chance_to_spread_during_patch = chance_to_spread_during_patch
[docs] def get_number_of_nodes(self) -> int:
"""
Return the number of nodes in the network.
Returns:
The number of nodes in the network (int)
"""
return len(self.nodes)
[docs] def get_observation(self) -> np.array:
"""
Get the states of all the nodes in the network.
Returns:
observation: The current state of the environment (numpy array)
"""
observation = np.zeros(
(len(self.nodes), (len(self.nodes) + 2)), dtype=np.float32
)
for i in range(0, len(self.nodes)):
data = self.nodes[i].get_condition()
observation[i][0] = data[0]
observation[i][1] = data[1]
for j in range(0, len(self.nodes)):
if self.nodes[i].get_condition()[0]:
observation[i][j + 2] = 0
else:
observation[i][j + 2] = self.adj_matrix[i][j]
return observation
[docs] def modify_node(self, number: int, changes: Tuple[bool, int]):
"""
Change the state of a single node.
Args:
number: the number of the node to change
changes: a list with two variables in [isolate, compromise]
isolate: A boolean that will if true change the isolation status of the node (true -> false,
false -> true) (boolean)
compromise: a mode signal that will change the state of a node. 0 does nothing, 1 makes it safe and
2 compromises the node (int)
"""
[isolate, compromise] = changes
if isolate:
self.nodes[number].change_isolated()
self.nodes[number].change_compromised(compromise)
[docs] def get_compromised_nodes(self) -> List[int]:
"""
Create a list of all the nodes in the network that are compromised.
Returns:
compromised_nodes: A list of nodes that are compromised (list of ints)
"""
compromised_nodes = []
for i in range(0, len(self.nodes)):
if self.nodes[i].get_condition()[1]:
# check if compromised
compromised_nodes.append(i)
return compromised_nodes
[docs] def get_un_compromised_nodes(self) -> List[int]:
"""
Create a list of all the safe nodes in the network.
Returns:
un_compromised_nodes: A list of nodes that are safe (list of ints)
"""
un_compromised_nodes = []
for i in range(0, len(self.nodes)):
if not self.nodes[i].get_condition()[1]:
un_compromised_nodes.append(i)
return un_compromised_nodes
[docs] def get_isolated_nodes(self) -> List[int]:
"""
Create a list of all the isolated nodes in the network.
Returns:
isolated_nodes: A list of nodes that are isolated (list of ints)
"""
isolated_nodes = []
for i in range(0, len(self.nodes)):
if self.nodes[i].get_condition()[0]:
isolated_nodes.append(i)
return isolated_nodes
[docs] def get_number_of_isolated(self) -> int:
"""
Get the number of isolated nodes in the network.
Returns:
the number of isolated nodes in the network (int)
"""
return len(self.get_isolated_nodes())
[docs] def get_number_of_un_compromised(self) -> int:
"""
Get the number of safe nodes in the network.
Returns:
the number of safe nodes in the network (int)
"""
return len(self.get_un_compromised_nodes())
[docs] def get_connected_nodes(self, number: int) -> List[int]:
"""
When given a node returns a list of all of the nodes connected to that node.
Args:
number: the number of the node to run on
Returns:
a list of all the nodes connected to a specified node (list of ints)
"""
if self.nodes[number].get_condition()[0]:
return []
else:
# checks the connected nodes though the adj matrix and checks if the nodes are not isolated
return [
i
for i in range(0, len(self.nodes))
if self.adj_matrix[number][i] == 1
and self.nodes[i].get_condition()[0] is False
]
[docs] def spread(self, number: int):
"""
Spread the red agent through all connected nodes.
Args:
number: the number of the node to spread from
"""
if (
self.nodes[number].get_condition()[0]
or self.nodes[number].get_condition()[1] is False
):
# If the nodes is isolated or not infected
pass
else:
# get all the connected nodes
connected_nodes = self.get_connected_nodes(number)
for i in connected_nodes:
n = random.randint(1, 100)
# attempt to spread
if n < self.chance_to_spread_during_patch * 100:
self.nodes[i].change_compromised(2)
[docs] def calculate_reward(self) -> float:
"""
Calculate a reward for the current network state.
Returns:
reward: the reward for being in the current state
"""
reward = 0
for i in self.nodes:
# gets the conditions of all the nodes
node_state = i.get_condition()
if node_state[1] is False:
# reward for safe
reward = reward + 0.2
elif node_state[1] is True and node_state[0]:
# reward for unsafe but isolated
reward = reward + 0.01
return reward
[docs] def get_netx_graph(self) -> nx.Graph:
"""
Get the underlying networkx graph.
Returns:
A networkx graph object
"""
nodes = [str(i) for i in range(self.get_number_of_nodes())]
graph = nx.Graph()
graph.add_nodes_from(nodes)
for i in range(len(self.adj_matrix)):
for j in range(len(self.adj_matrix[i])):
if self.adj_matrix[i][j] == 1:
graph.add_edge(str(i), str(j))
return graph
[docs] def get_netx_pos(self) -> dict:
"""Get graph positions."""
return self.pos_dic