diff --git a/src/ai/training.py b/src/ai/training.py index 119b767..bddfc1f 100644 --- a/src/ai/training.py +++ b/src/ai/training.py @@ -33,7 +33,9 @@ def train( population.add_reporter(neat.StdOutReporter(True)) stats = neat.StatisticsReporter() population.add_reporter(stats) - population.add_reporter(neat.Checkpointer(5, 900)) + population.add_reporter( + neat.Checkpointer(CONFIG.ai.checkpoint_interval, CONFIG.ai.checkpoint_delay) + ) pe = neat.ParallelEvaluator(int(parallel), eval_genome) diff --git a/src/utils/config.py b/src/utils/config.py index 413f225..c6c2d01 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -66,6 +66,8 @@ class AI: plot_path: Path = BASE_PATH / "plots" checkpoint_path: Path = BASE_PATH / "checkpoints" config_path: Path = BASE_PATH / "config" + checkpoint_interval: int = 10 + checkpoint_delay: int = 900 @define