refactor(AI): separate into files

This commit is contained in:
Kristofers Solo 2024-01-03 04:15:17 +02:00
parent e03465d2d3
commit 73548ed8f4
10 changed files with 105 additions and 76 deletions

BIN
best_genome Normal file

Binary file not shown.

BIN
best_genome.pkl Normal file

Binary file not shown.

View File

@ -9,7 +9,7 @@ from py2048 import Menu
@logger.catch
def main() -> None:
# Menu().run()
train()
train(100)
if __name__ == "__main__":

View File

@ -1,3 +1,4 @@
from .train import train
from .io import read_genome
from .training import train
__all__ = ["train"]
__all__ = ["train", "read_genome"]

13
src/ai/config.py Normal file
View File

@ -0,0 +1,13 @@
import neat
from path import BASE_PATH
def get_config() -> neat.Config:
config_path = BASE_PATH / "config.txt"
return neat.Config(
neat.DefaultGenome,
neat.DefaultReproduction,
neat.DefaultSpeciesSet,
neat.DefaultStagnation,
config_path,
)

46
src/ai/evaluation.py Normal file
View File

@ -0,0 +1,46 @@
import neat
from loguru import logger
from py2048 import Menu
def eval_genomes(genomes, config: neat.Config):
for genome_id, genome in genomes:
genome.fitness = 0
app = Menu()
net = neat.nn.FeedForwardNetwork.create(genome, config)
app.play()
app._game_active = False
while True:
output = net.activate(
(
*app.game.board.matrix(),
app.game.board.score,
)
)
decision = output.index(max(output))
decisions = {
0: app.game.move_up,
1: app.game.move_down,
2: app.game.move_left,
3: app.game.move_right,
}
decisions[decision]()
app._hande_events()
app.game.draw(app._surface)
max_val = app.game.board.max_val()
if app.game.board._is_full() or max_val >= 2048:
calculate_fitness(genome, max_val)
logger.info(f"{max_val=}")
app.game.restart()
break
def calculate_fitness(genome: neat.DefaultGenome, score: int):
genome.fitness += score

15
src/ai/io.py Normal file
View File

@ -0,0 +1,15 @@
import pickle
from pathlib import Path
import neat
from path import BASE_PATH
def read_genome(filename: Path) -> neat.DefaultGenome:
with open(filename, "rb") as f:
return pickle.load(f)
def save_genome(genome, filename: Path) -> None:
with open(filename, "wb") as f:
pickle.dump(genome, f)

View File

@ -1,73 +0,0 @@
import neat
from loguru import logger
from path import BASE_PATH
from py2048 import Menu
def _get_config() -> neat.Config:
config_path = BASE_PATH / "config.txt"
return neat.Config(
neat.DefaultGenome,
neat.DefaultReproduction,
neat.DefaultSpeciesSet,
neat.DefaultStagnation,
config_path,
)
def train() -> None:
config = _get_config()
# p = neat.Checkpointer.restore_checkpoint("neat-checkpoint-0")
p = neat.Population(config)
p.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
p.add_reporter(stats)
p.add_reporter(neat.Checkpointer(1))
winner = p.run(eval_genomes, 50)
logger.info(f"\nBest genome:\n{winner}")
def eval_genomes(genomes, config: neat.Config):
for genome_id, genome in genomes:
genome.fitness = 4.0
app = Menu()
net = neat.nn.FeedForwardNetwork.create(genome, config)
app.play()
app._game_active = False
while True:
output = net.activate(
(
*app.game.board.matrix(),
app.game.board.score,
)
)
decision = output.index(max(output))
decisions = {
0: app.game.move_up,
1: app.game.move_down,
2: app.game.move_left,
3: app.game.move_right,
}
decisions[decision]()
app._hande_events()
app.game.draw(app._surface)
if app.game.board._is_full() or app.game.board.score > 10_000:
calculate_fitness(genome, app.game.board.score)
logger.info(
f"Genome: {genome_id} fitness: {genome.fitness} score: {app.game.board.score}"
)
app.game.restart()
break
def calculate_fitness(genome, score: int):
genome.fitness += score

22
src/ai/training.py Normal file
View File

@ -0,0 +1,22 @@
import neat
from loguru import logger
from path import BASE_PATH
from .config import get_config
from .evaluation import eval_genomes
from .io import save_genome
def train(generations: int) -> None:
"""Train the AI for a given number of generations."""
config = get_config()
population = neat.Population(config)
population.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
population.add_reporter(stats)
population.add_reporter(neat.Checkpointer(1))
winner = population.run(eval_genomes, generations)
logger.info(winner)
save_genome(winner, BASE_PATH / "best_genome")

View File

@ -116,6 +116,11 @@ class Board(pygame.sprite.Group):
self.empty()
self._initiate_game()
def max_val(self) -> int:
"""Return the maximum value of the tiles."""
tile: Tile
return int(max(tile.value for tile in self.sprites()))
def get_tile(self, position: Position) -> Optional[Tile]:
"""Return the tile at the specified position."""
tile: Tile