Source code for tests.integration_tests.generic_env.test_new_entry_nodes

import random

import numpy as np
import pytest
from stable_baselines3.common.env_checker import check_env


[docs]@pytest.mark.integration_test() def test_new_entry_nodes(create_yawning_titan_run): """Test the selection of entry nodes and validate they are correct.""" yt_run = create_yawning_titan_run( game_mode_name="new_entry_nodes", network_name="mesh_18" ) env = yt_run.env check_env(env, warn=True) env.reset() entry_nodes = {} for i in range(0, 10000): obs, rew, done, notes = env.step( random.randint(0, env.BLUE.get_number_of_actions() - 1) ) if done: for node in env.network_interface.current_graph.entry_nodes: if node.uuid not in entry_nodes: entry_nodes[node.uuid] = 1 else: entry_nodes[node.uuid] += 1 env.reset() # check that entry nodes cannot be chosen assert len(entry_nodes.keys()) == 18 # check that each node is roughly chosen equally target_count = 10000 / len(entry_nodes.values()) * 3 # num entry nodes = 3 for i in entry_nodes.values(): assert np.isclose(i, target_count, atol=(target_count * 0.1))