mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
feat(ai): add training func
This commit is contained in:
parent
1ead412528
commit
932dea1676
1
main.py
1
main.py
@ -70,6 +70,7 @@ def main(args: argparse.ArgumentParser) -> None:
|
||||
|
||||
if args.train is not None:
|
||||
ai.log.debug("Training the AI")
|
||||
ai.train(args.train)
|
||||
else:
|
||||
game.log.debug("Running the game")
|
||||
game.Main().run()
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from .log import log
|
||||
from .training import train
|
||||
|
||||
__all__ = ["log"]
|
||||
__all__ = ["log", "train"]
|
||||
|
||||
20
src/ai/training.py
Normal file
20
src/ai/training.py
Normal file
@ -0,0 +1,20 @@
|
||||
import neat
|
||||
from utils import BASE_PATH
|
||||
|
||||
from .config import get_config
|
||||
from .evaluations import eval_genomes
|
||||
from .io import save_genome
|
||||
from .log import log
|
||||
|
||||
|
||||
def train(generations: int) -> None:
|
||||
"""Train the AI"""
|
||||
config = get_config()
|
||||
population = neat.Population(config)
|
||||
population.add_reporter(neat.StdOutReporter(True))
|
||||
population.add_reporter(neat.StatisticsReporter())
|
||||
population.add_reporter(neat.Checkpointer(5, 900))
|
||||
winner = population.run(eval_genomes, generations)
|
||||
|
||||
log.info("Saving best genome")
|
||||
save_genome(winner)
|
||||
Loading…
Reference in New Issue
Block a user