mirror of
https://github.com/kristoferssolo/2048.git
synced 2025-10-21 15:20:35 +00:00
feat(ai): add ai
This commit is contained in:
parent
97a64b44b6
commit
e03465d2d3
79
config.txt
Normal file
79
config.txt
Normal file
@ -0,0 +1,79 @@
|
||||
[NEAT]
|
||||
fitness_criterion = max
|
||||
fitness_threshold = 3.9
|
||||
pop_size = 150
|
||||
reset_on_extinction = False
|
||||
|
||||
[DefaultGenome]
|
||||
# node activation options
|
||||
activation_default = sigmoid
|
||||
activation_mutate_rate = 0.0
|
||||
activation_options = sigmoid
|
||||
|
||||
# node aggregation options
|
||||
aggregation_default = sum
|
||||
aggregation_mutate_rate = 0.0
|
||||
aggregation_options = sum
|
||||
|
||||
# node bias options
|
||||
bias_init_mean = 0.0
|
||||
bias_init_stdev = 1.0
|
||||
bias_max_value = 30.0
|
||||
bias_min_value = -30.0
|
||||
bias_mutate_power = 0.5
|
||||
bias_mutate_rate = 0.7
|
||||
bias_replace_rate = 0.1
|
||||
|
||||
# genome compatibility options
|
||||
compatibility_disjoint_coefficient = 1.0
|
||||
compatibility_weight_coefficient = 0.5
|
||||
|
||||
# connection add/remove rates
|
||||
conn_add_prob = 0.5
|
||||
conn_delete_prob = 0.5
|
||||
|
||||
# connection enable options
|
||||
enabled_default = True
|
||||
enabled_mutate_rate = 0.01
|
||||
|
||||
feed_forward = True
|
||||
initial_connection = full
|
||||
|
||||
# node add/remove rates
|
||||
node_add_prob = 0.2
|
||||
node_delete_prob = 0.2
|
||||
|
||||
# network parameters
|
||||
num_hidden = 0
|
||||
num_inputs = 17
|
||||
num_outputs = 4
|
||||
|
||||
# node response options
|
||||
response_init_mean = 1.0
|
||||
response_init_stdev = 0.0
|
||||
response_max_value = 30.0
|
||||
response_min_value = -30.0
|
||||
response_mutate_power = 0.0
|
||||
response_mutate_rate = 0.0
|
||||
response_replace_rate = 0.0
|
||||
|
||||
# connection weight options
|
||||
weight_init_mean = 0.0
|
||||
weight_init_stdev = 1.0
|
||||
weight_max_value = 30
|
||||
weight_min_value = -30
|
||||
weight_mutate_power = 0.5
|
||||
weight_mutate_rate = 0.8
|
||||
weight_replace_rate = 0.1
|
||||
|
||||
[DefaultSpeciesSet]
|
||||
compatibility_threshold = 3.0
|
||||
|
||||
[DefaultStagnation]
|
||||
species_fitness_func = max
|
||||
max_stagnation = 20
|
||||
species_elitism = 2
|
||||
|
||||
[DefaultReproduction]
|
||||
elitism = 2
|
||||
survival_threshold = 0.2
|
||||
4
main.py
4
main.py
@ -1,13 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
|
||||
from ai import train
|
||||
from loguru import logger
|
||||
from py2048 import Menu
|
||||
|
||||
|
||||
@logger.catch
|
||||
def main() -> None:
|
||||
Menu().run()
|
||||
# Menu().run()
|
||||
train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -6,7 +6,12 @@ authors = [{ name = "Kristofers Solo", email = "dev@kristofers.xyz" }]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
license = { text = "GPLv3" }
|
||||
dependencies = ["pygame-ce==2.3.2", "loguru==0.7.2", "attrs==23.1.0"]
|
||||
dependencies = [
|
||||
"pygame-ce==2.3.2",
|
||||
"loguru==0.7.2",
|
||||
"attrs==23.1.0",
|
||||
"neat-python>=0.92",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
check_untyped_defs = true
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
pygame-ce>=2.3.2
|
||||
loguru>=0.7.2
|
||||
attrs>=23.1.0
|
||||
neat-python>=0.92
|
||||
.
|
||||
|
||||
3
src/ai/__init__.py
Normal file
3
src/ai/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .train import train
|
||||
|
||||
__all__ = ["train"]
|
||||
73
src/ai/train.py
Normal file
73
src/ai/train.py
Normal file
@ -0,0 +1,73 @@
|
||||
import neat
|
||||
from loguru import logger
|
||||
from path import BASE_PATH
|
||||
from py2048 import Menu
|
||||
|
||||
|
||||
def _get_config() -> neat.Config:
|
||||
config_path = BASE_PATH / "config.txt"
|
||||
return neat.Config(
|
||||
neat.DefaultGenome,
|
||||
neat.DefaultReproduction,
|
||||
neat.DefaultSpeciesSet,
|
||||
neat.DefaultStagnation,
|
||||
config_path,
|
||||
)
|
||||
|
||||
|
||||
def train() -> None:
|
||||
config = _get_config()
|
||||
# p = neat.Checkpointer.restore_checkpoint("neat-checkpoint-0")
|
||||
p = neat.Population(config)
|
||||
p.add_reporter(neat.StdOutReporter(True))
|
||||
stats = neat.StatisticsReporter()
|
||||
p.add_reporter(stats)
|
||||
p.add_reporter(neat.Checkpointer(1))
|
||||
|
||||
winner = p.run(eval_genomes, 50)
|
||||
|
||||
logger.info(f"\nBest genome:\n{winner}")
|
||||
|
||||
|
||||
def eval_genomes(genomes, config: neat.Config):
|
||||
for genome_id, genome in genomes:
|
||||
genome.fitness = 4.0
|
||||
app = Menu()
|
||||
net = neat.nn.FeedForwardNetwork.create(genome, config)
|
||||
|
||||
app.play()
|
||||
app._game_active = False
|
||||
|
||||
while True:
|
||||
output = net.activate(
|
||||
(
|
||||
*app.game.board.matrix(),
|
||||
app.game.board.score,
|
||||
)
|
||||
)
|
||||
|
||||
decision = output.index(max(output))
|
||||
|
||||
decisions = {
|
||||
0: app.game.move_up,
|
||||
1: app.game.move_down,
|
||||
2: app.game.move_left,
|
||||
3: app.game.move_right,
|
||||
}
|
||||
|
||||
decisions[decision]()
|
||||
|
||||
app._hande_events()
|
||||
app.game.draw(app._surface)
|
||||
|
||||
if app.game.board._is_full() or app.game.board.score > 10_000:
|
||||
calculate_fitness(genome, app.game.board.score)
|
||||
logger.info(
|
||||
f"Genome: {genome_id} fitness: {genome.fitness} score: {app.game.board.score}"
|
||||
)
|
||||
app.game.restart()
|
||||
break
|
||||
|
||||
|
||||
def calculate_fitness(genome, score: int):
|
||||
genome.fitness += score
|
||||
@ -1,4 +1,5 @@
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import pygame
|
||||
from loguru import logger
|
||||
@ -114,3 +115,25 @@ class Board(pygame.sprite.Group):
|
||||
"""Reset the board."""
|
||||
self.empty()
|
||||
self._initiate_game()
|
||||
|
||||
def get_tile(self, position: Position) -> Optional[Tile]:
|
||||
"""Return the tile at the specified position."""
|
||||
tile: Tile
|
||||
for tile in self.sprites():
|
||||
if tile.pos == position:
|
||||
return tile
|
||||
return None
|
||||
|
||||
def matrix(self) -> list[int]:
|
||||
"""Return a 1d matrix of values of the tiles."""
|
||||
matrix: list[int] = []
|
||||
|
||||
for i in range(1, Config.BOARD.len + 1):
|
||||
for j in range(1, Config.BOARD.len + 1):
|
||||
tile = self.get_tile(Position(j, i))
|
||||
if tile:
|
||||
matrix.append(tile.value)
|
||||
else:
|
||||
matrix.append(0)
|
||||
|
||||
return matrix
|
||||
|
||||
@ -10,9 +10,11 @@ from py2048.utils import ColorScheme, Direction, Position, Size
|
||||
from .abc import MovableUIElement, UIElement
|
||||
|
||||
|
||||
def _grid_pos(pos: int) -> int:
|
||||
def _grid_pos(position: Position) -> Position:
|
||||
"""Return the position in the grid."""
|
||||
return pos // Config.TILE.size + 1
|
||||
x = (position.x - Config.BOARD.pos.x) // Config.TILE.size + 1
|
||||
y = (position.y - Config.BOARD.pos.y) // Config.TILE.size + 1
|
||||
return Position(x, y)
|
||||
|
||||
|
||||
class Tile(MovableUIElement, pygame.sprite.Sprite):
|
||||
@ -191,9 +193,9 @@ class Tile(MovableUIElement, pygame.sprite.Sprite):
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return a hash of the tile."""
|
||||
return hash((self.rect.x, self.rect.y, self.value))
|
||||
return hash((self.rect.x, self.rect.y))
|
||||
|
||||
@property
|
||||
def pos(self) -> Position:
|
||||
"""Return the position of the tile."""
|
||||
return Position(_grid_pos(self.rect.x), _grid_pos(self.rect.y))
|
||||
return _grid_pos(Position(*self.rect.topleft))
|
||||
|
||||
@ -49,10 +49,6 @@ class Game:
|
||||
logger.info("Game over!")
|
||||
self.restart()
|
||||
|
||||
def restart(self) -> None:
|
||||
self.board.reset()
|
||||
self.update_score(0)
|
||||
|
||||
def move_up(self) -> None:
|
||||
self.move(Direction.UP)
|
||||
|
||||
@ -65,6 +61,11 @@ class Game:
|
||||
def move_right(self) -> None:
|
||||
self.move(Direction.RIGHT)
|
||||
|
||||
def restart(self) -> None:
|
||||
self.board.reset()
|
||||
self.board.score = 0
|
||||
self.update_score(0)
|
||||
|
||||
def _create_labels(self) -> pygame.sprite.Group:
|
||||
size = Size(60, 40)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import sys
|
||||
|
||||
import neat
|
||||
import pygame
|
||||
from loguru import logger
|
||||
|
||||
@ -61,6 +62,7 @@ class Menu:
|
||||
|
||||
def run(self) -> None:
|
||||
"""Run the game loop."""
|
||||
|
||||
while True:
|
||||
self._hande_events()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user