diff --git a/config.txt b/config.txt new file mode 100644 index 0000000..4d036e0 --- /dev/null +++ b/config.txt @@ -0,0 +1,79 @@ +[NEAT] +fitness_criterion = max +fitness_threshold = 3.9 +pop_size = 150 +reset_on_extinction = False + +[DefaultGenome] +# node activation options +activation_default = sigmoid +activation_mutate_rate = 0.0 +activation_options = sigmoid + +# node aggregation options +aggregation_default = sum +aggregation_mutate_rate = 0.0 +aggregation_options = sum + +# node bias options +bias_init_mean = 0.0 +bias_init_stdev = 1.0 +bias_max_value = 30.0 +bias_min_value = -30.0 +bias_mutate_power = 0.5 +bias_mutate_rate = 0.7 +bias_replace_rate = 0.1 + +# genome compatibility options +compatibility_disjoint_coefficient = 1.0 +compatibility_weight_coefficient = 0.5 + +# connection add/remove rates +conn_add_prob = 0.5 +conn_delete_prob = 0.5 + +# connection enable options +enabled_default = True +enabled_mutate_rate = 0.01 + +feed_forward = True +initial_connection = full + +# node add/remove rates +node_add_prob = 0.2 +node_delete_prob = 0.2 + +# network parameters +num_hidden = 0 +num_inputs = 17 +num_outputs = 4 + +# node response options +response_init_mean = 1.0 +response_init_stdev = 0.0 +response_max_value = 30.0 +response_min_value = -30.0 +response_mutate_power = 0.0 +response_mutate_rate = 0.0 +response_replace_rate = 0.0 + +# connection weight options +weight_init_mean = 0.0 +weight_init_stdev = 1.0 +weight_max_value = 30 +weight_min_value = -30 +weight_mutate_power = 0.5 +weight_mutate_rate = 0.8 +weight_replace_rate = 0.1 + +[DefaultSpeciesSet] +compatibility_threshold = 3.0 + +[DefaultStagnation] +species_fitness_func = max +max_stagnation = 20 +species_elitism = 2 + +[DefaultReproduction] +elitism = 2 +survival_threshold = 0.2 diff --git a/main.py b/main.py index 256fa24..59a035c 100755 --- a/main.py +++ b/main.py @@ -1,13 +1,15 @@ #!/usr/bin/env python +from ai import train from loguru import logger from py2048 import Menu @logger.catch def main() -> None: - Menu().run() + # Menu().run() + train() if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 69e430c..eb9e1a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,12 @@ authors = [{ name = "Kristofers Solo", email = "dev@kristofers.xyz" }] readme = "README.md" requires-python = ">=3.11" license = { text = "GPLv3" } -dependencies = ["pygame-ce==2.3.2", "loguru==0.7.2", "attrs==23.1.0"] +dependencies = [ + "pygame-ce==2.3.2", + "loguru==0.7.2", + "attrs==23.1.0", + "neat-python>=0.92", +] [tool.mypy] check_untyped_defs = true diff --git a/requirements.txt b/requirements.txt index 9c0c746..1646e49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ pygame-ce>=2.3.2 loguru>=0.7.2 attrs>=23.1.0 +neat-python>=0.92 . diff --git a/src/ai/__init__.py b/src/ai/__init__.py new file mode 100644 index 0000000..4102b72 --- /dev/null +++ b/src/ai/__init__.py @@ -0,0 +1,3 @@ +from .train import train + +__all__ = ["train"] diff --git a/src/ai/train.py b/src/ai/train.py new file mode 100644 index 0000000..861e7cb --- /dev/null +++ b/src/ai/train.py @@ -0,0 +1,73 @@ +import neat +from loguru import logger +from path import BASE_PATH +from py2048 import Menu + + +def _get_config() -> neat.Config: + config_path = BASE_PATH / "config.txt" + return neat.Config( + neat.DefaultGenome, + neat.DefaultReproduction, + neat.DefaultSpeciesSet, + neat.DefaultStagnation, + config_path, + ) + + +def train() -> None: + config = _get_config() + # p = neat.Checkpointer.restore_checkpoint("neat-checkpoint-0") + p = neat.Population(config) + p.add_reporter(neat.StdOutReporter(True)) + stats = neat.StatisticsReporter() + p.add_reporter(stats) + p.add_reporter(neat.Checkpointer(1)) + + winner = p.run(eval_genomes, 50) + + logger.info(f"\nBest genome:\n{winner}") + + +def eval_genomes(genomes, config: neat.Config): + for genome_id, genome in genomes: + genome.fitness = 4.0 + app = Menu() + net = neat.nn.FeedForwardNetwork.create(genome, config) + + app.play() + app._game_active = False + + while True: + output = net.activate( + ( + *app.game.board.matrix(), + app.game.board.score, + ) + ) + + decision = output.index(max(output)) + + decisions = { + 0: app.game.move_up, + 1: app.game.move_down, + 2: app.game.move_left, + 3: app.game.move_right, + } + + decisions[decision]() + + app._hande_events() + app.game.draw(app._surface) + + if app.game.board._is_full() or app.game.board.score > 10_000: + calculate_fitness(genome, app.game.board.score) + logger.info( + f"Genome: {genome_id} fitness: {genome.fitness} score: {app.game.board.score}" + ) + app.game.restart() + break + + +def calculate_fitness(genome, score: int): + genome.fitness += score diff --git a/src/py2048/objects/board.py b/src/py2048/objects/board.py index 004d549..5fa6a6c 100644 --- a/src/py2048/objects/board.py +++ b/src/py2048/objects/board.py @@ -1,4 +1,5 @@ import random +from typing import Optional import pygame from loguru import logger @@ -114,3 +115,25 @@ class Board(pygame.sprite.Group): """Reset the board.""" self.empty() self._initiate_game() + + def get_tile(self, position: Position) -> Optional[Tile]: + """Return the tile at the specified position.""" + tile: Tile + for tile in self.sprites(): + if tile.pos == position: + return tile + return None + + def matrix(self) -> list[int]: + """Return a 1d matrix of values of the tiles.""" + matrix: list[int] = [] + + for i in range(1, Config.BOARD.len + 1): + for j in range(1, Config.BOARD.len + 1): + tile = self.get_tile(Position(j, i)) + if tile: + matrix.append(tile.value) + else: + matrix.append(0) + + return matrix diff --git a/src/py2048/objects/tile.py b/src/py2048/objects/tile.py index 30feb94..380c06f 100644 --- a/src/py2048/objects/tile.py +++ b/src/py2048/objects/tile.py @@ -10,9 +10,11 @@ from py2048.utils import ColorScheme, Direction, Position, Size from .abc import MovableUIElement, UIElement -def _grid_pos(pos: int) -> int: +def _grid_pos(position: Position) -> Position: """Return the position in the grid.""" - return pos // Config.TILE.size + 1 + x = (position.x - Config.BOARD.pos.x) // Config.TILE.size + 1 + y = (position.y - Config.BOARD.pos.y) // Config.TILE.size + 1 + return Position(x, y) class Tile(MovableUIElement, pygame.sprite.Sprite): @@ -191,9 +193,9 @@ class Tile(MovableUIElement, pygame.sprite.Sprite): def __hash__(self) -> int: """Return a hash of the tile.""" - return hash((self.rect.x, self.rect.y, self.value)) + return hash((self.rect.x, self.rect.y)) @property def pos(self) -> Position: """Return the position of the tile.""" - return Position(_grid_pos(self.rect.x), _grid_pos(self.rect.y)) + return _grid_pos(Position(*self.rect.topleft)) diff --git a/src/py2048/screens/game.py b/src/py2048/screens/game.py index 19d54eb..857923c 100644 --- a/src/py2048/screens/game.py +++ b/src/py2048/screens/game.py @@ -49,10 +49,6 @@ class Game: logger.info("Game over!") self.restart() - def restart(self) -> None: - self.board.reset() - self.update_score(0) - def move_up(self) -> None: self.move(Direction.UP) @@ -65,6 +61,11 @@ class Game: def move_right(self) -> None: self.move(Direction.RIGHT) + def restart(self) -> None: + self.board.reset() + self.board.score = 0 + self.update_score(0) + def _create_labels(self) -> pygame.sprite.Group: size = Size(60, 40) diff --git a/src/py2048/screens/menu.py b/src/py2048/screens/menu.py index f9e6f91..2e5d224 100644 --- a/src/py2048/screens/menu.py +++ b/src/py2048/screens/menu.py @@ -1,5 +1,6 @@ import sys +import neat import pygame from loguru import logger @@ -61,6 +62,7 @@ class Menu: def run(self) -> None: """Run the game loop.""" + while True: self._hande_events()