diff --git a/main.py b/main.py index 4803b02..25e9a54 100755 --- a/main.py +++ b/main.py @@ -70,6 +70,7 @@ def main(args: argparse.ArgumentParser) -> None: if args.train is not None: ai.log.debug("Training the AI") + ai.train(args.train) else: game.log.debug("Running the game") game.Main().run() diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 466da48..c7794f3 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -1,3 +1,4 @@ from .log import log +from .training import train -__all__ = ["log"] +__all__ = ["log", "train"] diff --git a/src/ai/training.py b/src/ai/training.py new file mode 100644 index 0000000..7abf8d2 --- /dev/null +++ b/src/ai/training.py @@ -0,0 +1,20 @@ +import neat +from utils import BASE_PATH + +from .config import get_config +from .evaluations import eval_genomes +from .io import save_genome +from .log import log + + +def train(generations: int) -> None: + """Train the AI""" + config = get_config() + population = neat.Population(config) + population.add_reporter(neat.StdOutReporter(True)) + population.add_reporter(neat.StatisticsReporter()) + population.add_reporter(neat.Checkpointer(5, 900)) + winner = population.run(eval_genomes, generations) + + log.info("Saving best genome") + save_genome(winner)