mirror of
https://github.com/kristoferssolo/2048.git
synced 2025-10-21 15:20:35 +00:00
refactor(AI): separate into files
This commit is contained in:
parent
e03465d2d3
commit
73548ed8f4
BIN
best_genome
Normal file
BIN
best_genome
Normal file
Binary file not shown.
BIN
best_genome.pkl
Normal file
BIN
best_genome.pkl
Normal file
Binary file not shown.
2
main.py
2
main.py
@ -9,7 +9,7 @@ from py2048 import Menu
|
||||
@logger.catch
|
||||
def main() -> None:
|
||||
# Menu().run()
|
||||
train()
|
||||
train(100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from .train import train
|
||||
from .io import read_genome
|
||||
from .training import train
|
||||
|
||||
__all__ = ["train"]
|
||||
__all__ = ["train", "read_genome"]
|
||||
|
||||
13
src/ai/config.py
Normal file
13
src/ai/config.py
Normal file
@ -0,0 +1,13 @@
|
||||
import neat
|
||||
from path import BASE_PATH
|
||||
|
||||
|
||||
def get_config() -> neat.Config:
|
||||
config_path = BASE_PATH / "config.txt"
|
||||
return neat.Config(
|
||||
neat.DefaultGenome,
|
||||
neat.DefaultReproduction,
|
||||
neat.DefaultSpeciesSet,
|
||||
neat.DefaultStagnation,
|
||||
config_path,
|
||||
)
|
||||
46
src/ai/evaluation.py
Normal file
46
src/ai/evaluation.py
Normal file
@ -0,0 +1,46 @@
|
||||
import neat
|
||||
from loguru import logger
|
||||
from py2048 import Menu
|
||||
|
||||
|
||||
def eval_genomes(genomes, config: neat.Config):
|
||||
for genome_id, genome in genomes:
|
||||
genome.fitness = 0
|
||||
app = Menu()
|
||||
net = neat.nn.FeedForwardNetwork.create(genome, config)
|
||||
|
||||
app.play()
|
||||
app._game_active = False
|
||||
|
||||
while True:
|
||||
output = net.activate(
|
||||
(
|
||||
*app.game.board.matrix(),
|
||||
app.game.board.score,
|
||||
)
|
||||
)
|
||||
|
||||
decision = output.index(max(output))
|
||||
|
||||
decisions = {
|
||||
0: app.game.move_up,
|
||||
1: app.game.move_down,
|
||||
2: app.game.move_left,
|
||||
3: app.game.move_right,
|
||||
}
|
||||
|
||||
decisions[decision]()
|
||||
|
||||
app._hande_events()
|
||||
app.game.draw(app._surface)
|
||||
max_val = app.game.board.max_val()
|
||||
|
||||
if app.game.board._is_full() or max_val >= 2048:
|
||||
calculate_fitness(genome, max_val)
|
||||
logger.info(f"{max_val=}")
|
||||
app.game.restart()
|
||||
break
|
||||
|
||||
|
||||
def calculate_fitness(genome: neat.DefaultGenome, score: int):
|
||||
genome.fitness += score
|
||||
15
src/ai/io.py
Normal file
15
src/ai/io.py
Normal file
@ -0,0 +1,15 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import neat
|
||||
from path import BASE_PATH
|
||||
|
||||
|
||||
def read_genome(filename: Path) -> neat.DefaultGenome:
|
||||
with open(filename, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def save_genome(genome, filename: Path) -> None:
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(genome, f)
|
||||
@ -1,73 +0,0 @@
|
||||
import neat
|
||||
from loguru import logger
|
||||
from path import BASE_PATH
|
||||
from py2048 import Menu
|
||||
|
||||
|
||||
def _get_config() -> neat.Config:
|
||||
config_path = BASE_PATH / "config.txt"
|
||||
return neat.Config(
|
||||
neat.DefaultGenome,
|
||||
neat.DefaultReproduction,
|
||||
neat.DefaultSpeciesSet,
|
||||
neat.DefaultStagnation,
|
||||
config_path,
|
||||
)
|
||||
|
||||
|
||||
def train() -> None:
|
||||
config = _get_config()
|
||||
# p = neat.Checkpointer.restore_checkpoint("neat-checkpoint-0")
|
||||
p = neat.Population(config)
|
||||
p.add_reporter(neat.StdOutReporter(True))
|
||||
stats = neat.StatisticsReporter()
|
||||
p.add_reporter(stats)
|
||||
p.add_reporter(neat.Checkpointer(1))
|
||||
|
||||
winner = p.run(eval_genomes, 50)
|
||||
|
||||
logger.info(f"\nBest genome:\n{winner}")
|
||||
|
||||
|
||||
def eval_genomes(genomes, config: neat.Config):
|
||||
for genome_id, genome in genomes:
|
||||
genome.fitness = 4.0
|
||||
app = Menu()
|
||||
net = neat.nn.FeedForwardNetwork.create(genome, config)
|
||||
|
||||
app.play()
|
||||
app._game_active = False
|
||||
|
||||
while True:
|
||||
output = net.activate(
|
||||
(
|
||||
*app.game.board.matrix(),
|
||||
app.game.board.score,
|
||||
)
|
||||
)
|
||||
|
||||
decision = output.index(max(output))
|
||||
|
||||
decisions = {
|
||||
0: app.game.move_up,
|
||||
1: app.game.move_down,
|
||||
2: app.game.move_left,
|
||||
3: app.game.move_right,
|
||||
}
|
||||
|
||||
decisions[decision]()
|
||||
|
||||
app._hande_events()
|
||||
app.game.draw(app._surface)
|
||||
|
||||
if app.game.board._is_full() or app.game.board.score > 10_000:
|
||||
calculate_fitness(genome, app.game.board.score)
|
||||
logger.info(
|
||||
f"Genome: {genome_id} fitness: {genome.fitness} score: {app.game.board.score}"
|
||||
)
|
||||
app.game.restart()
|
||||
break
|
||||
|
||||
|
||||
def calculate_fitness(genome, score: int):
|
||||
genome.fitness += score
|
||||
22
src/ai/training.py
Normal file
22
src/ai/training.py
Normal file
@ -0,0 +1,22 @@
|
||||
import neat
|
||||
from loguru import logger
|
||||
from path import BASE_PATH
|
||||
|
||||
from .config import get_config
|
||||
from .evaluation import eval_genomes
|
||||
from .io import save_genome
|
||||
|
||||
|
||||
def train(generations: int) -> None:
|
||||
"""Train the AI for a given number of generations."""
|
||||
config = get_config()
|
||||
population = neat.Population(config)
|
||||
population.add_reporter(neat.StdOutReporter(True))
|
||||
stats = neat.StatisticsReporter()
|
||||
population.add_reporter(stats)
|
||||
population.add_reporter(neat.Checkpointer(1))
|
||||
|
||||
winner = population.run(eval_genomes, generations)
|
||||
|
||||
logger.info(winner)
|
||||
save_genome(winner, BASE_PATH / "best_genome")
|
||||
@ -116,6 +116,11 @@ class Board(pygame.sprite.Group):
|
||||
self.empty()
|
||||
self._initiate_game()
|
||||
|
||||
def max_val(self) -> int:
|
||||
"""Return the maximum value of the tiles."""
|
||||
tile: Tile
|
||||
return int(max(tile.value for tile in self.sprites()))
|
||||
|
||||
def get_tile(self, position: Position) -> Optional[Tile]:
|
||||
"""Return the tile at the specified position."""
|
||||
tile: Tile
|
||||
|
||||
Loading…
Reference in New Issue
Block a user