From e6cef45d194991331ac07b0caae080dcea89e1f9 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Fri, 5 Jan 2024 16:26:19 +0200 Subject: [PATCH] feat(utils, ai): set neat config file location in `Config` --- src/ai/config.py | 13 ------------- src/ai/io.py | 10 ++++++++++ src/ai/training.py | 3 +-- src/utils/config.py | 1 + 4 files changed, 12 insertions(+), 15 deletions(-) delete mode 100644 src/ai/config.py diff --git a/src/ai/config.py b/src/ai/config.py deleted file mode 100644 index f761d5a..0000000 --- a/src/ai/config.py +++ /dev/null @@ -1,13 +0,0 @@ -import neat -from utils import BASE_PATH - - -def get_config() -> neat.Config: - config_path = BASE_PATH / "config.txt" - return neat.Config( - neat.DefaultGenome, - neat.DefaultReproduction, - neat.DefaultSpeciesSet, - neat.DefaultStagnation, - config_path, - ) diff --git a/src/ai/io.py b/src/ai/io.py index 935908b..9aa3b60 100644 --- a/src/ai/io.py +++ b/src/ai/io.py @@ -5,6 +5,16 @@ import neat from utils import CONFIG +def get_config() -> neat.Config: + return neat.Config( + neat.DefaultGenome, + neat.DefaultReproduction, + neat.DefaultSpeciesSet, + neat.DefaultStagnation, + CONFIG.ai.config_path, + ) + + def load_genome() -> neat.DefaultGenome: with open(CONFIG.ai.winner_path, "rb") as f: return pickle.load(f) diff --git a/src/ai/training.py b/src/ai/training.py index 32aaa58..df64f55 100644 --- a/src/ai/training.py +++ b/src/ai/training.py @@ -5,9 +5,8 @@ import pygame from game import Main from utils import BASE_PATH, CONFIG -from .config import get_config from .evaluations import eval_genome -from .io import save_genome +from .io import get_config, save_genome from .log import log from .visualize import plot_progress, plot_species, plot_stats diff --git a/src/utils/config.py b/src/utils/config.py index d58924a..413f225 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -65,6 +65,7 @@ class AI: winner_path: Path = BASE_PATH / "winner" plot_path: Path = BASE_PATH / "plots" checkpoint_path: Path = BASE_PATH / "checkpoints" + config_path: Path = BASE_PATH / "config" @define