Source code for yawning_titan.envs.specific.graph_explore

import random
from typing import Tuple

import gym
import networkx as nx
import numpy as np

from yawning_titan.envs.generic.helpers.graph2plot import CustomEnvGraph


[docs]class GraphExplore(gym.Env): """ A custom environment that follows the gym interface spec. This environment emulates a network and enables an agent to select which node to visit, if it is not possible to move to the node the agent is denied the move. """ metadata = {"render.modes": ["human"]} NODES = 10 # the number of nodes within the network random_seed = 1010 # the initial random_seed of the random network generated GAME_MAX = 1000 # the number of game moves allowed visualisation = None
[docs] def __init__(self): """Initialise environment.""" print("GAME INIT") super(GraphExplore, self).__init__() # Define action and observation space # They must be gym.spaces objects # Example when using discrete actions: self.G = nx.random_internet_as_graph(n=self.NODES, random_seed=self.random_seed) self.pos = nx.spring_layout(self.G) self.action_space = gym.spaces.Discrete(self.NODES + 1) # Example for using image as input: self.observation_space = gym.spaces.Box( low=0, high=1, shape=(self.NODES, 1, 1), dtype=np.uint8 ) # all can communicate, or none can self.reward_range = (-1 * (self.NODES * self.NODES), self.NODES * self.NODES) # start blue self.INITIAL_BLUE = random.choice(list(self.G.nodes())) self.POS_BLUE = self.INITIAL_BLUE # blue visit list self.BLUE_VISIT = [self.INITIAL_BLUE] # current step self.CURRENT_STEP = 0 # score self.BLUE_SCORE = 0
[docs] def step(self, action: int) -> Tuple[np.array, float, bool, dict]: """Execute one time step within the environment.""" print( "GAME STEP {_step} of {_max}".format( _step=1 + self.CURRENT_STEP, _max=self.GAME_MAX ) ) self._take_action(action) self.CURRENT_STEP += 1 reward = self._calc_reward() obs = self._next_observation() done = False if len(self.BLUE_VISIT) == self.NODES or (self.CURRENT_STEP == self.GAME_MAX): done = True return obs, reward, done, {}
def _take_action(self, action: int): """Take an action withint the environment from the node chosen to be visited.""" if action == self.NODES: # do nothing print("Passing turn") pass else: # attempt to visit node print( "Currently at node:{current} and want to move to: {future}".format( current=self.POS_BLUE, future=action ) ) # is possible? if action in list(self.G.neighbors(self.POS_BLUE)) or ( action == self.POS_BLUE ): # move print("Moved to: {future}".format(future=action)) self.POS_BLUE = action try: self.BLUE_VISIT.remove(action) except: # noqa print("Never visited node {future} before".format(future=action)) self.BLUE_VISIT.append(action) else: # do not move print("Cannot move to: {future}".format(future=action)) pass def _calc_reward(self) -> float: """ Calculate agent reward. The defined reward as the total number of nodes visited with some penalty to try and coax the the agent into moving. """ return len(self.BLUE_VISIT) - ( (self.CURRENT_STEP / (1.0 * self.GAME_MAX)) * len(self.BLUE_VISIT) ) def _next_observation(self) -> np.array: """ Get the next observation. The observation space is just a list of nodes visited and therefore the agent is blind to the connectivity space """ # return has blue visited nodes? obs = np.zeros(self.NODES) for i in self.BLUE_VISIT: obs[i] = 1 obs = np.array(obs).reshape(self.NODES, 1, 1) return obs
[docs] def reset(self) -> np.array: """Reset the initial game configurations.""" # Reset the state of the environment to an initial state print("GAME RESET") self.CURRENT_STEP = 0 self.G = nx.random_internet_as_graph(n=self.NODES, random_seed=self.random_seed) self.pos = nx.spring_layout(self.G) self.INITIAL_BLUE = random.choice(list(self.G.nodes())) self.POS_BLUE = self.INITIAL_BLUE self.CURRENT_STEP = 0 self.BLUE_SCORE = 0 self.BLUE_VISIT = [self.INITIAL_BLUE] return self._next_observation()
[docs] def render(self, mode: str = "live", close: bool = False): """Render the environment to the screen so that it can be played in realtime.""" if mode == "file": pass elif mode == "live": print("rendering..") if self.visualisation is None: self.visualisation = CustomEnvGraph(title="Network Visualisation") if self.CURRENT_STEP > 0: self.visualisation.render( self.CURRENT_STEP, self.G, self.pos, {}, self.BLUE_VISIT, [], self._calc_reward(), None, {i: 0.5 for i in range(len(self.G.nodes))}, [self.POS_BLUE], "Graph explore", )
[docs] def close(self): """Remove all open visualisations.""" if self.visualisation is not None: self.visualisation.close() self.visualisation = None