From 932dea1676673a17797d708be9e32973efcd965d Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Thu, 4 Jan 2024 18:29:40 +0200 Subject: [PATCH] feat(ai): add training func --- main.py | 1 + src/ai/__init__.py | 3 ++- src/ai/training.py | 20 ++++++++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 src/ai/training.py diff --git a/main.py b/main.py index 4803b02..25e9a54 100755 --- a/main.py +++ b/main.py @@ -70,6 +70,7 @@ def main(args: argparse.ArgumentParser) -> None: if args.train is not None: ai.log.debug("Training the AI") + ai.train(args.train) else: game.log.debug("Running the game") game.Main().run() diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 466da48..c7794f3 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -1,3 +1,4 @@ from .log import log +from .training import train -__all__ = ["log"] +__all__ = ["log", "train"] diff --git a/src/ai/training.py b/src/ai/training.py new file mode 100644 index 0000000..7abf8d2 --- /dev/null +++ b/src/ai/training.py @@ -0,0 +1,20 @@ +import neat +from utils import BASE_PATH + +from .config import get_config +from .evaluations import eval_genomes +from .io import save_genome +from .log import log + + +def train(generations: int) -> None: + """Train the AI""" + config = get_config() + population = neat.Population(config) + population.add_reporter(neat.StdOutReporter(True)) + population.add_reporter(neat.StatisticsReporter()) + population.add_reporter(neat.Checkpointer(5, 900)) + winner = population.run(eval_genomes, generations) + + log.info("Saving best genome") + save_genome(winner)