mirror of
https://github.com/kristoferssolo/2048.git
synced 2025-10-21 15:20:35 +00:00
refactor(AI): separate into files
This commit is contained in:
parent
e03465d2d3
commit
73548ed8f4
BIN
best_genome
Normal file
BIN
best_genome
Normal file
Binary file not shown.
BIN
best_genome.pkl
Normal file
BIN
best_genome.pkl
Normal file
Binary file not shown.
2
main.py
2
main.py
@ -9,7 +9,7 @@ from py2048 import Menu
|
|||||||
@logger.catch
|
@logger.catch
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
# Menu().run()
|
# Menu().run()
|
||||||
train()
|
train(100)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
from .train import train
|
from .io import read_genome
|
||||||
|
from .training import train
|
||||||
|
|
||||||
__all__ = ["train"]
|
__all__ = ["train", "read_genome"]
|
||||||
|
|||||||
13
src/ai/config.py
Normal file
13
src/ai/config.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import neat
|
||||||
|
from path import BASE_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def get_config() -> neat.Config:
|
||||||
|
config_path = BASE_PATH / "config.txt"
|
||||||
|
return neat.Config(
|
||||||
|
neat.DefaultGenome,
|
||||||
|
neat.DefaultReproduction,
|
||||||
|
neat.DefaultSpeciesSet,
|
||||||
|
neat.DefaultStagnation,
|
||||||
|
config_path,
|
||||||
|
)
|
||||||
46
src/ai/evaluation.py
Normal file
46
src/ai/evaluation.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import neat
|
||||||
|
from loguru import logger
|
||||||
|
from py2048 import Menu
|
||||||
|
|
||||||
|
|
||||||
|
def eval_genomes(genomes, config: neat.Config):
|
||||||
|
for genome_id, genome in genomes:
|
||||||
|
genome.fitness = 0
|
||||||
|
app = Menu()
|
||||||
|
net = neat.nn.FeedForwardNetwork.create(genome, config)
|
||||||
|
|
||||||
|
app.play()
|
||||||
|
app._game_active = False
|
||||||
|
|
||||||
|
while True:
|
||||||
|
output = net.activate(
|
||||||
|
(
|
||||||
|
*app.game.board.matrix(),
|
||||||
|
app.game.board.score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
decision = output.index(max(output))
|
||||||
|
|
||||||
|
decisions = {
|
||||||
|
0: app.game.move_up,
|
||||||
|
1: app.game.move_down,
|
||||||
|
2: app.game.move_left,
|
||||||
|
3: app.game.move_right,
|
||||||
|
}
|
||||||
|
|
||||||
|
decisions[decision]()
|
||||||
|
|
||||||
|
app._hande_events()
|
||||||
|
app.game.draw(app._surface)
|
||||||
|
max_val = app.game.board.max_val()
|
||||||
|
|
||||||
|
if app.game.board._is_full() or max_val >= 2048:
|
||||||
|
calculate_fitness(genome, max_val)
|
||||||
|
logger.info(f"{max_val=}")
|
||||||
|
app.game.restart()
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_fitness(genome: neat.DefaultGenome, score: int):
|
||||||
|
genome.fitness += score
|
||||||
15
src/ai/io.py
Normal file
15
src/ai/io.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import neat
|
||||||
|
from path import BASE_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def read_genome(filename: Path) -> neat.DefaultGenome:
|
||||||
|
with open(filename, "rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def save_genome(genome, filename: Path) -> None:
|
||||||
|
with open(filename, "wb") as f:
|
||||||
|
pickle.dump(genome, f)
|
||||||
@ -1,73 +0,0 @@
|
|||||||
import neat
|
|
||||||
from loguru import logger
|
|
||||||
from path import BASE_PATH
|
|
||||||
from py2048 import Menu
|
|
||||||
|
|
||||||
|
|
||||||
def _get_config() -> neat.Config:
|
|
||||||
config_path = BASE_PATH / "config.txt"
|
|
||||||
return neat.Config(
|
|
||||||
neat.DefaultGenome,
|
|
||||||
neat.DefaultReproduction,
|
|
||||||
neat.DefaultSpeciesSet,
|
|
||||||
neat.DefaultStagnation,
|
|
||||||
config_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def train() -> None:
|
|
||||||
config = _get_config()
|
|
||||||
# p = neat.Checkpointer.restore_checkpoint("neat-checkpoint-0")
|
|
||||||
p = neat.Population(config)
|
|
||||||
p.add_reporter(neat.StdOutReporter(True))
|
|
||||||
stats = neat.StatisticsReporter()
|
|
||||||
p.add_reporter(stats)
|
|
||||||
p.add_reporter(neat.Checkpointer(1))
|
|
||||||
|
|
||||||
winner = p.run(eval_genomes, 50)
|
|
||||||
|
|
||||||
logger.info(f"\nBest genome:\n{winner}")
|
|
||||||
|
|
||||||
|
|
||||||
def eval_genomes(genomes, config: neat.Config):
|
|
||||||
for genome_id, genome in genomes:
|
|
||||||
genome.fitness = 4.0
|
|
||||||
app = Menu()
|
|
||||||
net = neat.nn.FeedForwardNetwork.create(genome, config)
|
|
||||||
|
|
||||||
app.play()
|
|
||||||
app._game_active = False
|
|
||||||
|
|
||||||
while True:
|
|
||||||
output = net.activate(
|
|
||||||
(
|
|
||||||
*app.game.board.matrix(),
|
|
||||||
app.game.board.score,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
decision = output.index(max(output))
|
|
||||||
|
|
||||||
decisions = {
|
|
||||||
0: app.game.move_up,
|
|
||||||
1: app.game.move_down,
|
|
||||||
2: app.game.move_left,
|
|
||||||
3: app.game.move_right,
|
|
||||||
}
|
|
||||||
|
|
||||||
decisions[decision]()
|
|
||||||
|
|
||||||
app._hande_events()
|
|
||||||
app.game.draw(app._surface)
|
|
||||||
|
|
||||||
if app.game.board._is_full() or app.game.board.score > 10_000:
|
|
||||||
calculate_fitness(genome, app.game.board.score)
|
|
||||||
logger.info(
|
|
||||||
f"Genome: {genome_id} fitness: {genome.fitness} score: {app.game.board.score}"
|
|
||||||
)
|
|
||||||
app.game.restart()
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_fitness(genome, score: int):
|
|
||||||
genome.fitness += score
|
|
||||||
22
src/ai/training.py
Normal file
22
src/ai/training.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import neat
|
||||||
|
from loguru import logger
|
||||||
|
from path import BASE_PATH
|
||||||
|
|
||||||
|
from .config import get_config
|
||||||
|
from .evaluation import eval_genomes
|
||||||
|
from .io import save_genome
|
||||||
|
|
||||||
|
|
||||||
|
def train(generations: int) -> None:
|
||||||
|
"""Train the AI for a given number of generations."""
|
||||||
|
config = get_config()
|
||||||
|
population = neat.Population(config)
|
||||||
|
population.add_reporter(neat.StdOutReporter(True))
|
||||||
|
stats = neat.StatisticsReporter()
|
||||||
|
population.add_reporter(stats)
|
||||||
|
population.add_reporter(neat.Checkpointer(1))
|
||||||
|
|
||||||
|
winner = population.run(eval_genomes, generations)
|
||||||
|
|
||||||
|
logger.info(winner)
|
||||||
|
save_genome(winner, BASE_PATH / "best_genome")
|
||||||
@ -116,6 +116,11 @@ class Board(pygame.sprite.Group):
|
|||||||
self.empty()
|
self.empty()
|
||||||
self._initiate_game()
|
self._initiate_game()
|
||||||
|
|
||||||
|
def max_val(self) -> int:
|
||||||
|
"""Return the maximum value of the tiles."""
|
||||||
|
tile: Tile
|
||||||
|
return int(max(tile.value for tile in self.sprites()))
|
||||||
|
|
||||||
def get_tile(self, position: Position) -> Optional[Tile]:
|
def get_tile(self, position: Position) -> Optional[Tile]:
|
||||||
"""Return the tile at the specified position."""
|
"""Return the tile at the specified position."""
|
||||||
tile: Tile
|
tile: Tile
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user