diff --git a/src/ai/io.py b/src/ai/io.py index abc0680..935908b 100644 --- a/src/ai/io.py +++ b/src/ai/io.py @@ -2,14 +2,14 @@ import pickle from pathlib import Path import neat -from utils import BASE_PATH +from utils import CONFIG def load_genome() -> neat.DefaultGenome: - with open(BASE_PATH / "winner.pkl", "rb") as f: + with open(CONFIG.ai.winner_path, "rb") as f: return pickle.load(f) def save_genome(genome: neat.DefaultGenome) -> None: - with open(BASE_PATH / "winner.pkl", "wb") as f: + with open(CONFIG.ai.winner_path, "wb") as f: pickle.dump(genome, f) diff --git a/src/ai/training.py b/src/ai/training.py index 6accdf4..32aaa58 100644 --- a/src/ai/training.py +++ b/src/ai/training.py @@ -3,7 +3,7 @@ import time import neat import pygame from game import Main -from utils import BASE_PATH +from utils import BASE_PATH, CONFIG from .config import get_config from .evaluations import eval_genome @@ -12,19 +12,19 @@ from .log import log from .visualize import plot_progress, plot_species, plot_stats -def train(gen_count: int, parallel: int = 1) -> None: +def train( + gen_count: int = CONFIG.ai.generations, parallel: int = CONFIG.ai.parallels +) -> None: """ Train the AI Args: - gen_count: Number of generations to train. + gen_count: Number of generations to train (default is 200). threads: Number of threads to use (default is 1). """ config = get_config() - chekpoint_path = BASE_PATH / "checkpoints" - plots_path = BASE_PATH / "plots" # population = neat.Checkpointer().restore_checkpoint( - # BASE_PATH / "checkpoints" / "neat-checkpoint-199" + # CONFIG.ai.checkpoint_path / "neat-checkpoint-199" # ) population = neat.Population(config) population.add_reporter(neat.StdOutReporter(True)) @@ -39,9 +39,9 @@ def train(gen_count: int, parallel: int = 1) -> None: stats, ylog=False, view=False, - filename=plots_path / "avg_fitness.png", + filename=CONFIG.ai.plot_path / "avg_fitness.png", ) - plot_species(stats, view=False, filename=plots_path / "speciation.png") + plot_species(stats, view=False, filename=CONFIG.ai.plot_path / "speciation.png") log.info("Saving best genome") save_genome(winner) diff --git a/src/utils/config.py b/src/utils/config.py index 8dc7e5e..d58924a 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -58,6 +58,15 @@ class Music: volume: float = 0.01 +@define +class AI: + generations: int = 200 + parallels: int = 1 + winner_path: Path = BASE_PATH / "winner" + plot_path: Path = BASE_PATH / "plots" + checkpoint_path: Path = BASE_PATH / "checkpoints" + + @define class Config: log_level: str = "warning" @@ -68,6 +77,7 @@ class Config: font: Font = Font() music: Music = Music() colors = TokyoNightNight() + ai = AI() fps: int = 60