From 0e5e502898543d12f8e9fa6f6ad2aa66f2c46936 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Fri, 5 Jan 2024 16:34:26 +0200 Subject: [PATCH] feat(ai): add checkpoint path as argument on launch --- config.txt => config | 0 main.py | 5 ++--- src/ai/fitness.py | 2 +- src/ai/training.py | 18 +++++++++++------- 4 files changed, 14 insertions(+), 11 deletions(-) rename config.txt => config (100%) diff --git a/config.txt b/config similarity index 100% rename from config.txt rename to config diff --git a/main.py b/main.py index f2d154b..5438502 100755 --- a/main.py +++ b/main.py @@ -32,9 +32,8 @@ group.add_argument( parser.add_argument( "-t", "--train", - type=pos_int, - nargs=2, - metavar=("generations", "parallels"), + nargs=3, + metavar=("n generations", "n parallels", "checkpoint"), help="Trains the AI", ) diff --git a/src/ai/fitness.py b/src/ai/fitness.py index 2c2e81a..b2ce090 100644 --- a/src/ai/fitness.py +++ b/src/ai/fitness.py @@ -53,7 +53,7 @@ def _calc_holes(field: np.ndarray) -> float: return penalty -def _calc_height_penalty(field: np.ndarray) -> float: +def _calc_height_penalty(field: np.ndarray) -> tuple[float, float]: column_heights = np.max( np.where(field == 1, field.shape[0] - np.arange(field.shape[0])[:, None], 0), axis=0, diff --git a/src/ai/training.py b/src/ai/training.py index df64f55..119b767 100644 --- a/src/ai/training.py +++ b/src/ai/training.py @@ -1,4 +1,5 @@ import time +from typing import Optional import neat import pygame @@ -12,7 +13,9 @@ from .visualize import plot_progress, plot_species, plot_stats def train( - gen_count: int = CONFIG.ai.generations, parallel: int = CONFIG.ai.parallels + gen_count: int = CONFIG.ai.generations, + parallel: int = CONFIG.ai.parallels, + checkpoint_path: Optional[str] = None, ) -> None: """ Train the AI @@ -22,18 +25,19 @@ def train( """ config = get_config() - # population = neat.Checkpointer().restore_checkpoint( - # CONFIG.ai.checkpoint_path / "neat-checkpoint-199" - # ) - population = neat.Population(config) + population = ( + neat.Checkpointer().restore_checkpoint(checkpoint_path) + if checkpoint_path + else neat.Population(config) + ) population.add_reporter(neat.StdOutReporter(True)) stats = neat.StatisticsReporter() population.add_reporter(stats) population.add_reporter(neat.Checkpointer(5, 900)) - pe = neat.ParallelEvaluator(parallel, eval_genome) + pe = neat.ParallelEvaluator(int(parallel), eval_genome) - winner = population.run(pe.evaluate, gen_count) + winner = population.run(pe.evaluate, int(gen_count)) plot_stats( stats, ylog=False,