feat(ai): add ai

This commit is contained in:
Kristofers Solo 2024-01-03 03:58:38 +02:00
parent 97a64b44b6
commit e03465d2d3
10 changed files with 201 additions and 10 deletions

79
config.txt Normal file
View File

@ -0,0 +1,79 @@
[NEAT]
fitness_criterion = max
fitness_threshold = 3.9
pop_size = 150
reset_on_extinction = False
[DefaultGenome]
# node activation options
activation_default = sigmoid
activation_mutate_rate = 0.0
activation_options = sigmoid
# node aggregation options
aggregation_default = sum
aggregation_mutate_rate = 0.0
aggregation_options = sum
# node bias options
bias_init_mean = 0.0
bias_init_stdev = 1.0
bias_max_value = 30.0
bias_min_value = -30.0
bias_mutate_power = 0.5
bias_mutate_rate = 0.7
bias_replace_rate = 0.1
# genome compatibility options
compatibility_disjoint_coefficient = 1.0
compatibility_weight_coefficient = 0.5
# connection add/remove rates
conn_add_prob = 0.5
conn_delete_prob = 0.5
# connection enable options
enabled_default = True
enabled_mutate_rate = 0.01
feed_forward = True
initial_connection = full
# node add/remove rates
node_add_prob = 0.2
node_delete_prob = 0.2
# network parameters
num_hidden = 0
num_inputs = 17
num_outputs = 4
# node response options
response_init_mean = 1.0
response_init_stdev = 0.0
response_max_value = 30.0
response_min_value = -30.0
response_mutate_power = 0.0
response_mutate_rate = 0.0
response_replace_rate = 0.0
# connection weight options
weight_init_mean = 0.0
weight_init_stdev = 1.0
weight_max_value = 30
weight_min_value = -30
weight_mutate_power = 0.5
weight_mutate_rate = 0.8
weight_replace_rate = 0.1
[DefaultSpeciesSet]
compatibility_threshold = 3.0
[DefaultStagnation]
species_fitness_func = max
max_stagnation = 20
species_elitism = 2
[DefaultReproduction]
elitism = 2
survival_threshold = 0.2

View File

@ -1,13 +1,15 @@
#!/usr/bin/env python
from ai import train
from loguru import logger
from py2048 import Menu
@logger.catch
def main() -> None:
Menu().run()
# Menu().run()
train()
if __name__ == "__main__":

View File

@ -6,7 +6,12 @@ authors = [{ name = "Kristofers Solo", email = "dev@kristofers.xyz" }]
readme = "README.md"
requires-python = ">=3.11"
license = { text = "GPLv3" }
dependencies = ["pygame-ce==2.3.2", "loguru==0.7.2", "attrs==23.1.0"]
dependencies = [
"pygame-ce==2.3.2",
"loguru==0.7.2",
"attrs==23.1.0",
"neat-python>=0.92",
]
[tool.mypy]
check_untyped_defs = true

View File

@ -1,4 +1,5 @@
pygame-ce>=2.3.2
loguru>=0.7.2
attrs>=23.1.0
neat-python>=0.92
.

3
src/ai/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .train import train
__all__ = ["train"]

73
src/ai/train.py Normal file
View File

@ -0,0 +1,73 @@
import neat
from loguru import logger
from path import BASE_PATH
from py2048 import Menu
def _get_config() -> neat.Config:
config_path = BASE_PATH / "config.txt"
return neat.Config(
neat.DefaultGenome,
neat.DefaultReproduction,
neat.DefaultSpeciesSet,
neat.DefaultStagnation,
config_path,
)
def train() -> None:
config = _get_config()
# p = neat.Checkpointer.restore_checkpoint("neat-checkpoint-0")
p = neat.Population(config)
p.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
p.add_reporter(stats)
p.add_reporter(neat.Checkpointer(1))
winner = p.run(eval_genomes, 50)
logger.info(f"\nBest genome:\n{winner}")
def eval_genomes(genomes, config: neat.Config):
for genome_id, genome in genomes:
genome.fitness = 4.0
app = Menu()
net = neat.nn.FeedForwardNetwork.create(genome, config)
app.play()
app._game_active = False
while True:
output = net.activate(
(
*app.game.board.matrix(),
app.game.board.score,
)
)
decision = output.index(max(output))
decisions = {
0: app.game.move_up,
1: app.game.move_down,
2: app.game.move_left,
3: app.game.move_right,
}
decisions[decision]()
app._hande_events()
app.game.draw(app._surface)
if app.game.board._is_full() or app.game.board.score > 10_000:
calculate_fitness(genome, app.game.board.score)
logger.info(
f"Genome: {genome_id} fitness: {genome.fitness} score: {app.game.board.score}"
)
app.game.restart()
break
def calculate_fitness(genome, score: int):
genome.fitness += score

View File

@ -1,4 +1,5 @@
import random
from typing import Optional
import pygame
from loguru import logger
@ -114,3 +115,25 @@ class Board(pygame.sprite.Group):
"""Reset the board."""
self.empty()
self._initiate_game()
def get_tile(self, position: Position) -> Optional[Tile]:
"""Return the tile at the specified position."""
tile: Tile
for tile in self.sprites():
if tile.pos == position:
return tile
return None
def matrix(self) -> list[int]:
"""Return a 1d matrix of values of the tiles."""
matrix: list[int] = []
for i in range(1, Config.BOARD.len + 1):
for j in range(1, Config.BOARD.len + 1):
tile = self.get_tile(Position(j, i))
if tile:
matrix.append(tile.value)
else:
matrix.append(0)
return matrix

View File

@ -10,9 +10,11 @@ from py2048.utils import ColorScheme, Direction, Position, Size
from .abc import MovableUIElement, UIElement
def _grid_pos(pos: int) -> int:
def _grid_pos(position: Position) -> Position:
"""Return the position in the grid."""
return pos // Config.TILE.size + 1
x = (position.x - Config.BOARD.pos.x) // Config.TILE.size + 1
y = (position.y - Config.BOARD.pos.y) // Config.TILE.size + 1
return Position(x, y)
class Tile(MovableUIElement, pygame.sprite.Sprite):
@ -191,9 +193,9 @@ class Tile(MovableUIElement, pygame.sprite.Sprite):
def __hash__(self) -> int:
"""Return a hash of the tile."""
return hash((self.rect.x, self.rect.y, self.value))
return hash((self.rect.x, self.rect.y))
@property
def pos(self) -> Position:
"""Return the position of the tile."""
return Position(_grid_pos(self.rect.x), _grid_pos(self.rect.y))
return _grid_pos(Position(*self.rect.topleft))

View File

@ -49,10 +49,6 @@ class Game:
logger.info("Game over!")
self.restart()
def restart(self) -> None:
self.board.reset()
self.update_score(0)
def move_up(self) -> None:
self.move(Direction.UP)
@ -65,6 +61,11 @@ class Game:
def move_right(self) -> None:
self.move(Direction.RIGHT)
def restart(self) -> None:
self.board.reset()
self.board.score = 0
self.update_score(0)
def _create_labels(self) -> pygame.sprite.Group:
size = Size(60, 40)

View File

@ -1,5 +1,6 @@
import sys
import neat
import pygame
from loguru import logger
@ -61,6 +62,7 @@ class Menu:
def run(self) -> None:
"""Run the game loop."""
while True:
self._hande_events()